From f6df09b9bf819f18d35310ca31ddb187a924b75f Mon Sep 17 00:00:00 2001 From: Faiaz Sanaulla <105630300+fsdvh@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:24:20 +0200 Subject: [PATCH] VTX-666: Sync from upstream (#51) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Preallocate for `FixedSizeList` in `concat` (#5862) * Add specific fixed size list concat test * Add fixed size list concat benchmark * Improve `FixedSizeList` concat performance for large list * `cargo fmt` * Increase size of `FixedSizeList` benchmark data * Get capacity recursively for `FixedSizeList` * Reuse `Capacities::List` to avoid breaking change * Use correct default capacities * Avoid a `Box::new()` when not needed * format --------- Co-authored-by: Will Jones * Add eq benchmark for StringArray/StringViewArray (#5924) * add neq/eq benchmark for String/ViewArray * move bench to comparsion kernel * clean unnecessary dep * make clippy happy * Add the ability for Maps to cast to another case where the field names are different (#5703) * Add the ability for Maps to cast to another case where the field names are different. Arrow Maps have field names for the elements of the fields, the field names are allowed to be any value and do not affect the type of the data. This allows a Map where the field names are key_value, key, value to be mapped to a entries, keys, values. This can be helpful in merging record batches that may have come from different sources. This also makes maps behave similar to lists which also have a field to distinguish their elements. * Apply suggestions from code review Co-authored-by: Andrew Lamb * Feedback from code review - simplify map casting logic to reuse the entries - Added unit tests for negative cases - Use MapBuilder to make the intended type clearer. * fix formatting * Lint and format * correctly set the null fields --------- Co-authored-by: Andrew Lamb * fix(ipc): set correct row count when reading struct arrays with zero fields (#5918) * Update zstd-sys requirement from >=2.0.0, <2.0.10 to >=2.0.0, <2.0.12 (#5913) Updates the requirements on [zstd-sys](https://github.com/gyscos/zstd-rs) to permit the latest version. - [Release notes](https://github.com/gyscos/zstd-rs/releases) - [Commits](https://github.com/gyscos/zstd-rs/compare/zstd-sys-2.0.7...zstd-sys-2.0.11) --- updated-dependencies: - dependency-name: zstd-sys dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add `MultipartUpload` blanket implementation for `Box` (#5919) * add impl for box * update * another update * small fix * Fix typo in benchmarks (#5935) * row format benches for bool & nullable int (#5943) * Implement arrow-row encoding/decoding for view types (#5922) * implement arrow-row encoding/decoding for view types * add doc comments, better error msg, more test coverage * ensure no performance regression * update perf * fix bug * make fmt happy * Update arrow-array/src/array/byte_view_array.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * update * update comments * move cmp around * move things around and remove inline hint * Update arrow-array/src/array/byte_view_array.rs Co-authored-by: Andrew Lamb * Update arrow-ord/src/cmp.rs Co-authored-by: Andrew Lamb * return error instead of panic * remove unnecessary func --------- Co-authored-by: Andrew Lamb Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * Better document support for nested comparison (#5942) * Update quick-xml requirement from 0.32.0 to 0.33.0 in /object_store (#5946) Updates the requirements on [quick-xml](https://github.com/tafia/quick-xml) to permit the latest version. - [Release notes](https://github.com/tafia/quick-xml/releases) - [Changelog](https://github.com/tafia/quick-xml/blob/master/Changelog.md) - [Commits](https://github.com/tafia/quick-xml/compare/v0.32.0...v0.33.0) --- updated-dependencies: - dependency-name: quick-xml dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Implement like/ilike etc for StringViewArray (#5931) * like for string view array * fix bug * update doc * update tests * test: Add unit test for extending slice of list array (#5948) * test: Add unit test for extending slice of list array * For review * Update quick-xml requirement from 0.33.0 to 0.34.0 in /object_store (#5954) Updates the requirements on [quick-xml](https://github.com/tafia/quick-xml) to permit the latest version. - [Release notes](https://github.com/tafia/quick-xml/releases) - [Changelog](https://github.com/tafia/quick-xml/blob/master/Changelog.md) - [Commits](https://github.com/tafia/quick-xml/compare/v0.33.0...v0.34.0) --- updated-dependencies: - dependency-name: quick-xml dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Minor: fixup contribution guide (#5952) * chore(5797): change default data_page_row_limit to 20k (#5957) * Improve error message for unsupported nested comparison (#5961) * Improve error message for unsupported nested comparison * Update arrow-ord/src/cmp.rs Co-authored-by: Jay Zhan --------- Co-authored-by: Jay Zhan * feat: add max_bytes and min_bytes on PageIndex (#5950) * Faster primitive arrays encoding into row format (#5858) * skip iterator removed from primitive encoding * special cases for not-null primitives encoding * faster iterators for nullable columns * Document process for PRs with breaking changes (#5953) * Document process for PRs with breaking changes * ticket reference * Update CONTRIBUTING.md Co-authored-by: Xuanwo --------- Co-authored-by: Xuanwo * `like` benchmark for StringView (#5936) * Expose `IntervalMonthDayNano` and `IntervalDayTime` and update docs (#5928) * Expose IntervalMonthDayNano and IntervalDayMonth and update docs * fix doc test * implement sort for view types (#5963) * Fix FFI array offset handling (#5964) * Add benchmark for reading binary/binary view from parquet (#5968) * implement sort for view types * add bench for binary/binary view * Add view buffer for parquet reader (#5970) * implement sort for view types * add bench for binary/binary view * add view buffer, prepare for byte_view_array reader * make clippy happy * reuse make_view_unchecked * Update parquet/src/arrow/buffer/view_buffer.rs Co-authored-by: Andrew Lamb * update * rename and inline --------- Co-authored-by: Andrew Lamb * Handle flight dictionary ID assignment automatically (#5971) * failing test * Handle dict ID assignment during flight encoding/decoding * remove println * One more println * Make auto-assign optional * Update docs * Remove breaking change * Update arrow-ipc/src/writer.rs Co-authored-by: Andrew Lamb * Remove breaking change to DictionaryTracker ctor --------- Co-authored-by: Andrew Lamb * Make ObjectStoreScheme public (#5912) * Make ObjectStoreScheme public * Fix clippy, add docs and examples --------- Co-authored-by: Andrew Lamb * Add operation in ArrowNativeTypeOp::neg_check error message (#5944) (#5980) * feat: support reading OPTIONAL column in parquet_derive (#5717) * support def_level=1 but non-null column in reader * update comment, adapt ut to the uuid change --------- Co-authored-by: Ye Yuan * Update quick-xml requirement from 0.34.0 to 0.35.0 in /object_store (#5983) Updates the requirements on [quick-xml](https://github.com/tafia/quick-xml) to permit the latest version. - [Release notes](https://github.com/tafia/quick-xml/releases) - [Changelog](https://github.com/tafia/quick-xml/blob/master/Changelog.md) - [Commits](https://github.com/tafia/quick-xml/compare/v0.34.0...v0.35.0) --- updated-dependencies: - dependency-name: quick-xml dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Reduce repo size by removing accumulative commits in CI job (#5982) * Use force_orphan in the CI job Use force_orphan in the CI job * Update .github/workflows/docs.yml --------- Co-authored-by: Andrew Lamb * Minor: fix clippy complaint in parquet_derive (#5984) * Add user defined metadata (#5915) * Add metadata attribute * Add user-defined metadata for AWS/GCP/ABS `with_attributes` * Reads and writes both implemented * Add tests for GetClient * Fix an indentation * Placate clippy * Use `strip_prefix` and mutable attributes * Use static Cow for attribute metadata * Add error for value decode failure * Remove unnecessary into * Provide Arrow Schema Hint to Parquet Reader - Alternative 2 (#5939) * Adds option for providing a schema to the Arrow Parquet Reader. * Adds more complete tests. Adds a more detailed error message for incompatible columns. Adds nested fields to test_with_schema. Adds test for incompatible nested field. Updates documentation. * Add an example using showing how to use the with_schema option. --------- Co-authored-by: Eric Fredine * WriteMultipart Abort on MultipartUpload::complete Error (#5974) * update * another one * more update * another update * debug * debug * some updates * debug * debug * cleanup * cleanup * simplify * address some comments * cleanup on failure * restore abort method * docs * Implement directly build byte view array on top of parquet buffer (#5972) * implement sort for view types * add bench for binary/binary view * add view buffer, prepare for byte_view_array reader * make clippy happy * add byte view array reader * fix doc link * reuse make_view_unchecked * Update parquet/src/arrow/buffer/view_buffer.rs Co-authored-by: Andrew Lamb * update * rename and inline * Update parquet/src/arrow/array_reader/byte_view_array.rs Co-authored-by: Andrew Lamb * use unused * Revert "use unused" This reverts commit 5e6887095251066cfa9998cb16a9eea788f9e175. --------- Co-authored-by: Andrew Lamb * fix: error in case of invalid interval expression (#5987) This PR addresses an error that occurs when interval expressions contains invalid amount of components. The error type was previously unclear and confusing: `NotYetImplemented`. That doesn't seem correct, because such values are not going to be supported. Let's take a look at such example: ```sql INTERVAL '1 MONTH DAY' ``` This is an obvious typo/mistake which leads to such error, but in fact it's just invalid value (missing number before `DAY`) * Add ParquetMetadata::memory_size size estimation (#5965) * Add ParquetMetadata::memory_size size estimation * Require HeapSize for ParquetValueType * feat(5851): ArrowWriter memory usage (#5967) * refactor(5851): delineate the different memory estimates APIs for the ArrowWriter and column writers * feat(5851): add memory size estimates to the ColumnValueEncoder implementations and the DictEncoder * test(5851): add memory_size() to in-progress test * chore(5851): update docs to make it more explicit what is the difference btwn memory_size vs get_estimated_total_byte * feat(5851): clarify the ColumnValueEncoder::estimated_memory_size interface, and update impls to account for bloom filter size * feat(5851): account for stats array size in the ByteArrayEncoder * Refine documentation * More accurate memory estimation * Improve tests * Update accounting for non dict encoded data * Include more memory size calculations * clean up async writer * clippy * tweak --------- Co-authored-by: Andrew Lamb * Prepare arrow `52.1.0` (#5992) * Update version to 52.1.0 * Prepare arrow `52.1.0` * Update CHANGELOG * Implement dictionary support for reading ByteView from parquet (#5973) * implement dictionary encoding support * update comments * implement `DataType::try_form(&str)` (#5994) * implement "DataType::try_form(&str)" * add missing file * add FromStr as well as TryFrom<&str> * fmt * Add additional documentation and examples to DataType (#5997) * Automatically cleanup empty dirs in LocalFileSystem (#5978) * automatically cleanup empty dirs * automatic cleanup toggle * configurable cleanup * test for automatic dir deletion * clippy * more comments * Add FlightSqlServiceClient::new_from_inner (#6003) * fix doc ci in latest rust nightly version (#6012) * allow rustdoc::unportable_markdown in arrow-flight. * fix doc in sql_info.rs. * reduce scope of lint disable --------- Co-authored-by: Andrew Lamb * Deduplicate strings/binarys when building view types (#6005) * implement string view deduplication in builder * make clippy happy * Apply suggestions from code review Co-authored-by: Andrew Lamb * better coding style --------- Co-authored-by: Andrew Lamb * Fast utf8 validation when loading string view from parquet (#6009) * fast utf8 validation * better documentation * Update parquet/src/arrow/array_reader/byte_view_array.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * Rename `Schema::all_fields` to `flattened_fields` (#6001) * Rename Schema::all_fields to flattened_fields * Add doc example for Schema::flattened_fields * fmt doc example * Update arrow-schema/src/schema.rs --------- Co-authored-by: Andrew Lamb * Complete `StringViewArray` and `BinaryViewArray` parquet decoder: implement delta byte array and delta length byte array encoding (#6004) * implement all encodings * address comments * fix bug * Update parquet/src/arrow/array_reader/byte_view_array.rs Co-authored-by: Andrew Lamb * fix test * update comments * update test * Only copy strings one --------- Co-authored-by: Andrew Lamb * Update zstd-sys requirement from >=2.0.0, <2.0.12 to >=2.0.0, <2.0.13 (#6019) Updates the requirements on [zstd-sys](https://github.com/gyscos/zstd-rs) to permit the latest version. - [Release notes](https://github.com/gyscos/zstd-rs/releases) - [Commits](https://github.com/gyscos/zstd-rs/compare/zstd-sys-2.0.7...zstd-sys-2.0.12) --- updated-dependencies: - dependency-name: zstd-sys dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update clap test (#6028) * Unsafe improvements: core `parquet` crate. (#6024) * Unsafe improvements: core `parquet` crate. * Make FromBytes an unsafe trait. * Improve performance reading `ByteViewArray` from parquet by removing an implicit copy (#6031) * update byte view array to not implicit copy * Add small comments * Update quick-xml requirement from 0.35.0 to 0.36.0 in /object_store (#6032) Updates the requirements on [quick-xml](https://github.com/tafia/quick-xml) to permit the latest version. - [Release notes](https://github.com/tafia/quick-xml/releases) - [Changelog](https://github.com/tafia/quick-xml/blob/master/Changelog.md) - [Commits](https://github.com/tafia/quick-xml/compare/v0.35.0...v0.36.0) --- updated-dependencies: - dependency-name: quick-xml dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Fix `hashbrown` version in `arrow-array`, remove from `arrow-row` (#6035) * Additional tests for parquet reader utf8 validation (#6023) * Clean up unused code for view types in offset buffer (#6040) * clean up unused view types in offset buffer * make tests happy * Move avoid using copy-based buffer creation (#6039) * Fix 5592: Colon (:) in in object_store::path::{Path} is not handled on Windows (#5830) * Fix issue #5800: Handle missing files in list_with_delimiter * draft * cargo fmt * Handle leading colon * Add windows CI * Fix CI job * Only run local tests and set target family for failing tests * Run all tests without my changes and removed target os * Restore changes again * Add back newline (removed by mistake) * Fix test after merge with master * Minor API adjustments for StringViewBuilder (#6047) * minor update * add memory accounting * Update arrow-buffer/src/builder/null.rs Co-authored-by: Andrew Lamb * Update arrow-array/src/builder/generic_bytes_view_builder.rs Co-authored-by: Andrew Lamb * update comments --------- Co-authored-by: Andrew Lamb * Fix typo in GenericByteViewArray documentation (#6054) * Directly decode String/BinaryView types from arrow-row format (#6044) * add string view bench * check in new impl * add utf8 * quick utf8 validation * Update arrow-row/src/variable.rs Co-authored-by: Andrew Lamb * address comments * update * Revert "address comments" This reverts commit e2656c94dd5ff4fb2f486278feb346d44a7f5436. * addr comments --------- Co-authored-by: Andrew Lamb * Add begin/end_transaction methods in FlightSqlServiceClient (#6026) * Add begin/end_transaction methods in FlightSqlServiceClient * Add test * Remove unused imports * Implement min max support for string/binary view types (#6053) * add * implement min max support for string/binary view * update tests * Add parquet `StatisticsConverter` for arrow reader (#6046) * Adds arrow statistics converter for parquet stastistics. * Adds integration tests for arrow statsistics converter. * Fix linting, remove todo, re-use arrow code. * Remove commented out debug::log statements. * Move parquet_column to lib.rs * doc tweaks * Add benchmark * Add parquet_column_index and arrow_field accessors + test * Copy edit docs obsessively * clippy --------- Co-authored-by: Eric Fredine Co-authored-by: Andrew Lamb * StringView support in arrow-csv (#6062) * StringView support in arrow-csv * review and micro-benches * Minor: clarify the relationship between `file::metadata` and `format` (#6049) * Do not write `ColumnIndex` for null columns when not writing page statistics (#6011) * disable column_index_builder if no page stats are collected * add test * no need to clone descr --------- Co-authored-by: Andrew Lamb * Reorganize arrow-flight test code (#6065) * Reorganize test code * asf header * reuse TestFixture * .await * Create flight_sql_client.rs * remove code * remove unused import * Fix clippy lints * Sanitize error message for sensitive requests (#6074) * Sanitize error message for sensitive requests * Clippy * use GCE metadata server env var overrides (#6015) * use GCE metadata env var overrides * update docs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * Correct timeout in comment from 5s to 30s (#6073) * Prepare for object_store `0.10.2` release (#6079) * Prepare for `object_store 10.2.0` release * Add CHANGELOG * Historical changelog * Minor: Improve parquet PageIndex documentation (#6042) * Minor: Improve parquet PageIndex documentation * More improvements * Add reasons for data page being without null * Apply suggestions from code review Co-authored-by: Val Lorentz * Update parquet/src/file/page_index/index.rs --------- Co-authored-by: Val Lorentz * Enable casting from Utf8View (#6077) * Enable casting from Utf8View -> string or temporal types * save * implement casting utf8view -> timestamp/interval types, with tests * fix clippy * fmt --------- Co-authored-by: Andrew Lamb * Add PartialEq to ParquetMetaData and FileMetadata (#6082) Prep for #6000 * fix panic in `ParquetMetadata::memory_size`: check has_min_max_set before invoking min()/max() (#6092) * fix: check has_min_max_set before invoking min()/max() * chore: add unit test for statistics heap size * Fixup test --------- Co-authored-by: Andrew Lamb * Optimize `max_boolean` by operating on u64 chunks (#6098) * Optimize `max_boolean` Operate on bit chunks instead of individual booleans, which can lead to massive speedups while not regressing the short-circuiting behavior of the existing implementation. `cargo bench --bench aggregate_kernels -- "bool/max"` shows throughput improvements between 50% to 23390% on my machine. * add tests exercising u64 chunk code * add benchmark to track performance (#6101) * Make bool_or an alias for max_boolean (#6100) Improves `cargo bench --bench aggregate_kernels -- "bool/or"` throughput by 68%-22366% on my machine * Faster `GenericByteView` construction (#6102) * add benchmark to track performance * fast byte view construction * make doc happy * fix clippy * update comments * Implement specialized min/max for `GenericBinaryView` (`StringView` and `BinaryView`) (#6089) * implement better min/max for string view * Apply suggestions from code review Co-authored-by: Andrew Lamb * address review comments --------- Co-authored-by: Andrew Lamb * Prepare `52.2.0` release (#6110) * Update version to 52.2.0 * Update CHANGELOG for 52.2.0 * touchups * manual tweaks * manual tweaks * added a flush method to IPC writers (#6108) While the writers expose `get_ref` and `get_mut` to access the underlying `io::Write` writer, there is an internal layer of a `BufWriter` that is not accessible. Because of that, there is no way to ensure that all messages written thus far to the `StreamWriter` or `FileWriter` have actually been passed to the underlying writer. Here we expose a `flush` method that flushes the internal buffer and the underlying writer. See #6099 for the discussion. * Fix Clippy for the Rust 1.80 release (#6116) * Fix clippy lints in arrow-data * Fix clippy errors in arrow-array * fix clippy in concat * Clippy in arrow-string * remove unecessary feature in arrow-array * fix clippy in arrow-cast * Fix clippy in parquet crate * Fix clippy in arrow-flight * Fix clippy in object_store crate (#6120) * Fix clippy in object_store crate * clippy ignore * Merge `53.0.0-dev` dev branch to main (#6126) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` (#6041) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` Signed-off-by: Bugen Zhao * fix example tests Signed-off-by: Bugen Zhao --------- Signed-off-by: Bugen Zhao * Remove `impl> From for Buffer` that easily accidentally copies data (#6043) * deprecate auto copy, ask explicit reference * update comments * make cargo doc happy * Make display of interval types more pretty (#6006) * improve dispaly for interval. * update test in pretty, and fix display problem. * tmp * fix tests in arrow-cast. * fix tests in pretty. * fix style. * Update snafu (#5930) * Update Parquet thrift generated structures (#6045) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * Revert "Revert "Write Bloom filters between row groups instead of the end (#…" (#5933) This reverts commit 22e0b4432c9838f2536284015271d3de9a165135. * Revert "Update snafu (#5930)" (#6069) This reverts commit 756b1fb26d1702f36f446faf9bb40a4869c3e840. * Update pyo3 requirement from 0.21.1 to 0.22.1 (fixed) (#6075) * Update pyo3 requirement from 0.21.1 to 0.22.1 Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/main/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.1) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * refactor: remove deprecated `FromPyArrow::from_pyarrow` "GIL Refs" are being phased out. * chore: update `pyo3` in integration tests --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * remove repeated codes to make the codes more concise. (#6080) * Add `unencoded_byte_array_data_bytes` to `ParquetMetaData` (#6068) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * add support for unencoded_byte_array_data_bytes * add comments * change sig of ColumnMetrics::update_variable_length_bytes() * rename ParquetOffsetIndex to OffsetSizeIndex * rename some functions * suggestion from review Co-authored-by: Andrew Lamb * add Default trait to ColumnMetrics as suggested in review * rename OffsetSizeIndex to OffsetIndexMetaData --------- Co-authored-by: Andrew Lamb * Update pyo3 requirement from 0.21.1 to 0.22.2 (#6085) Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/v0.22.2/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.2) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Deprecate read_page_locations() and simplify offset index in `ParquetMetaData` (#6095) * deprecate read_page_locations * add to_thrift() to OffsetIndexMetaData * Update parquet/src/column/writer/mod.rs Co-authored-by: Ed Seidl --------- Signed-off-by: Bugen Zhao Signed-off-by: dependabot[bot] Co-authored-by: Bugen Zhao Co-authored-by: Xiangpeng Hao Co-authored-by: kamille Co-authored-by: Jesse Co-authored-by: Ed Seidl Co-authored-by: Marco Neumann Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add support for level histograms added in PARQUET-2261 to `ParquetMetaData` (#6105) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` (#6041) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` Signed-off-by: Bugen Zhao * fix example tests Signed-off-by: Bugen Zhao --------- Signed-off-by: Bugen Zhao * Remove `impl> From for Buffer` that easily accidentally copies data (#6043) * deprecate auto copy, ask explicit reference * update comments * make cargo doc happy * Make display of interval types more pretty (#6006) * improve dispaly for interval. * update test in pretty, and fix display problem. * tmp * fix tests in arrow-cast. * fix tests in pretty. * fix style. * Update snafu (#5930) * Update Parquet thrift generated structures (#6045) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * Revert "Revert "Write Bloom filters between row groups instead of the end (#…" (#5933) This reverts commit 22e0b4432c9838f2536284015271d3de9a165135. * Revert "Update snafu (#5930)" (#6069) This reverts commit 756b1fb26d1702f36f446faf9bb40a4869c3e840. * Update pyo3 requirement from 0.21.1 to 0.22.1 (fixed) (#6075) * Update pyo3 requirement from 0.21.1 to 0.22.1 Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/main/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.1) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * refactor: remove deprecated `FromPyArrow::from_pyarrow` "GIL Refs" are being phased out. * chore: update `pyo3` in integration tests --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * remove repeated codes to make the codes more concise. (#6080) * Add `unencoded_byte_array_data_bytes` to `ParquetMetaData` (#6068) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * add support for unencoded_byte_array_data_bytes * add comments * change sig of ColumnMetrics::update_variable_length_bytes() * rename ParquetOffsetIndex to OffsetSizeIndex * rename some functions * suggestion from review Co-authored-by: Andrew Lamb * add Default trait to ColumnMetrics as suggested in review * rename OffsetSizeIndex to OffsetIndexMetaData --------- Co-authored-by: Andrew Lamb * deprecate read_page_locations * add level histograms to metadata * add to_thrift() to OffsetIndexMetaData * Update pyo3 requirement from 0.21.1 to 0.22.2 (#6085) Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/v0.22.2/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.2) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Deprecate read_page_locations() and simplify offset index in `ParquetMetaData` (#6095) * deprecate read_page_locations * add to_thrift() to OffsetIndexMetaData * move valid test into ColumnIndexBuilder::append_histograms * move update_histogram() inside ColumnMetrics * Update parquet/src/column/writer/mod.rs Co-authored-by: Ed Seidl * Implement LevelHistograms as a struct * formatting * fix error in docs --------- Signed-off-by: Bugen Zhao Signed-off-by: dependabot[bot] Co-authored-by: Bugen Zhao Co-authored-by: Xiangpeng Hao Co-authored-by: kamille Co-authored-by: Jesse Co-authored-by: Andrew Lamb Co-authored-by: Marco Neumann Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add ArrowError::ArithmeticError (#6130) * Implement data_part for intervals (#6071) Signed-off-by: Nick Cameron Co-authored-by: Andrew Lamb * Remove `SchemaBuilder` dependency from `StructArray` constructors (#6139) * Remove automatic buffering in `ipc::reader::FileReader` for for consistent buffering (#6132) * change ipc::reader and writer APIs for consistent buffering Current writer API automatically wraps the supplied std::io::Writer impl into a BufWriter. It is cleaner and more idiomatic to have the default be using the supplied impl directly, as the user might already have a BufWriter or an impl that doesn't actually benefit from buffering at all. StreamReader does a similar thing, but it also exposes a `try_new_unbuffered` that bypasses the internal wrap. Here we propose a consistent and non-buffered by default API: - `try_new` does not wrap the passed reader/writer, - `try_new_buffered` is a convenience function that does wrap the reader/writer into a BufReader/BufWriter, - all four publicly exposed IPC reader/writers follow the above consistently, i.e. `StreamReader`, `FileReader`, `StreamWriter`, `FileWriter`. Those are breaking changes. An additional tweak: removed the generic type bounds from struct definitions on the four types, as that is the idiomatic Rust approach (see e.g. stdlib's HashMap that has no bounds on the struct definition, only the impl requires Hash + Eq). See #6099 for the discussion. * improvements to docs in `arrow::ipc::reader` and `writer` Applied a few suggestions, made `Error` sections more consistent. * Use `LevelHistogram` in `PageIndex` (#6135) * use LevelHistogram in PageIndex and ColumnIndexBuilder * revert changes to OffsetIndexBuilder * Fix comparison kernel benchmarks (#6147) * fix comparison kernel benchmarks * add comment as suggested by @alamb * Implement exponential block size growing strategy for `StringViewBuilder` (#6136) * new block size growing strategy * Update arrow-array/src/builder/generic_bytes_view_builder.rs Co-authored-by: Andrew Lamb * update function name, deprecate old function * update comments --------- Co-authored-by: Andrew Lamb * improve LIKE regex (#6145) * Improve `LIKE` performance for "contains" style queries (#6128) * improve "contains" performance * add tests * cargo fmt :disappointed: --------- Co-authored-by: Andrew Lamb * improvements to `(i)starts_with` and `(i)ends_with` performance (#6118) * improvements to "starts_with" and "ends_with" * add tests and refactor slightly * add comments * Add `BooleanArray::new_from_packed` and `BooleanArray::new_from_u8` (#6127) * Support construct BooleanArray from &[u8] * fix doc * add new_from_packed and new_from_u8; delete impl From<&[u8]> for BooleanArray and BooleanBuffer * Update object store MSRV to `1.64` (#6123) * Update MSRV to 1.64 * Revert "clippy ignore" This reverts commit 7a4b760bfb2a63c7778b20a4710c2828224f9565. * Upgrade protobuf definitions to flightsql 17.0 (#6133) (#6169) * Update FlightSql.proto to version 17.0 Adds new message CommandStatementIngest and removes `experimental` from other messages. * Regenerate flight sql protocol This upgrades the file to version 17.0 of the protobuf definition. Co-authored-by: Douglas Anderson * Add additional documentation and examples to ArrayAccessor (#6141) * Minor: Update release schedule in README (#6125) * Minor: Update release schedule in README * prettier * fixp * Optimize `take` kernel for `BinaryViewArray` and `StringViewArray` (#6168) * improve speed of view take kernel * ArrayData -> new_unchecked * Update arrow-select/src/take.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * Minor: improve comments in temporal.rs tests (#6140) * Support `StringView` and `BinaryView` in CDataInterface (#6171) * fix round-trip for view schema in CFFI * add * Make object_store errors non-exhaustive (#6165) * Update snafu (#5930) (#6070) Co-authored-by: Jesse * Update sysinfo requirement from 0.30.12 to 0.31.2 (#6182) * Update sysinfo requirement from 0.30.12 to 0.31.2 Updates the requirements on [sysinfo](https://github.com/GuillaumeGomez/sysinfo) to permit the latest version. - [Changelog](https://github.com/GuillaumeGomez/sysinfo/blob/master/CHANGELOG.md) - [Commits](https://github.com/GuillaumeGomez/sysinfo/compare/v0.30.13...v0.31.2) --- updated-dependencies: - dependency-name: sysinfo dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update example for new sysinfo API --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb * No longer write Parquet column metadata after column chunks *and* in the footer (#6117) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` (#6041) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` Signed-off-by: Bugen Zhao * fix example tests Signed-off-by: Bugen Zhao --------- Signed-off-by: Bugen Zhao * Remove `impl> From for Buffer` that easily accidentally copies data (#6043) * deprecate auto copy, ask explicit reference * update comments * make cargo doc happy * Make display of interval types more pretty (#6006) * improve dispaly for interval. * update test in pretty, and fix display problem. * tmp * fix tests in arrow-cast. * fix tests in pretty. * fix style. * Update snafu (#5930) * Update Parquet thrift generated structures (#6045) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * Revert "Revert "Write Bloom filters between row groups instead of the end (#…" (#5933) This reverts commit 22e0b4432c9838f2536284015271d3de9a165135. * Revert "Update snafu (#5930)" (#6069) This reverts commit 756b1fb26d1702f36f446faf9bb40a4869c3e840. * Update pyo3 requirement from 0.21.1 to 0.22.1 (fixed) (#6075) * Update pyo3 requirement from 0.21.1 to 0.22.1 Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/main/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.1) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * refactor: remove deprecated `FromPyArrow::from_pyarrow` "GIL Refs" are being phased out. * chore: update `pyo3` in integration tests --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * remove repeated codes to make the codes more concise. (#6080) * Add `unencoded_byte_array_data_bytes` to `ParquetMetaData` (#6068) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * add support for unencoded_byte_array_data_bytes * add comments * change sig of ColumnMetrics::update_variable_length_bytes() * rename ParquetOffsetIndex to OffsetSizeIndex * rename some functions * suggestion from review Co-authored-by: Andrew Lamb * add Default trait to ColumnMetrics as suggested in review * rename OffsetSizeIndex to OffsetIndexMetaData --------- Co-authored-by: Andrew Lamb * Update pyo3 requirement from 0.21.1 to 0.22.2 (#6085) Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/v0.22.2/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.2) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Deprecate read_page_locations() and simplify offset index in `ParquetMetaData` (#6095) * deprecate read_page_locations * add to_thrift() to OffsetIndexMetaData * no longer write inline column metadata * Update parquet/src/column/writer/mod.rs Co-authored-by: Ed Seidl * suggestion from review Co-authored-by: Andrew Lamb * add some more documentation * remove write_metadata from PageWriter --------- Signed-off-by: Bugen Zhao Signed-off-by: dependabot[bot] Co-authored-by: Bugen Zhao Co-authored-by: Xiangpeng Hao Co-authored-by: kamille Co-authored-by: Jesse Co-authored-by: Andrew Lamb Co-authored-by: Marco Neumann Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * add filter benchmark for fsb (#6186) * Add support for `StringView` and `BinaryView` statistics in `StatisticsConverter` (#6181) * Add StringView and BinaryView support for the macro `get_statistics` * Add StringView and BinaryView support for the macro `get_data_page_statistics` * add tests to cover the support for StringView and BinaryView in the macro get_data_page_statistics * found potential bugs and ignore the tests * fake alarm! no bugs, fix the code by initiating all batches to have 5 rows * make the get_stat StringView and BinaryView tests cover bytes greater than 12 * Benchmarks for `bool_and` (#6189) * Fix typo in documentation of Float64Array (#6188) * feat(parquet): Implement AsyncFileWriter for `object_store::buffered::BufWriter` (#6013) * feat(parquet): Implement AsyncFileWriter for obejct_store::BufWriter Signed-off-by: Xuanwo * Fix build Signed-off-by: Xuanwo * Bump object_store Signed-off-by: Xuanwo * Apply suggestions from code review Co-authored-by: Andrew Lamb * Address comments Signed-off-by: Xuanwo * Add comments Signed-off-by: Xuanwo * Make it better to read Signed-off-by: Xuanwo * Fix docs Signed-off-by: Xuanwo --------- Signed-off-by: Xuanwo Co-authored-by: Andrew Lamb * Support Parquet `BYTE_STREAM_SPLIT` for INT32, INT64, and FIXED_LEN_BYTE_ARRAY primitive types (#6159) * add todos to help trace flow * add support for byte_stream_split encoding for INT32 and INT64 data * byte_stream_split encoding for fixed_len_byte_array * revert changes to Decoder and add VariableWidthByteStreamSplitDecoder * remove set_type_width as it is now unused * begin implementing roundtrip test * move test * clean up some documentation * add test of byte_stream_split with flba * add check for and test of mismatched sizes * remove type_length from Encoder and add VaribleWidthByteStreamSplitEncoder * fix clippy error * change type of argument to new() * formatting * add another test * add variable to split/join streams for FLBA * more informative error message * avoid buffer copies in decoder per suggestion from review * add roundtrip test * optimized version...but clippy complains * clippy was right...replace loop with copy_from_slice * fix test * optimize split_streams_variable for long type widths * Reduce bounds check in `RowIter`, add `unsafe Rows::row_unchecked` (#6142) * update * update comment * update row-iter bench * make clippy happy * Update zstd-sys requirement from >=2.0.0, <2.0.13 to >=2.0.0, <2.0.14 (#6196) Updates the requirements on [zstd-sys](https://github.com/gyscos/zstd-rs) to permit the latest version. - [Release notes](https://github.com/gyscos/zstd-rs/releases) - [Commits](https://github.com/gyscos/zstd-rs/commits) --- updated-dependencies: - dependency-name: zstd-sys dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add `ThriftMetadataWriter` for writing Parquet metadata (#6197) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` (#6041) * bump `tonic` to 0.12 and `prost` to 0.13 for `arrow-flight` Signed-off-by: Bugen Zhao * fix example tests Signed-off-by: Bugen Zhao --------- Signed-off-by: Bugen Zhao * Remove `impl> From for Buffer` that easily accidentally copies data (#6043) * deprecate auto copy, ask explicit reference * update comments * make cargo doc happy * Make display of interval types more pretty (#6006) * improve dispaly for interval. * update test in pretty, and fix display problem. * tmp * fix tests in arrow-cast. * fix tests in pretty. * fix style. * Update snafu (#5930) * Update Parquet thrift generated structures (#6045) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * Revert "Revert "Write Bloom filters between row groups instead of the end (#…" (#5933) This reverts commit 22e0b4432c9838f2536284015271d3de9a165135. * Revert "Update snafu (#5930)" (#6069) This reverts commit 756b1fb26d1702f36f446faf9bb40a4869c3e840. * Update pyo3 requirement from 0.21.1 to 0.22.1 (fixed) (#6075) * Update pyo3 requirement from 0.21.1 to 0.22.1 Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/main/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.1) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * refactor: remove deprecated `FromPyArrow::from_pyarrow` "GIL Refs" are being phased out. * chore: update `pyo3` in integration tests --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * remove repeated codes to make the codes more concise. (#6080) * Add `unencoded_byte_array_data_bytes` to `ParquetMetaData` (#6068) * update to latest thrift (as of 11 Jul 2024) from parquet-format * pass None for optional size statistics * escape HTML tags * don't need to escape brackets in arrays * add support for unencoded_byte_array_data_bytes * add comments * change sig of ColumnMetrics::update_variable_length_bytes() * rename ParquetOffsetIndex to OffsetSizeIndex * rename some functions * suggestion from review Co-authored-by: Andrew Lamb * add Default trait to ColumnMetrics as suggested in review * rename OffsetSizeIndex to OffsetIndexMetaData --------- Co-authored-by: Andrew Lamb * Update pyo3 requirement from 0.21.1 to 0.22.2 (#6085) Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/v0.22.2/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.21.1...v0.22.2) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Deprecate read_page_locations() and simplify offset index in `ParquetMetaData` (#6095) * deprecate read_page_locations * add to_thrift() to OffsetIndexMetaData * Update parquet/src/column/writer/mod.rs Co-authored-by: Ed Seidl * Upgrade protobuf definitions to flightsql 17.0 (#6133) * Update FlightSql.proto to version 17.0 Adds new message CommandStatementIngest and removes `experimental` from other messages. * Regenerate flight sql protocol This upgrades the file to version 17.0 of the protobuf definition. * Add `ParquetMetadataWriter` allow ad-hoc encoding of `ParquetMetadata` * fix loading in test by etseidl Co-authored-by: Ed Seidl * add rough equivalence test * one more check * make clippy happy * separate tests that require arrow into a separate module * add histograms to to_thrift() --------- Signed-off-by: Bugen Zhao Signed-off-by: dependabot[bot] Co-authored-by: Bugen Zhao Co-authored-by: Xiangpeng Hao Co-authored-by: kamille Co-authored-by: Jesse Co-authored-by: Ed Seidl Co-authored-by: Andrew Lamb Co-authored-by: Marco Neumann Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Douglas Anderson Co-authored-by: Ed Seidl * Add (more) Parquet Metadata Documentation (#6184) * Minor: Add (more) Parquet Metadata Documenation * fix clippy * fix parquet type is_optional comment (#6192) Co-authored-by: jp0317 * Remove duplicated statistics tests in parquet (#6190) * move all tests to parquet/tests/arrow_reader/statistics.rs, and leave a comment in original file * remove duplicated tests and adjust the empty tests * data file tests brought folders changes * fix lint * add comments Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * fix: interleave docs suggests itself, not take (#6210) * fix: Correctly handle take on dense union of a single selected type (#6209) * fix: use filter instead of filter_primitive * fix: remove pub(crate) from filter_primitive * fix: run cargo fmt * fix: clippy * Make it clear that StatisticsConverter can not panic (#6187) * Optimize `min_boolean` and `bool_and` (#6144) * Optimize `min_boolean` and `bool_and` Closes #https://github.com/apache/arrow-rs/issues/6103 * use any * Add benchmarks for `BYTE_STREAM_SPLIT` encoded Parquet `FIXED_LEN_BYTE_ARRAY` data (#6204) * save type_width for fixed_len_byte_array * add decimal128 and float16 byte_stream_split benches * add f16 * add decimal128 flba(16) bench * fix(arrow): restrict the range of temporal values produced via `data_gen` (#6205) * fix: random timestamp array * fix: restrict range of randomly generated temporal values * fix: exclusive range used * Support casting between BinaryView <--> Utf8 and LargeUtf8 (#6180) * support cast between binaryview and string * update impl. and add bench mark * Add ut for views * Apply coments * feat(object_store): add `PermissionDenied` variant to top-level error (#6194) * feat(object_store): add `PermissionDenied` variant to top-level error * Update object_store/src/lib.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * refactor: add additional error variant for unauthenticated ops * fix: include path in unauthenticated error --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * update BYTE_STREAM_SPLIT documentation (#6212) * Add time dictionary coercions (#6208) * Add time dictionary coercions * format * Pass through primitive values * use spaces not tabs everywhere (#6217) * Implement specialized filter kernel for `FixedSizeByteArray` (#6178) * refactor filter for FixedSizeByteArray * fix expect * remove benchmark code * fix * remove from_trusted_len_iter_slice_u8 * fmt --------- Co-authored-by: Andrew Lamb * fix: lexsort_to_indices should not fallback to non-lexical sort if the datatype is not supported (#6225) * fix: lexsort_to_indices should not fallback to non-lexical sort if the datatype is not supported * fix clippy * Check error message * Prepare for object_store `0.11.0` release (#6227) * Update version to 0.11.0 * Changelog for 0.11.0 * Remove irrelevant content from changelog * Improve interval parsing (#6211) * improve interval parsing * rename * cleanup * fix formatting * make IntervalParseConfig public * add debug to IntervalParseConfig * fmt * Add LICENSE and NOTICE files to object_store (#6234) * Add LICENSE and NOTICE files to object_store * Update object_store/NOTICE.txt Co-authored-by: Xuanwo * Update object_store/LICENSE.txt --------- Co-authored-by: Xuanwo * Update changelog for object_store 0.11.0 release (#6238) * Minor: Remove non standard footer from LICENSE.txt (#6237) * Minor: Improve Type documentation (#6224) * Minor: Improve XXXType documentation * Update arrow-array/src/types.rs Co-authored-by: Marco Neumann --------- Co-authored-by: Marco Neumann * Add "take" workflow for self-assigning tickets, add "how to find issues" to contributor guide (#6059) * Add "take" workflow for contributors to assign themselves to tickets * Copy datafusion Finding and Creating Issues to work on * Move `ParquetMetadataWriter` to its own module, update documentation (#6202) * Move `ThriftMetadataWriter` and `ParquetMetadataWriter` to a new module * Improve documentation, make pub(crate) * Apply suggestions from code review Co-authored-by: Ed Seidl * Add comment side effect of writing column and offset indexes * Document how to write bloom filters * Update parquet/src/file/metadata/writer.rs Co-authored-by: Ed Seidl --------- Co-authored-by: Ed Seidl * Modest improvement to FixedLenByteArray BYTE_STREAM_SPLIT arrow decoder (#6222) * replace reserve/push with resize/direct access * remove import * make a bit faster * Improve performance of `FixedLengthBinary` decoding (#6220) * add set_from_bytes to ParquetValueType * change naming of FLBA types so critcmp will work * minor enhance doc for ParquetField (#6239) * Remove unnecessary null buffer construction when converting arrays to a different type (#6244) * create primitive array from iter and nulls * clippy * speed up some more decimals * add optimizations for byte_stream_split * decimal256 * Revert "add optimizations for byte_stream_split" This reverts commit 5d4ae0dc09f95ee9079b46b117fb554f63157564. * add comments * Add examples to `StringViewBuilder` and `BinaryViewBuilder` (#6240) * Add examples to `StringViewBuilder` and `BinaryViewBuilder` * add doc link * Implement PartialEq for GenericBinaryArray (#6241) * parquet Statistics - deprecate `has_*` APIs and add `_opt` functions that return `Option` (#6216) * update public api Statistics::min to return an option. I first re-named the existing method to `min_unchecked` and made it internal to the crate. I then added a `pub min(&self) -> Opiton<&T>` method. I figure we can first change the public API before deciding what to do about internal usage. Ref: https://github.com/apache/arrow-rs/issues/6093 * update public api Statistics::max to return an option. I first re-named the existing method to `max_unchecked` and made it internal to the crate. I then added a `pub max(&self) -> Opiton<&T>` method. I figure we can first change the public API before deciding what to do about internal usage. Ref: https://github.com/apache/arrow-rs/issues/6093 * cargo fmt * remove Statistics::has_min_max_set from the public api Ref: https://github.com/apache/arrow-rs/issues/6093 * update impl HeapSize for ValueStatistics to use new min and max api * migrate all tests to new Statistics min and max api * make Statistics::null_count return Option This removes ambiguity around whether the between all values are non-null or just that the null count stat is missing Ref: https://github.com/apache/arrow-rs/issues/6215 * update expected metadata memory size tests Changing null_count from u64 to Option increases the memory size and layout of the metadata. I included these tests as a separate commit to call extra attention to it. * add TODO question on is_min_max_backwards_compatible * Apply suggestions from code review Co-authored-by: Andrew Lamb * update ValueStatistics::max docs * rename new optional ValueStatistics::max to max_opt Per PR review, we will deprecate the old API instead of introducing a brekaing change. Ref: https://github.com/apache/arrow-rs/pull/6216#pullrequestreview-2236537291 * rename new optional ValueStatistics::min to min_opt * add Statistics:{min,max}_bytes_opt This adds the API and migrates all of the test usage. The old APIs will be deprecated next. * update make_stats_iterator macro to use *_opt methods * deprecate non *_opt Statistics and ValueStatistics methods * remove stale TODO comments * remove has_min_max_set check from make_decimal_stats_iterator The check is unnecessary now that the stats funcs return Option when unset. * deprecate has_min_max_set An internal version was also created because it is used so extensively in testing. * switch to null_count_opt and reintroduce deprecated null_count and has_nulls * remove redundant test assertions of stats._internal_has_min_max_set This removes the assertion from any test that subsequently unwraps both min_opt and max_opt. * replace negated test assertions of stats._internal_has_mix_max_set with assertions on min_opt and max_opt This removes all use of Statistics::_internal_has_min_max_set from the code base, and so it is also removed. * Revert changes to parquet writing, update comments --------- Co-authored-by: Andrew Lamb * Minor: Update DateType::Date64 docs (#6223) * feat(object_store): add support for server-side encryption with customer-provided keys (SSE-C) (#6230) * Add support for server-side encryption with customer-provided keys (SSE-C). * Add SSE-C test using MinIO. * Visibility change * add nocapture to verify the test indeed runs * cargo fmt * Update object_store/src/aws/mod.rs use environment variables Co-authored-by: Will Jones * Update object_store/CONTRIBUTING.md use environment variables Co-authored-by: Will Jones * Fix api --------- Co-authored-by: Will Jones * Expose bulk ingest in flight sql client and server (#6201) * Expose CommandStatementIngest as pub in sql module * Add do_put_statement_ingest to FlightSqlService Dispatch this handler for the new CommandStatementIngest command. * Sort list * Implement stub do_put_statement_ingest in example * Refactor helper functions into tests/common/utils * Implement execute_ingest for flight sql client I referenced the C++ implementation here: https://github.com/apache/arrow/commit/0d1ea5db1f9312412fe2cc28363e8c9deb2521ba * Add integration test for sql client execute_ingest * Fix lint clippy::new_without_default * Allow streaming ingest for FlightClient::execute_ingest * Properly return client errors --------- Co-authored-by: Andrew Lamb * docs: Add parquet_opendal in related projects (#6236) * docs: Add parquet_opendal in related projects * Fix spaces * Avoid infinite loop in bad parquet by checking the number of rep levels (#6232) * check the number of rep levels read from page * minor fix on typo Co-authored-by: Andrew Lamb * add check on record_read as well --------- Co-authored-by: jp0317 Co-authored-by: Andrew Lamb * Make the bearer token visible in FlightSqlServiceClient (#6254) * Make the bearer token visible in FlightSqlServiceClient * Update client.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * Add tests for bad parquet files (#6262) * Add tests for bad parquet files * Reenable test * Add test for very subltley different file * Update parquet object_store dependency to 0.11.0 (#6264) * Implement date_part for durations (#6246) Signed-off-by: Nick Cameron * feat: further TLS options on ClientOptions: #5034 (#6148) * feat: further TLS options on ClientOptions: #5034 * Rename to Certificate and with_root_certificate, add docs --------- Co-authored-by: Andrew Lamb * Improve documentation for MutableArrayData (#6272) * Do not print compression level in schema printer (#6271) The compression level is only used during compression, not decompression, and isn't actually stored in the metadata. Printing it is misleading. * Add `Statistics::distinct_count_opt` and deprecate `Statistics::distinct_count` (#6259) * Fix accessing name from ffi schema (#6273) * Fix accessing name from ffi schema * Add test * ci: use octokit to add assignee (#6267) * Only add encryption headers for for SSE-C in get. (#6260) * Minor: move `FallibleRequestStream` and `FallibleTonicResponseStream` to a module (#6258) * Minor: move FallibleRequestStream and FallibleTonicResponseStream to their own modules * Improve documentation and add links * Minor: `pub use ByteView` in arrow and improve documentation (#6275) * Minor: `pub use ByteView` in arrow and improve documentation * clarify docs more * ci: simplify octokit add assignee (#6280) * Update tower requirement from 0.4.13 to 0.5.0 (#6250) * Update tower requirement from 0.4.13 to 0.5.0 Updates the requirements on [tower](https://github.com/tower-rs/tower) to permit the latest version. - [Release notes](https://github.com/tower-rs/tower/releases) - [Commits](https://github.com/tower-rs/tower/compare/tower-0.4.13...tower-0.5.0) --- updated-dependencies: - dependency-name: tower dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Add tower version --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb * Fix panic in comparison_kernel benchmarks (#6284) * Fix panic in comparison_kernel benchmarks * Add other special case equality kernels * Add other benchmarks * fix reference in doctest to size_of which is not imported by default (#6286) This corrects an issue with this doctest noticed on FreeBSD/amd64 with rustc 1.77.0 * Use `unary()` for array conversion in Parquet array readers, speed up `Decimal128`, `Decimal256` and `Float16` (#6252) * add unary to FixedSizeBinaryArray; use unary for transformations * clean up documentation some * add from_unary to PrimitiveArray * use from_unary for converting byte array to decimal * rework from_unary to skip vector initialization * add example to from_unary docstring * fix broken link * add comments per review suggestion * Support writing UTC adjusted time arrays to parquet (#6278) * check if time is adjusted to utc from metadata * add test * add roundtrip test * cargo fmt * Fix regression --------- Co-authored-by: Andrew Lamb * Minor: improve `RowFilter` and `ArrowPredicate` docs (#6301) * Minor: improve `RowFilter` and `ArrowPredicate` docs * tweak * Specialize Prefix/Suffix Match for `Like/ILike` between Array and Scalar for StringViewArray (#6231) * v2 impl * Add bench * fix clippy * fix endswith * Finalize the prefix_v2 implementation * stop reverse string for ends_with * Fix comments * fix bad comment * Correct equals sematics * Err on `try_from_le_slice` (#6295) * Err on try_from_le_slice, fix #3577 * format and changes * small cleanup * fix clippy * add bad metadata test * run test only if feature is enabled * add MRE test * fmt * feat(parquet): add union method to RowSelection (#6308) Complements the existing RowSelection::intersection method. Useful for Or-ing row selections together, in contrast to intersection's use when AND-ing selections * Minor: Improve comments on GenericByteViewArray::bytes_iter(), prefix_iter() and suffix_iter() (#6306) * Update tonic-build requirement from =0.12.0 to =0.12.2 (#6314) * Update tonic-build requirement from =0.12.0 to =0.12.2 Updates the requirements on [tonic-build](https://github.com/hyperium/tonic) to permit the latest version. - [Release notes](https://github.com/hyperium/tonic/releases) - [Changelog](https://github.com/hyperium/tonic/blob/master/CHANGELOG.md) - [Commits](https://github.com/hyperium/tonic/compare/v0.12.0...v0.12.2) --- updated-dependencies: - dependency-name: tonic-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] * regenerate vendored code --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb * docs[object_store]: clarify the backoff strategy that is actually implemented (#6325) * Clarify the backoff strategy that is actually implemented * Update object_store/src/client/backoff.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * Pass empty vectors as min/max for all null pages when building ColumnIndex (#6316) * pass empty vecs as min/max for all null pages * add test * add some comments to test * Minor: improve filter documentation (#6317) * Minor: improve filter documentation * less space * Fix writing of invalid Parquet ColumnIndex when row group contains null pages (#6319) * Fix writing of invalid Parquet ColumnIndex when row group contains null pages Co-authored-by: Ed Seidl * fix lint * more rusty Co-authored-by: Ed Seidl * re-enable tests --------- Co-authored-by: Ed Seidl Co-authored-by: Andrew Lamb * Derive PartialEq and Eq for parquet::arrow::ProjectionMask (#6330) * Support zero column `RecordBatch`es in pyarrow integration (use RecordBatchOptions when converting a pyarrow RecordBatch) (#6320) * use RecordBatchOptions when converting a pyarrow RecordBatch Ref: https://github.com/apache/arrow-rs/issues/6318 * add assertion that num_rows persists through the round trip * add implementation comment * nicer creation of empty recordbatch in test_empty_recordbatch_with_row_count * use len provided by pycapsule interface when available * update test comment * parquet_derive: Match fields by name, support reading selected fields rather than all (#6269) * support reading pruned parquet * add pruned parquet reading test * better unit test * update comments * deref instead of clone * do not panic * copy integer * restore struct name * update comments --------- Co-authored-by: Ye Yuan Co-authored-by: Andrew Lamb * Specialize filter for structs and sparse unions (#6304) * specialize filter for structs and sparse unions * fix: move nested function to top level * fix: clarify optimization cases --------- Co-authored-by: Andrew Lamb * Prepare arrow/parquet `53.0.0` release (#6338) * Update version to 53.0.0 * Update changelog script * Update CHANGELOG.md * update changelog * Workaround new bug in parquet (#6344) * fix: azure sas token visible in logs (#6323) * fix: clippy warnings from nightly rust 1.82 (#6348) Signed-off-by: Ruihang Xia * [object_store] Propagate env vars as object store client options (#6334) * [object_store] Propagate env vars as object store client options * [object_store] Include the missing variants in the FromStr implementation of ClientConfigKey * cargo fmt * Remove vestigal conbench integration (#6339) * Remove vestigal conbench scripts * rat * feat: add catalog/schema subcommands to flight_sql_client. (#6332) * feat: add catalog/schema subcommands to flight_sql_client. With this change basic commands are added to query the catalogs and schemas of a Flight SQL server. * fix: adds tests for flight_sql_client cli Additionally adds a builder pattern for the CommandGetTableTypes similar to CommandGetDbSchemas, while its implementation is trivial it helps to have a pattern to follow when implementing the command. * fix: add default to GetTableTypesBuilder * Benchmark for bit_mask (set_bits) (#6353) * Benchmark for bit_mask (set_bits) * address review comments * impl `From>` for `Buffer` (#6355) * `object_store::GetOptions` derive `Clone` (#6361) * object_store::GetOptions derive Clone * undo wrong submodule * bump * object_store/delimited: Fix `TrailingEscape` condition (#6265) This seems like a copy-paste mistake since checking `is_quote` twice is probably wrong... * Add breaking change from `#6043` to `CHANGELOG` (#6354) * Manually run fmt on all files under parquet (#6328) * manually run fmt on all file under parquet * apply formatting that had been skipped * add more to comment * update formatting instructions * Update chrono-tz requirement from 0.9 to 0.10 (#6371) Updates the requirements on [chrono-tz](https://github.com/chronotope/chrono-tz) to permit the latest version. - [Release notes](https://github.com/chronotope/chrono-tz/releases) - [Commits](https://github.com/chronotope/chrono-tz/compare/v0.9.0...v0.10.0) --- updated-dependencies: - dependency-name: chrono-tz dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add support for Utf8View in arrow_string::length (#6345) * Add support for Utf8View in arrow_string::length #6305 * Cargo fmt. * Add support for BinaryView in arrow_string::length (#6359) * Add support for BinaryView in arrow_string::length * Adding a longer binary value to binary view test. * Improve `GenericStringBuilder` documentation (#6372) * Improve GenericStringBuilder documentation * Update arrow-array/src/builder/generic_bytes_builder.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * Update prost-build requirement from =0.13.1 to =0.13.2 (#6350) * Update prost-build requirement from =0.13.1 to =0.13.2 Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Changelog](https://github.com/tokio-rs/prost/blob/master/CHANGELOG.md) - [Commits](https://github.com/tokio-rs/prost/compare/v0.13.1...v0.13.2) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] * chore: update vendored code --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb * add "ARROW_VERSION" const (#6379) * Support StringViewArray interop with python: fix lingering C Data Interface issues for *ViewArray (#6368) * fix lingering C Data Interface issues for *ViewArray Fixes https://github.com/apache/arrow-rs/issues/6366 * report views length in elements -> bytes * use pyarrow 17 * use only good versions * fix support for View arrays in C FFI, add test * update comment in github action * more ffi test cases * more byte_view tests for into_pyarrow * parquet writer: Raise an error when the row_group_index overflows i16 (#6378) This caused confusing panics down the line because 'ordinal' is negative. * impl `From>` for `Buffer` (#6389) * Clear string-tracking hash table when ByteView deduplication is enabled (#6385) * Improve performance of set_bits by avoiding to set individual bits (#6288) * bench * fix: Optimize set_bits * clippy * clippyj * miri * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * miri * miri * miri * miri * miri * miri * miri * miri * miri * miri * miri * address review comments * address review comments * address review comments * Revert "address review comments" This reverts commit ef2864fe15d2c856c05eae70693d68eb2ae00fa8. * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * Revert "address review comments" This reverts commit a15db144effdfdae7dad4d93c8fb6eb93216dab0. * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * stop panic in `MetadataLoader` on invalid data (#6367) * stop panic in MetadataLoader on invalid data * better check for invalid prefect * limit hint instead of erroring * import FOOTER_SIZE * Update lexical-core requirement from 0.8 to 1.0 (to resolve RUSTSEC-2023-0086) (#6402) * Update lexical-core requirement from 0.8 to 1.0 * Remove safety comment * Remove "NOT YET FULLY SUPPORTED" comment from DataType::Utf8View/BinaryView (#6380) * Move lifetime of `take_iter` from iterator to its items (#6403) * Derive `Clone` for `object_store::aws::AmazonS3` (#6414) * fix: binary_mut should work if only one input array has null buffer (#6396) * fix: binary_mut should work if only one input array has null buffer * Avoid copying null buffer in binary_mut * Update arrow-arith/src/arity.rs Co-authored-by: Andrew Lamb * Update arrow-arith/src/arity.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb * Fix encoding/decoding REE Dicts when using streaming IPC (#6399) * arrow-ipc: Add test for streaming IPC with REE dicts * arrow-schema: Include child fields of REE fields * fix: Stop losing precision and scale when casting decimal to dictionary (#6383) * Stop losing precision and scale when casting decimal to dictionary * address feedback * Rephrase doc comment (#6421) * docs: rephase some Signed-off-by: Ruihang Xia * fix all warnings Signed-off-by: Ruihang Xia * big letter at the beginning Signed-off-by: Ruihang Xia * Apply suggestions from code review Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * Update arrow/src/pyarrow.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * Update arrow-array/src/types.rs Co-authored-by: Matthijs Brobbel --------- Signed-off-by: Ruihang Xia Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Co-authored-by: Matthijs Brobbel * fix: don't panic in IPC reader if struct child arrays have different … (#6417) * fix: don't panic in IPC reader if struct child arrays have different lengths * fix: clippy * test: add ipc read invalid struct test * Add `set_bits` fuzz test (#6394) * Implement set_bits fuzz test * Update arrow-buffer/src/util/bit_mask.rs * Update arrow-buffer/src/util/bit_mask.rs * fix import * Reduce integration test matrix (#6407) Fix #6406 * chore: add docs, part of #37 (#6424) * chore: add docs, part of #37 - add pragma `#![warn(missing_docs)]` to `arrow`, `arrow-arith`, `arrow-avro` - add docs to the same to remove lint warnings * chore: add docs, part of #37 - add pragma `#![warn(missing_docs)]` to `arrow-buffer`, `arrow-cast`, `arrow-csv` - add docs to the same to remove lint warnings * chore: update docs, resolve PR comments * Add RowSelection::skipped_row_count (#6429) * silence warnings (#6432) * object_score: Support Azure Fabric OAuth Provider (#6382) * Update Azure dependencies and add support for Fabric token authentication * Refactor Azure credential provider to support Fabric token authentication * Refactor Azure credential provider to remove unnecessary print statements and improve token handling * Bump object_store version to 0.11.0 * Refactor Azure credential provider to remove unnecessary print statements and improve token handling * perf: Faster decimal precision overflow checks (#6419) * add benchmark * add optimization * fix * fix * cargo fmt * clippy * Update arrow-data/src/decimal.rs Co-authored-by: Liang-Chi Hsieh * optimize to avoid allocating an idx variable * revert change to public api * fix error in rustdoc --------- Co-authored-by: Liang-Chi Hsieh * Implement native support StringViewArray for `regexp_is_match` and `regexp_is_match_scalar` function, deprecate `regexp_is_match_utf8` and `regexp_is_match_utf8_scalar` (#6376) * Implement native support StringViewArray for regex_is_match function * Update test cases cover StringViewArray length more then 12 bytes * Add StringView benchmark for regexp_is_match Signed-off-by: Tai Le Manh * Implement native support StringViewArray for regex_is_match function Signed-off-by: Tai Le Manh * Remove duplicate implementation, fix clippy, add docs more --------- Signed-off-by: Tai Le Manh Co-authored-by: Andrew Lamb * bump arrow-flight msrv to 1.71.1 (#6437) * feat: expose HTTP/2 max frame size in `object_store` (#6442) Especially when transferring large amounts of data over HTTP/2, this can massively reduce the overhead. * chore: add docs, part of #37 (#6433) * chore: add docs, part of #37 - add pragma `#![warn(missing_docs)]` to the following - `arrow-array` - `arrow-cast` - `arrow-csv` - `arrow-data` - `arrow-json` - `arrow-ord` - `arrow-pyarrow-integration-testing` - `arrow-row` - `arrow-schema` - `arrow-select` - `arrow-string` - `arrow` - `parquet_derive` - add docs to those that generated lint warnings - Remove `bitflags` workaround in `arrow-schema` At some point, a change in `bitflags v2.3.0` had started generating lint warnings in `arrow-schema`, This was handled using a [workaround](https://github.com/apache/arrow-rs/pull/4233) [Issue](https://github.com/bitflags/bitflags/issues/356) `bitflags v2.3.1` fixed the issue hence the workaround is no longer needed. * fix: resolve comments on PR #6433 * Fix doc "bit width" to "byte width" (#6434) * Minor: Add some missing documentation to fix CI errors (#6445) * fix CI errors * apply suggestion from review Co-authored-by: ngli-me <107162634+ngli-me@users.noreply.github.com> --------- Co-authored-by: ngli-me <107162634+ngli-me@users.noreply.github.com> * Update prost-build requirement from =0.13.2 to =0.13.3 (#6440) * Update prost-build requirement from =0.13.2 to =0.13.3 Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Changelog](https://github.com/tokio-rs/prost/blob/master/CHANGELOG.md) - [Commits](https://github.com/tokio-rs/prost/compare/v0.13.2...v0.13.3) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] * update vendored code --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb * Add `ParquetMetaDataReader` (#6431) * add ParquetMetaDataReader * clippy * Apply suggestions from code review Co-authored-by: Andrew Lamb * formatting * add ParquetMetaDataReader to module documentation * document erros returned from `try_parse_sized` * oops * rename methods per review suggestion --------- Co-authored-by: Andrew Lamb * throw arrow error instead of panic (#6456) * Disable rust<>nanoarrow integration test in CI (#6449) * Add `union_extract` kernel (#6387) * feat: add union_extract kernel * fix: reexport union_extract in arrow crate * add tests, improve docs, simplify code --------- Co-authored-by: Andrew Lamb * Add `IpcSchemaEncoder`, deprecate ipc schema functions, Fix IPC not respecting not preserving dict ID (#6444) * arrow-ipc: Add test for non preserving dict ID behavior with same ID * arrow-ipc: Always set dict ID in IPC from dictionary tracker This decouples dictionary IDs that end up in IPC from the schema further because the dictionary tracker always first gathers the dict ID for each field whether it is pre-defined and preserved or not. Then when actually writing the IPC bytes the dictionary ID is always taken from the dictionary tracker as opposed to falling back to the `Field` of the `Schema`. * arrow-ipc: Read dictionary IDs from dictionary tracker in correct order When dictionary IDs are not preserved, then they are assigned depth first, however, when reading them from the dictionary tracker to write the IPC bytes, they were previously read from the dictionary tracker in the order that the schema is traversed (first come first serve), which caused an incorrect order of dictionaries serialized in IPC. * Refine IpcSchemaEncoder API and docs * reduce repeated code * Fix lints --------- Co-authored-by: Andrew Lamb * Add additional documentation and builder APIs to `SortOptions` (#6441) * Minor: Add additional documentation and builder APIs to `SortOptions` * Port some uses * Update defaults * Add nulls_first() and nulls_last() and more examples * Workaround for missing Parquet page indexes in `ParquetMetadaReader` (#6450) * workaround for missing page indexes * remove empty line * Apply suggestions from code review Co-authored-by: Andrew Lamb * fmt --------- Co-authored-by: Andrew Lamb * Support cast between Durations + between Durations all numeric types (#6452) * Support cast between Durations Signed-off-by: tison * Support cast between Durations and all numeric type Signed-off-by: tison * Impl cast between Durations Signed-off-by: tison * Add test_cast_between_durations Signed-off-by: tison * add test cases Signed-off-by: tison * cargo clippy Signed-off-by: tison --------- Signed-off-by: tison * Update Cargo.toml (#6459) * remove dup * empty --------- Signed-off-by: dependabot[bot] Signed-off-by: Bugen Zhao Signed-off-by: Nick Cameron Signed-off-by: Xuanwo Signed-off-by: Ruihang Xia Signed-off-by: Tai Le Manh Signed-off-by: tison Co-authored-by: Judah Rand <17158624+judahrand@users.noreply.github.com> Co-authored-by: Will Jones Co-authored-by: Xiangpeng Hao Co-authored-by: Michael Maletich Co-authored-by: Andrew Lamb Co-authored-by: Tomoaki Kawada Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eduard Karacharov Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Co-authored-by: Liang-Chi Hsieh Co-authored-by: wiedld Co-authored-by: Jay Zhan Co-authored-by: Trent Hauck Co-authored-by: Xuanwo Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Co-authored-by: Tom Forbes Co-authored-by: Zhao Gang Co-authored-by: double-free Co-authored-by: Ye Yuan Co-authored-by: Owen Leung Co-authored-by: Chris Riccomini Co-authored-by: Eric Fredine Co-authored-by: Eric Fredine Co-authored-by: Xiangpeng Hao Co-authored-by: Artem Medvedev Co-authored-by: Samuel Colvin Co-authored-by: 张林伟 Co-authored-by: kamille Co-authored-by: Luca Versari Co-authored-by: Matthijs Brobbel Co-authored-by: Hesam Pakdaman <14890379+hesampakdaman@users.noreply.github.com> Co-authored-by: Val Lorentz Co-authored-by: Yongting You <2010youy01@gmail.com> Co-authored-by: Ed Seidl Co-authored-by: barronw <141040627+barronw@users.noreply.github.com> Co-authored-by: Trung Dinh Co-authored-by: Val Lorentz Co-authored-by: Andrew Duffy Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Co-authored-by: Fischer <89784872+Fischer0522@users.noreply.github.com> Co-authored-by: Simon Vandel Sillesen Co-authored-by: V0ldek Co-authored-by: Bugen Zhao Co-authored-by: Jesse Co-authored-by: Marco Neumann Co-authored-by: Andy Grove Co-authored-by: Nick Cameron Co-authored-by: Alexander Rafferty Co-authored-by: pn <13125187405@163.com> Co-authored-by: Douglas Anderson Co-authored-by: kf zheng <100595273+Kev1n8@users.noreply.github.com> Co-authored-by: Daniel Mesejo Co-authored-by: Ed Seidl Co-authored-by: Jinpeng Co-authored-by: jp0317 Co-authored-by: gstvg <28798827+gstvg@users.noreply.github.com> Co-authored-by: Kyle McCarthy Co-authored-by: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Co-authored-by: mwish Co-authored-by: Michael J Ward Co-authored-by: Jiacheng Yang <92543367+jiachengdb@users.noreply.github.com> Co-authored-by: Costi Ciudatu Co-authored-by: ByteBaker <42913098+ByteBaker@users.noreply.github.com> Co-authored-by: Thomas ten Cate Co-authored-by: Kyle Barron Co-authored-by: dsgibbons Co-authored-by: R. Tyler Croy Co-authored-by: aykut-bozkurt <51649454+aykut-bozkurt@users.noreply.github.com> Co-authored-by: Scott Donnelly Co-authored-by: Weston Pace Co-authored-by: Alex Wilcoxson Co-authored-by: Ruihang Xia Co-authored-by: Nathaniel Cook Co-authored-by: KAZUYUKI TANIMURA Co-authored-by: Tobias Bieniek Co-authored-by: Bruce Ritchie Co-authored-by: Shane Sveller Co-authored-by: KAZUYUKI TANIMURA Co-authored-by: Dario Curreri <48800335+dariocurr@users.noreply.github.com> Co-authored-by: Tzu Gwo Co-authored-by: Frederic Branczyk Co-authored-by: Sutou Kouhei Co-authored-by: Robin Lin <128118209+RobinLin666@users.noreply.github.com> Co-authored-by: Tai Le Manh <49281946+tlm365@users.noreply.github.com> Co-authored-by: ngli-me <107162634+ngli-me@users.noreply.github.com> Co-authored-by: Jax Liu Co-authored-by: tison Co-authored-by: Alexander Shtuchkin --- .asf.yaml | 8 +- .gitattributes | 9 +- .github/actions/setup-builder/action.yaml | 21 +- .github/dependabot.yml | 11 +- .github/workflows/arrow.yml | 200 +- .github/workflows/arrow_flight.yml | 47 +- .github/workflows/audit.yml | 43 + .github/workflows/cancel.yml | 54 - .github/workflows/coverage.yml | 8 +- .github/workflows/dev.yml | 14 +- .github/workflows/dev_pr.yml | 11 +- .github/workflows/dev_pr/labeler.yml | 34 +- .github/workflows/docs.yml | 52 +- .github/workflows/integration.yml | 114 +- .github/workflows/miri.sh | 13 +- .github/workflows/miri.yaml | 22 +- .github/workflows/object_store.yml | 159 +- .github/workflows/parquet.yml | 129 +- .github/workflows/parquet_derive.yml | 24 +- .github/workflows/rust.yml | 67 +- .github/workflows/take.yml | 39 + .github_changelog_generator | 2 +- .gitignore | 7 +- CHANGELOG-old.md | 2821 ++++- CHANGELOG.md | 230 +- CONTRIBUTING.md | 110 +- Cargo.toml | 77 +- LICENSE.txt | 10 - README.md | 113 +- arrow-arith/Cargo.toml | 45 + arrow-arith/src/aggregate.rs | 1704 +++ arrow-arith/src/arithmetic.rs | 341 + arrow-arith/src/arity.rs | 668 ++ arrow-arith/src/bitwise.rs | 392 + .../kernels => arrow-arith/src}/boolean.rs | 732 +- .../mod.rs => arrow-arith/src/lib.rs | 18 +- arrow-arith/src/numeric.rs | 1523 +++ arrow-arith/src/temporal.rs | 2088 ++++ arrow-array/Cargo.toml | 77 + arrow-array/benches/decimal_overflow.rs | 53 + arrow-array/benches/fixed_size_list_array.rs | 51 + arrow-array/benches/gc_view_types.rs | 48 + arrow-array/benches/occupancy.rs | 57 + arrow-array/src/arithmetic.rs | 867 ++ .../src/array/binary_array.rs | 435 +- arrow-array/src/array/boolean_array.rs | 700 ++ arrow-array/src/array/byte_array.rs | 617 ++ arrow-array/src/array/byte_view_array.rs | 999 ++ arrow-array/src/array/dictionary_array.rs | 1383 +++ .../src/array/fixed_size_binary_array.rs | 489 +- .../src/array/fixed_size_list_array.rs | 693 ++ arrow-array/src/array/list_array.rs | 1184 ++ arrow-array/src/array/map_array.rs | 801 ++ arrow-array/src/array/mod.rs | 1143 ++ arrow-array/src/array/null_array.rs | 197 + arrow-array/src/array/primitive_array.rs | 2734 +++++ arrow-array/src/array/run_array.rs | 1093 ++ arrow-array/src/array/string_array.rs | 566 + arrow-array/src/array/struct_array.rs | 734 ++ .../src/array/union_array.rs | 873 +- .../src}/builder/boolean_builder.rs | 159 +- arrow-array/src/builder/buffer_builder.rs | 225 + .../src}/builder/fixed_size_binary_builder.rs | 107 +- .../src/builder/fixed_size_list_builder.rs | 492 + .../src/builder/generic_byte_run_builder.rs | 514 + .../src/builder/generic_bytes_builder.rs | 543 + .../generic_bytes_dictionary_builder.rs | 630 ++ .../src/builder/generic_bytes_view_builder.rs | 733 ++ .../src/builder/generic_list_builder.rs | 806 ++ arrow-array/src/builder/map_builder.rs | 380 + arrow-array/src/builder/mod.rs | 325 + arrow-array/src/builder/null_builder.rs | 182 + .../src}/builder/primitive_builder.rs | 279 +- .../builder/primitive_dictionary_builder.rs | 402 + .../src/builder/primitive_run_builder.rs | 311 + arrow-array/src/builder/struct_builder.rs | 730 ++ .../src}/builder/union_builder.rs | 156 +- arrow-array/src/cast.rs | 1021 ++ arrow-array/src/delta.rs | 285 + arrow-array/src/ffi.rs | 1694 +++ {arrow => arrow-array}/src/ffi_stream.rs | 296 +- .../src/array => arrow-array/src}/iterator.rs | 75 +- arrow-array/src/lib.rs | 243 + .../csv/mod.rs => arrow-array/src/numeric.rs | 12 +- {arrow => arrow-array}/src/record_batch.rs | 737 +- arrow-array/src/run_iterator.rs | 384 + arrow-array/src/scalar.rs | 152 + arrow-array/src/temporal_conversions.rs | 351 + arrow-array/src/timezone.rs | 339 + .../util => arrow-array/src}/trusted_len.rs | 8 +- arrow-array/src/types.rs | 1637 +++ arrow-avro/Cargo.toml | 56 + arrow-avro/src/codec.rs | 315 + arrow-avro/src/compression.rs | 83 + arrow-avro/src/lib.rs | 41 + arrow-avro/src/reader/block.rs | 141 + arrow-avro/src/reader/header.rs | 345 + arrow-avro/src/reader/mod.rs | 107 + arrow-avro/src/reader/vlq.rs | 46 + arrow-avro/src/schema.rs | 512 + arrow-buffer/Cargo.toml | 57 + arrow-buffer/benches/bit_mask.rs | 58 + arrow-buffer/benches/i256.rs | 86 + arrow-buffer/benches/offset.rs | 49 + .../src/alloc/alignment.rs | 18 +- arrow-buffer/src/alloc/mod.rs | 70 + arrow-buffer/src/arith.rs | 77 + arrow-buffer/src/bigint/div.rs | 302 + arrow-buffer/src/bigint/mod.rs | 1267 +++ arrow-buffer/src/buffer/boolean.rs | 428 + .../src/buffer/immutable.rs | 436 +- arrow-buffer/src/buffer/mod.rs | 35 + {arrow => arrow-buffer}/src/buffer/mutable.rs | 391 +- arrow-buffer/src/buffer/null.rs | 261 + arrow-buffer/src/buffer/offset.rs | 242 + {arrow => arrow-buffer}/src/buffer/ops.rs | 73 +- arrow-buffer/src/buffer/run.rs | 230 + arrow-buffer/src/buffer/scalar.rs | 339 + .../src/builder/boolean.rs | 174 +- .../src/builder/mod.rs | 260 +- .../src/builder/null.rs | 75 +- arrow-buffer/src/builder/offset.rs | 125 + {arrow => arrow-buffer}/src/bytes.rs | 65 +- arrow-buffer/src/interval.rs | 579 + arrow-buffer/src/lib.rs | 45 + arrow-buffer/src/native.rs | 356 + .../src/util/bit_chunk_iterator.rs | 60 +- .../src/util/bit_iterator.rs | 140 +- arrow-buffer/src/util/bit_mask.rs | 432 + {arrow => arrow-buffer}/src/util/bit_util.rs | 95 +- arrow-buffer/src/util/mod.rs | 21 + arrow-cast/Cargo.toml | 79 + arrow-cast/benches/parse_date.rs | 34 + arrow-cast/benches/parse_decimal.rs | 56 + arrow-cast/benches/parse_time.rs | 42 + .../benches/parse_timestamp.rs | 39 +- arrow-cast/src/base64.rs | 120 + arrow-cast/src/cast/decimal.rs | 569 + arrow-cast/src/cast/dictionary.rs | 452 + arrow-cast/src/cast/list.rs | 182 + arrow-cast/src/cast/map.rs | 74 + arrow-cast/src/cast/mod.rs | 9693 +++++++++++++++++ arrow-cast/src/cast/string.rs | 380 + arrow-cast/src/display.rs | 1222 +++ arrow-cast/src/lib.rs | 28 + arrow-cast/src/parse.rs | 2772 +++++ {arrow/src/util => arrow-cast/src}/pretty.rs | 647 +- arrow-csv/Cargo.toml | 53 + arrow-csv/examples/README.md | 21 + arrow-csv/examples/csv_calculation.rs | 56 + arrow-csv/src/lib.rs | 46 + arrow-csv/src/reader/mod.rs | 2682 +++++ arrow-csv/src/reader/records.rs | 387 + arrow-csv/src/writer.rs | 866 ++ arrow-csv/test/data/custom_null_test.csv | 6 + .../test/data/decimal_test.csv | 0 arrow-csv/test/data/example.csv | 4 + arrow-csv/test/data/init_null_test.csv | 6 + {arrow => arrow-csv}/test/data/null_test.csv | 0 .../test/data/scientific_notation_test.csv | 19 + arrow-csv/test/data/truncated_rows.csv | 8 + {arrow => arrow-csv}/test/data/uk_cities.csv | 0 .../test/data/uk_cities_with_headers.csv | 0 .../test/data/various_types.csv | 0 .../test/data/various_types_invalid.csv | 0 arrow-data/Cargo.toml | 57 + arrow-data/src/byte_view.rs | 131 + arrow-data/src/data.rs | 2254 ++++ arrow-data/src/decimal.rs | 900 ++ .../array => arrow-data/src}/equal/boolean.rs | 46 +- arrow-data/src/equal/byte_view.rs | 74 + .../src}/equal/dictionary.rs | 20 +- arrow-data/src/equal/fixed_binary.rs | 99 + .../src}/equal/fixed_list.rs | 20 +- .../array => arrow-data/src}/equal/list.rs | 73 +- arrow-data/src/equal/mod.rs | 166 + .../array => arrow-data/src}/equal/null.rs | 2 +- arrow-data/src/equal/primitive.rs | 97 + arrow-data/src/equal/run.rs | 86 + .../src}/equal/structure.rs | 18 +- .../array => arrow-data/src}/equal/union.rs | 36 +- .../array => arrow-data/src}/equal/utils.rs | 55 +- .../src}/equal/variable_size.rs | 59 +- arrow-data/src/ffi.rs | 341 + arrow-data/src/lib.rs | 36 + .../src}/transform/boolean.rs | 4 +- .../src}/transform/fixed_binary.rs | 38 +- .../src}/transform/fixed_size_list.rs | 44 +- arrow-data/src/transform/list.rs | 54 + arrow-data/src/transform/mod.rs | 839 ++ .../src}/transform/null.rs | 3 +- .../src}/transform/primitive.rs | 9 +- arrow-data/src/transform/structure.rs | 37 + .../src}/transform/union.rs | 16 +- .../src}/transform/utils.rs | 29 +- arrow-data/src/transform/variable_size.rs | 69 + arrow-flight/CONTRIBUTING.md | 41 + arrow-flight/Cargo.toml | 88 +- arrow-flight/README.md | 45 +- arrow-flight/build.rs | 100 - arrow-flight/examples/data/ca.pem | 28 + arrow-flight/examples/data/client1.key | 28 + arrow-flight/examples/data/client1.pem | 19 + arrow-flight/examples/data/client_ca.pem | 19 + arrow-flight/examples/data/server.key | 28 + arrow-flight/examples/data/server.pem | 27 + arrow-flight/examples/flight_sql_server.rs | 837 +- arrow-flight/examples/server.rs | 63 +- arrow-flight/gen/Cargo.toml | 37 + arrow-flight/gen/src/main.rs | 86 + .../benchmarks.py => arrow-flight/regen.sh | 28 +- arrow-flight/src/arrow.flight.protocol.rs | 887 +- arrow-flight/src/bin/flight_sql_client.rs | 429 + arrow-flight/src/client.rs | 673 ++ arrow-flight/src/decode.rs | 434 + arrow-flight/src/encode.rs | 1626 +++ arrow-flight/src/error.rs | 148 + arrow-flight/src/lib.rs | 650 +- .../src/sql/arrow.flight.protocol.sql.rs | 2900 +++-- arrow-flight/src/sql/client.rs | 776 ++ arrow-flight/src/sql/metadata/catalogs.rs | 100 + arrow-flight/src/sql/metadata/db_schemas.rs | 286 + arrow-flight/src/sql/metadata/mod.rs | 77 + arrow-flight/src/sql/metadata/sql_info.rs | 561 + arrow-flight/src/sql/metadata/table_types.rs | 158 + arrow-flight/src/sql/metadata/tables.rs | 476 + arrow-flight/src/sql/metadata/xdbc_info.rs | 428 + arrow-flight/src/sql/mod.rs | 271 +- arrow-flight/src/sql/server.rs | 1055 +- arrow-flight/src/streams.rs | 134 + arrow-flight/src/trailers.rs | 92 + arrow-flight/src/utils.rs | 94 +- arrow-flight/tests/client.rs | 1151 ++ arrow-flight/tests/common/fixture.rs | 118 + arrow-flight/tests/common/mod.rs | 21 + arrow-flight/tests/common/server.rs | 502 + arrow-flight/tests/common/trailers_layer.rs | 124 + arrow-flight/tests/common/utils.rs | 118 + arrow-flight/tests/encode_decode.rs | 503 + arrow-flight/tests/flight_sql_client.rs | 216 + arrow-flight/tests/flight_sql_client_cli.rs | 757 ++ arrow-integration-test/Cargo.toml | 44 + .../data/integration.json | 0 arrow-integration-test/src/datatype.rs | 373 + arrow-integration-test/src/field.rs | 568 + .../src/lib.rs | 528 +- arrow-integration-test/src/schema.rs | 728 ++ .../Cargo.toml | 29 +- .../README.md | 2 +- .../src/bin/arrow-file-to-stream.rs | 2 +- .../src/bin/arrow-json-integration-test.rs | 74 +- .../src/bin/arrow-stream-to-file.rs | 0 .../src/bin/flight-test-integration-client.rs | 7 +- .../src/bin/flight-test-integration-server.rs | 4 +- .../src/flight_client_scenarios.rs | 0 .../auth_basic_proto.rs | 24 +- .../integration_test.rs | 79 +- .../src/flight_client_scenarios/middleware.rs | 16 +- .../src/flight_server_scenarios.rs | 6 +- .../auth_basic_proto.rs | 40 +- .../integration_test.rs | 92 +- .../src/flight_server_scenarios/middleware.rs | 15 +- arrow-integration-testing/src/lib.rs | 302 + arrow-integration-testing/tests/ipc_reader.rs | 228 + arrow-integration-testing/tests/ipc_writer.rs | 256 + arrow-ipc/CONTRIBUTING.md | 37 + arrow-ipc/Cargo.toml | 51 + {arrow => arrow-ipc}/regen.sh | 32 +- .../codec.rs => arrow-ipc/src/compression.rs | 176 +- arrow-ipc/src/convert.rs | 1221 +++ {arrow/src/ipc => arrow-ipc/src}/gen/File.rs | 245 +- .../src/ipc => arrow-ipc/src}/gen/Message.rs | 670 +- .../src/ipc => arrow-ipc/src}/gen/Schema.rs | 2421 ++-- .../ipc => arrow-ipc/src}/gen/SparseTensor.rs | 879 +- .../src/ipc => arrow-ipc/src}/gen/Tensor.rs | 441 +- {arrow/src/ipc => arrow-ipc/src}/gen/mod.rs | 0 arrow/src/ipc/mod.rs => arrow-ipc/src/lib.rs | 5 +- arrow-ipc/src/reader.rs | 2341 ++++ arrow-ipc/src/reader/stream.rs | 377 + arrow-ipc/src/writer.rs | 2775 +++++ arrow-json/Cargo.toml | 63 + arrow-json/benches/serde.rs | 62 + arrow-json/src/lib.rs | 159 + arrow-json/src/reader/boolean_array.rs | 43 + arrow-json/src/reader/decimal_array.rs | 102 + arrow-json/src/reader/list_array.rs | 113 + arrow-json/src/reader/map_array.rs | 154 + arrow-json/src/reader/mod.rs | 2319 ++++ arrow-json/src/reader/null_array.rs | 35 + arrow-json/src/reader/primitive_array.rs | 159 + arrow-json/src/reader/schema.rs | 742 ++ arrow-json/src/reader/serializer.rs | 420 + arrow-json/src/reader/string_array.rs | 129 + arrow-json/src/reader/struct_array.rs | 161 + arrow-json/src/reader/tape.rs | 911 ++ arrow-json/src/reader/timestamp_array.rs | 110 + arrow-json/src/writer.rs | 1808 +++ arrow-json/src/writer/encoder.rs | 547 + {arrow => arrow-json}/test/data/arrays.json | 0 arrow-json/test/data/basic.json | 12 + .../test/data/basic_nulls.json | 0 .../test/data/list_string_dict_nested.json | 0 .../data/list_string_dict_nested_nulls.json | 0 .../test/data/mixed_arrays.json | 0 .../test/data/mixed_arrays.json.gz | Bin .../test/data/nested_structs.json | 0 arrow-json/test/data/nested_with_nulls.json | 4 + arrow-ord/Cargo.toml | 46 + arrow-ord/src/cmp.rs | 855 ++ arrow-ord/src/comparison.rs | 3365 ++++++ arrow-ord/src/lib.rs | 53 + arrow-ord/src/ord.rs | 924 ++ arrow-ord/src/partition.rs | 317 + arrow-ord/src/rank.rs | 191 + .../compute/kernels => arrow-ord/src}/sort.rs | 2976 +++-- arrow-pyarrow-integration-testing/Cargo.toml | 10 +- arrow-pyarrow-integration-testing/README.md | 2 + .../pyproject.toml | 4 +- arrow-pyarrow-integration-testing/src/lib.rs | 97 +- .../tests/test_sql.py | 309 +- arrow-row/Cargo.toml | 56 + arrow-row/src/fixed.rs | 509 + arrow-row/src/lib.rs | 2361 ++++ arrow-row/src/list.rs | 186 + arrow-row/src/variable.rs | 350 + arrow-schema/Cargo.toml | 49 + arrow-schema/src/datatype.rs | 1092 ++ arrow-schema/src/datatype_parse.rs | 783 ++ arrow-schema/src/error.rs | 171 + arrow-schema/src/ffi.rs | 963 ++ arrow-schema/src/field.rs | 1007 ++ arrow-schema/src/fields.rs | 591 + arrow-schema/src/lib.rs | 209 + arrow-schema/src/schema.rs | 1013 ++ arrow-select/Cargo.toml | 48 + arrow-select/src/concat.rs | 853 ++ arrow-select/src/dictionary.rs | 331 + .../kernels => arrow-select/src}/filter.rs | 898 +- arrow-select/src/interleave.rs | 403 + arrow-select/src/lib.rs | 29 + arrow-select/src/nullif.rs | 512 + .../kernels => arrow-select/src}/take.rs | 1709 +-- arrow-select/src/union_extract.rs | 1236 +++ .../kernels => arrow-select/src}/window.rs | 20 +- arrow-select/src/zip.rs | 230 + arrow-string/Cargo.toml | 45 + .../src}/concat_elements.rs | 228 +- .../kernels => arrow-string/src}/length.rs | 575 +- .../kernels/mod.rs => arrow-string/src/lib.rs | 21 +- arrow-string/src/like.rs | 1740 +++ arrow-string/src/predicate.rs | 461 + arrow-string/src/regexp.rs | 808 ++ .../kernels => arrow-string/src}/substring.rs | 307 +- arrow/CONTRIBUTING.md | 29 - arrow/Cargo.toml | 170 +- arrow/README.md | 28 +- arrow/benches/aggregate_kernels.rs | 165 +- arrow/benches/arithmetic_kernels.rs | 149 +- arrow/benches/array_data_validate.rs | 11 +- arrow/benches/array_from_vec.rs | 14 +- arrow/benches/bitwise_kernel.rs | 117 + arrow/benches/boolean_append_packed.rs | 1 - arrow/benches/buffer_bit_ops.rs | 18 +- arrow/benches/buffer_create.rs | 13 +- arrow/benches/builder.rs | 31 +- arrow/benches/cast_kernels.rs | 109 +- arrow/benches/comparison_kernels.rs | 409 +- arrow/benches/concatenate_kernel.rs | 42 + arrow/benches/csv_reader.rs | 184 + arrow/benches/csv_writer.rs | 7 +- arrow/benches/decimal_validate.rs | 29 +- arrow/benches/equal.rs | 3 + arrow/benches/filter_kernels.rs | 101 +- arrow/benches/interleave_kernels.rs | 112 + arrow/benches/json_reader.rs | 98 +- arrow/benches/json_writer.rs | 198 + arrow/benches/lexsort.rs | 221 + arrow/benches/mutable_array.rs | 3 +- arrow/benches/partition_kernels.rs | 45 +- arrow/benches/primitive_run_accessor.rs | 54 + arrow/benches/primitive_run_take.rs | 77 + arrow/benches/regexp_kernels.rs | 51 + arrow/benches/row_format.rs | 169 + arrow/benches/sort_kernel.rs | 224 +- arrow/benches/string_dictionary_builder.rs | 13 +- arrow/benches/string_run_builder.rs | 60 + arrow/benches/string_run_iterator.rs | 82 + arrow/benches/take_kernels.rs | 90 +- arrow/examples/README.md | 5 +- arrow/examples/builders.rs | 32 +- arrow/examples/collect.rs | 87 + arrow/examples/dynamic_types.rs | 18 +- arrow/examples/read_csv.rs | 11 +- arrow/examples/read_csv_infer_schema.rs | 15 +- arrow/examples/tensor_builder.rs | 11 +- parquet/build.rs => arrow/examples/version.rs | 10 +- arrow/src/alloc/mod.rs | 155 - arrow/src/alloc/types.rs | 73 - arrow/src/array/array.rs | 1012 -- arrow/src/array/array_boolean.rs | 422 - arrow/src/array/array_decimal.rs | 972 -- arrow/src/array/array_dictionary.rs | 809 -- arrow/src/array/array_fixed_size_list.rs | 388 - arrow/src/array/array_list.rs | 956 -- arrow/src/array/array_map.rs | 532 - arrow/src/array/array_primitive.rs | 1146 -- arrow/src/array/array_string.rs | 829 -- arrow/src/array/array_struct.rs | 540 - arrow/src/array/builder/decimal_builder.rs | 383 - .../array/builder/fixed_size_list_builder.rs | 236 - .../array/builder/generic_binary_builder.rs | 232 - .../src/array/builder/generic_list_builder.rs | 320 - .../array/builder/generic_string_builder.rs | 193 - arrow/src/array/builder/map_builder.rs | 253 - arrow/src/array/builder/mod.rs | 151 - .../builder/primitive_dictionary_builder.rs | 252 - .../builder/string_dictionary_builder.rs | 374 - arrow/src/array/builder/struct_builder.rs | 464 - arrow/src/array/cast.rs | 512 - arrow/src/array/data.rs | 2894 ----- arrow/src/array/equal/decimal.rs | 74 - arrow/src/array/equal/fixed_binary.rs | 73 - arrow/src/array/equal/mod.rs | 1464 --- arrow/src/array/equal/primitive.rs | 70 - arrow/src/array/ffi.rs | 303 - arrow/src/array/mod.rs | 635 +- arrow/src/array/null.rs | 153 - arrow/src/array/ord.rs | 342 - arrow/src/array/raw_pointer.rs | 67 - arrow/src/array/transform/list.rs | 99 - arrow/src/array/transform/mod.rs | 1715 --- arrow/src/array/transform/structure.rs | 64 - arrow/src/array/transform/variable_size.rs | 105 - arrow/src/bitmap.rs | 156 - arrow/src/buffer/mod.rs | 72 - arrow/src/buffer/scalar.rs | 149 - arrow/src/compute/kernels.rs | 35 + arrow/src/compute/kernels/aggregate.rs | 1167 -- arrow/src/compute/kernels/arithmetic.rs | 2022 ---- arrow/src/compute/kernels/arity.rs | 271 - arrow/src/compute/kernels/cast.rs | 5556 ---------- arrow/src/compute/kernels/cast_utils.rs | 300 - arrow/src/compute/kernels/comparison.rs | 6501 ----------- arrow/src/compute/kernels/concat.rs | 572 - arrow/src/compute/kernels/limit.rs | 206 - arrow/src/compute/kernels/partition.rs | 408 - arrow/src/compute/kernels/regexp.rs | 157 - arrow/src/compute/kernels/temporal.rs | 1041 -- arrow/src/compute/kernels/zip.rs | 87 - arrow/src/compute/mod.rs | 7 +- arrow/src/compute/util.rs | 494 - arrow/src/csv/reader.rs | 2052 ---- arrow/src/csv/writer.rs | 820 -- arrow/src/datatypes/datatype.rs | 1499 --- arrow/src/datatypes/delta.rs | 182 - arrow/src/datatypes/ffi.rs | 460 - arrow/src/datatypes/field.rs | 881 -- arrow/src/datatypes/mod.rs | 1522 +-- arrow/src/datatypes/native.rs | 326 - arrow/src/datatypes/numeric.rs | 492 - arrow/src/datatypes/schema.rs | 467 - arrow/src/datatypes/types.rs | 569 - arrow/src/error.rs | 116 +- arrow/src/ffi.rs | 1484 --- arrow/src/ipc/compression/stub.rs | 63 - arrow/src/ipc/convert.rs | 1021 -- arrow/src/ipc/reader.rs | 1704 --- arrow/src/ipc/writer.rs | 1903 ---- arrow/src/json/mod.rs | 82 - arrow/src/json/reader.rs | 3364 ------ arrow/src/json/writer.rs | 1521 --- arrow/src/lib.rs | 208 +- arrow/src/pyarrow.rs | 479 +- arrow/src/temporal_conversions.rs | 248 - arrow/src/tensor.rs | 90 +- arrow/src/util/bench_util.rs | 276 +- arrow/src/util/bit_mask.rs | 190 - arrow/src/util/data_gen.rs | 537 +- arrow/src/util/decimal.rs | 474 - arrow/src/util/display.rs | 467 - arrow/src/util/mod.rs | 20 +- arrow/src/util/reader_parser.rs | 144 - arrow/src/util/string_writer.rs | 27 +- arrow/src/util/test_util.rs | 11 +- arrow/test/data/basic.json | 12 - arrow/tests/arithmetic.rs | 190 + arrow/tests/array_cast.rs | 599 + arrow/tests/array_equal.rs | 1299 +++ arrow/tests/array_transform.rs | 1153 ++ arrow/tests/array_validation.rs | 1101 ++ arrow/tests/csv.rs | 60 + arrow/tests/pyarrow.rs | 109 + arrow/tests/schema.rs | 11 +- arrow/tests/timezone.rs | 81 + conbench/.flake8 | 2 - conbench/.gitignore | 130 - conbench/.isort.cfg | 2 - conbench/README.md | 251 - conbench/_criterion.py | 98 - conbench/benchmarks.json | 8 - conbench/requirements-test.txt | 3 - conbench/requirements.txt | 1 - dev/release/README.md | 52 +- dev/release/create-tarball.sh | 5 +- dev/release/file_release_pr.sh | 40 + dev/release/label_issues.py | 153 + dev/release/rat_exclude_files.txt | 10 +- dev/release/update_change_log.sh | 41 +- dev/release/verify-release-candidate.sh | 27 +- format/Flight.proto | 827 +- format/FlightSql.proto | 3204 +++--- format/Message.fbs | 21 +- format/Schema.fbs | 228 +- integration-testing/src/lib.rs | 98 - integration-testing/tests/ipc_reader.rs | 293 - integration-testing/tests/ipc_writer.rs | 314 - object_store/.github_changelog_generator | 2 +- object_store/CHANGELOG-old.md | 741 ++ object_store/CHANGELOG.md | 50 +- object_store/CONTRIBUTING.md | 105 +- object_store/Cargo.toml | 47 +- object_store/LICENSE.txt | 204 + object_store/NOTICE.txt | 5 + object_store/README.md | 14 +- object_store/dev/release/release-tarball.sh | 3 + .../dev/release/remove-old-releases.sh | 45 + object_store/dev/release/update_change_log.sh | 9 +- .../dev/release/verify-release-candidate.sh | 2 +- object_store/src/attributes.rs | 248 + object_store/src/aws/builder.rs | 1471 +++ object_store/src/aws/checksum.rs | 63 + object_store/src/aws/client.rs | 918 +- object_store/src/aws/credential.rs | 732 +- object_store/src/aws/dynamo.rs | 593 + object_store/src/aws/mod.rs | 1078 +- object_store/src/aws/precondition.rs | 252 + object_store/src/aws/resolve.rs | 103 + object_store/src/azure/builder.rs | 1203 ++ object_store/src/azure/client.rs | 716 +- object_store/src/azure/credential.rs | 1061 +- object_store/src/azure/mod.rs | 826 +- object_store/src/buffered.rs | 662 ++ object_store/src/chunked.rs | 232 + object_store/src/client/backoff.rs | 21 +- object_store/src/client/get.rs | 432 + object_store/src/client/header.rs | 141 + object_store/src/client/list.rs | 126 + object_store/src/client/mock_server.rs | 92 +- object_store/src/client/mod.rs | 933 +- object_store/src/client/parts.rs | 48 + object_store/src/client/retry.rs | 557 +- object_store/src/client/s3.rs | 128 + object_store/src/client/token.rs | 26 +- object_store/src/config.rs | 143 + object_store/src/delimited.rs | 269 + object_store/src/gcp/builder.rs | 687 ++ object_store/src/gcp/client.rs | 663 ++ object_store/src/gcp/credential.rs | 805 +- object_store/src/gcp/mod.rs | 1062 +- object_store/src/http/client.rs | 457 + object_store/src/http/mod.rs | 271 + object_store/src/integration.rs | 1105 ++ object_store/src/lib.rs | 1743 ++- object_store/src/limit.rs | 185 +- object_store/src/local.rs | 1316 ++- object_store/src/memory.rs | 512 +- object_store/src/multipart.rs | 240 +- object_store/src/parse.rs | 365 + object_store/src/path/mod.rs | 267 +- object_store/src/path/parts.rs | 23 +- object_store/src/payload.rs | 328 + object_store/src/prefix.rs | 273 + object_store/src/signer.rs | 50 + object_store/src/tags.rs | 60 + object_store/src/throttle.rs | 287 +- object_store/src/upload.rs | 341 + object_store/src/util.rs | 266 +- object_store/tests/get_range_file.rs | 122 + parquet-testing | 2 +- parquet/CONTRIBUTING.md | 13 +- parquet/Cargo.toml | 165 +- parquet/README.md | 22 +- parquet/benches/arrow_reader.rs | 870 +- parquet/benches/arrow_statistics.rs | 269 + parquet/benches/arrow_writer.rs | 116 +- parquet/benches/compression.rs | 96 + parquet/benches/encoding.rs | 105 + parquet/benches/metadata.rs | 42 + parquet/examples/async_read_parquet.rs | 69 + parquet/examples/read_parquet.rs | 43 + parquet/examples/read_with_rowgroup.rs | 182 + parquet/examples/write_parquet.rs | 139 + parquet/pytest/requirements.in | 20 + parquet/pytest/requirements.txt | 176 + parquet/pytest/test_parquet_integration.py | 110 + parquet/regen.sh | 43 + parquet/src/arrow/array_reader/builder.rs | 305 +- parquet/src/arrow/array_reader/byte_array.rs | 202 +- .../array_reader/byte_array_dictionary.rs | 116 +- .../src/arrow/array_reader/byte_view_array.rs | 751 ++ parquet/src/arrow/array_reader/empty_array.rs | 7 +- .../array_reader/fixed_len_byte_array.rs | 258 +- .../array_reader/fixed_size_list_array.rs | 641 ++ parquet/src/arrow/array_reader/list_array.rs | 97 +- parquet/src/arrow/array_reader/map_array.rs | 37 +- parquet/src/arrow/array_reader/mod.rs | 83 +- parquet/src/arrow/array_reader/null_array.rs | 23 +- .../src/arrow/array_reader/primitive_array.rs | 258 +- .../src/arrow/array_reader/struct_array.rs | 51 +- parquet/src/arrow/array_reader/test_util.rs | 40 +- parquet/src/arrow/arrow_reader/filter.rs | 40 +- parquet/src/arrow/arrow_reader/mod.rs | 2490 ++++- parquet/src/arrow/arrow_reader/selection.rs | 823 +- parquet/src/arrow/arrow_reader/statistics.rs | 1595 +++ parquet/src/arrow/arrow_writer/byte_array.rs | 234 +- parquet/src/arrow/arrow_writer/levels.rs | 949 +- parquet/src/arrow/arrow_writer/mod.rs | 2146 +++- parquet/src/arrow/async_reader.rs | 1170 -- parquet/src/arrow/async_reader/metadata.rs | 372 + parquet/src/arrow/async_reader/mod.rs | 1996 ++++ parquet/src/arrow/async_reader/store.rs | 190 + parquet/src/arrow/async_writer/mod.rs | 468 + parquet/src/arrow/async_writer/store.rs | 157 + parquet/src/arrow/buffer/bit_util.rs | 13 +- parquet/src/arrow/buffer/dictionary_buffer.rs | 138 +- parquet/src/arrow/buffer/mod.rs | 1 + parquet/src/arrow/buffer/offset_buffer.rs | 105 +- parquet/src/arrow/buffer/view_buffer.rs | 193 + parquet/src/arrow/decoder/delta_byte_array.rs | 24 +- parquet/src/arrow/decoder/dictionary_index.rs | 20 +- parquet/src/arrow/mod.rs | 137 +- parquet/src/arrow/record_reader/buffer.rs | 217 +- .../arrow/record_reader/definition_levels.rs | 178 +- parquet/src/arrow/record_reader/mod.rs | 327 +- parquet/src/arrow/schema/complex.rs | 101 +- .../src/arrow/{schema.rs => schema/mod.rs} | 995 +- parquet/src/arrow/schema/primitive.rs | 109 +- parquet/src/basic.rs | 918 +- parquet/src/bin/parquet-concat.rs | 118 + parquet/src/bin/parquet-fromcsv-help.txt | 107 +- parquet/src/bin/parquet-fromcsv.rs | 263 +- parquet/src/bin/parquet-index.rs | 174 + parquet/src/bin/parquet-layout.rs | 236 + parquet/src/bin/parquet-read.rs | 15 +- parquet/src/bin/parquet-rewrite.rs | 279 + parquet/src/bin/parquet-rowcount.rs | 10 +- parquet/src/bin/parquet-schema.rs | 7 +- parquet/src/bin/parquet-show-bloom-filter.rs | 118 + parquet/src/bloom_filter/mod.rs | 525 + parquet/src/column/mod.rs | 50 +- parquet/src/column/page.rs | 186 +- parquet/src/column/reader.rs | 563 +- parquet/src/column/reader/decoder.rs | 453 +- parquet/src/column/writer/encoder.rs | 149 +- parquet/src/column/writer/mod.rs | 2206 +++- parquet/src/compression.rs | 594 +- parquet/src/data_type.rs | 328 +- parquet/src/encodings/decoding.rs | 576 +- .../decoding/byte_stream_split_decoder.rs | 256 + .../encoding/byte_stream_split_encoder.rs | 231 + .../src/encodings/encoding/dict_encoder.rs | 34 +- parquet/src/encodings/encoding/mod.rs | 221 +- parquet/src/encodings/levels.rs | 419 +- parquet/src/encodings/rle.rs | 299 +- parquet/src/errors.rs | 71 +- parquet/src/file/footer.rs | 111 +- parquet/src/file/metadata.rs | 1086 -- parquet/src/file/metadata/memory.rs | 239 + parquet/src/file/metadata/mod.rs | 1731 +++ parquet/src/file/metadata/reader.rs | 989 ++ parquet/src/file/metadata/writer.rs | 674 ++ parquet/src/file/mod.rs | 18 +- parquet/src/file/page_encoding_stats.rs | 18 +- parquet/src/file/page_index/index.rs | 363 +- parquet/src/file/page_index/index_reader.rs | 223 +- parquet/src/file/page_index/mod.rs | 8 +- parquet/src/file/page_index/offset_index.rs | 59 + parquet/src/file/page_index/range.rs | 475 - parquet/src/file/properties.rs | 912 +- parquet/src/file/reader.rs | 143 +- parquet/src/file/serialized_reader.rs | 642 +- parquet/src/file/statistics.rs | 654 +- parquet/src/file/writer.rs | 1636 ++- parquet/src/format.rs | 5481 ++++++++++ parquet/src/lib.rs | 85 +- parquet/src/record/api.rs | 369 +- parquet/src/record/mod.rs | 5 +- parquet/src/record/reader.rs | 460 +- parquet/src/record/record_reader.rs | 31 + parquet/src/record/record_writer.rs | 15 +- parquet/src/record/triplet.rs | 108 +- parquet/src/schema/mod.rs | 16 +- parquet/src/schema/parser.rs | 867 +- parquet/src/schema/printer.rs | 178 +- parquet/src/schema/types.rs | 663 +- parquet/src/schema/visitor.rs | 26 +- parquet/src/thrift.rs | 282 + parquet/src/util/bit_pack.rs | 8 +- parquet/src/util/bit_util.rs | 231 +- parquet/src/util/interner.rs | 13 + parquet/src/util/io.rs | 246 - parquet/src/util/memory.rs | 143 - parquet/src/util/mod.rs | 2 - parquet/src/util/test_common/file_util.rs | 3 +- parquet/src/util/test_common/mod.rs | 2 +- parquet/src/util/test_common/page_util.rs | 85 +- parquet/src/util/test_common/rand_gen.rs | 9 +- parquet/tests/arrow_reader/bad_data.rs | 162 + .../tests/arrow_reader/bad_raw_metadata.bin | Bin 0 -> 35456 bytes parquet/tests/arrow_reader/mod.rs | 1045 ++ parquet/tests/arrow_reader/statistics.rs | 2624 +++++ parquet/tests/arrow_writer_layout.rs | 546 + parquet/tests/boolean_writer.rs | 89 - parquet_derive/Cargo.toml | 20 +- parquet_derive/README.md | 55 +- parquet_derive/src/lib.rs | 157 +- parquet_derive/src/parquet_field.rs | 677 +- parquet_derive_test/Cargo.toml | 21 +- parquet_derive_test/src/lib.rs | 232 +- pre-commit.sh | 6 +- rustfmt.toml | 6 - testing | 2 +- 722 files changed, 211417 insertions(+), 91141 deletions(-) create mode 100644 .github/workflows/audit.yml delete mode 100644 .github/workflows/cancel.yml create mode 100644 .github/workflows/take.yml create mode 100644 arrow-arith/Cargo.toml create mode 100644 arrow-arith/src/aggregate.rs create mode 100644 arrow-arith/src/arithmetic.rs create mode 100644 arrow-arith/src/arity.rs create mode 100644 arrow-arith/src/bitwise.rs rename {arrow/src/compute/kernels => arrow-arith/src}/boolean.rs (56%) rename arrow/src/ipc/compression/mod.rs => arrow-arith/src/lib.rs (75%) create mode 100644 arrow-arith/src/numeric.rs create mode 100644 arrow-arith/src/temporal.rs create mode 100644 arrow-array/Cargo.toml create mode 100644 arrow-array/benches/decimal_overflow.rs create mode 100644 arrow-array/benches/fixed_size_list_array.rs create mode 100644 arrow-array/benches/gc_view_types.rs create mode 100644 arrow-array/benches/occupancy.rs create mode 100644 arrow-array/src/arithmetic.rs rename arrow/src/array/array_binary.rs => arrow-array/src/array/binary_array.rs (60%) create mode 100644 arrow-array/src/array/boolean_array.rs create mode 100644 arrow-array/src/array/byte_array.rs create mode 100644 arrow-array/src/array/byte_view_array.rs create mode 100644 arrow-array/src/array/dictionary_array.rs rename arrow/src/array/array_fixed_size_binary.rs => arrow-array/src/array/fixed_size_binary_array.rs (57%) create mode 100644 arrow-array/src/array/fixed_size_list_array.rs create mode 100644 arrow-array/src/array/list_array.rs create mode 100644 arrow-array/src/array/map_array.rs create mode 100644 arrow-array/src/array/mod.rs create mode 100644 arrow-array/src/array/null_array.rs create mode 100644 arrow-array/src/array/primitive_array.rs create mode 100644 arrow-array/src/array/run_array.rs create mode 100644 arrow-array/src/array/string_array.rs create mode 100644 arrow-array/src/array/struct_array.rs rename arrow/src/array/array_union.rs => arrow-array/src/array/union_array.rs (52%) rename {arrow/src/array => arrow-array/src}/builder/boolean_builder.rs (62%) create mode 100644 arrow-array/src/builder/buffer_builder.rs rename {arrow/src/array => arrow-array/src}/builder/fixed_size_binary_builder.rs (64%) create mode 100644 arrow-array/src/builder/fixed_size_list_builder.rs create mode 100644 arrow-array/src/builder/generic_byte_run_builder.rs create mode 100644 arrow-array/src/builder/generic_bytes_builder.rs create mode 100644 arrow-array/src/builder/generic_bytes_dictionary_builder.rs create mode 100644 arrow-array/src/builder/generic_bytes_view_builder.rs create mode 100644 arrow-array/src/builder/generic_list_builder.rs create mode 100644 arrow-array/src/builder/map_builder.rs create mode 100644 arrow-array/src/builder/mod.rs create mode 100644 arrow-array/src/builder/null_builder.rs rename {arrow/src/array => arrow-array/src}/builder/primitive_builder.rs (52%) create mode 100644 arrow-array/src/builder/primitive_dictionary_builder.rs create mode 100644 arrow-array/src/builder/primitive_run_builder.rs create mode 100644 arrow-array/src/builder/struct_builder.rs rename {arrow/src/array => arrow-array/src}/builder/union_builder.rs (70%) create mode 100644 arrow-array/src/cast.rs create mode 100644 arrow-array/src/delta.rs create mode 100644 arrow-array/src/ffi.rs rename {arrow => arrow-array}/src/ffi_stream.rs (65%) rename {arrow/src/array => arrow-array/src}/iterator.rs (73%) create mode 100644 arrow-array/src/lib.rs rename arrow/src/csv/mod.rs => arrow-array/src/numeric.rs (73%) rename {arrow => arrow-array}/src/record_batch.rs (60%) create mode 100644 arrow-array/src/run_iterator.rs create mode 100644 arrow-array/src/scalar.rs create mode 100644 arrow-array/src/temporal_conversions.rs create mode 100644 arrow-array/src/timezone.rs rename {arrow/src/util => arrow-array/src}/trusted_len.rs (94%) create mode 100644 arrow-array/src/types.rs create mode 100644 arrow-avro/Cargo.toml create mode 100644 arrow-avro/src/codec.rs create mode 100644 arrow-avro/src/compression.rs create mode 100644 arrow-avro/src/lib.rs create mode 100644 arrow-avro/src/reader/block.rs create mode 100644 arrow-avro/src/reader/header.rs create mode 100644 arrow-avro/src/reader/mod.rs create mode 100644 arrow-avro/src/reader/vlq.rs create mode 100644 arrow-avro/src/schema.rs create mode 100644 arrow-buffer/Cargo.toml create mode 100644 arrow-buffer/benches/bit_mask.rs create mode 100644 arrow-buffer/benches/i256.rs create mode 100644 arrow-buffer/benches/offset.rs rename {arrow => arrow-buffer}/src/alloc/alignment.rs (91%) create mode 100644 arrow-buffer/src/alloc/mod.rs create mode 100644 arrow-buffer/src/arith.rs create mode 100644 arrow-buffer/src/bigint/div.rs create mode 100644 arrow-buffer/src/bigint/mod.rs create mode 100644 arrow-buffer/src/buffer/boolean.rs rename {arrow => arrow-buffer}/src/buffer/immutable.rs (56%) create mode 100644 arrow-buffer/src/buffer/mod.rs rename {arrow => arrow-buffer}/src/buffer/mutable.rs (70%) create mode 100644 arrow-buffer/src/buffer/null.rs create mode 100644 arrow-buffer/src/buffer/offset.rs rename {arrow => arrow-buffer}/src/buffer/ops.rs (71%) create mode 100644 arrow-buffer/src/buffer/run.rs create mode 100644 arrow-buffer/src/buffer/scalar.rs rename arrow/src/array/builder/boolean_buffer_builder.rs => arrow-buffer/src/builder/boolean.rs (65%) rename arrow/src/array/builder/buffer_builder.rs => arrow-buffer/src/builder/mod.rs (58%) rename arrow/src/array/builder/null_buffer_builder.rs => arrow-buffer/src/builder/null.rs (74%) create mode 100644 arrow-buffer/src/builder/offset.rs rename {arrow => arrow-buffer}/src/bytes.rs (71%) create mode 100644 arrow-buffer/src/interval.rs create mode 100644 arrow-buffer/src/lib.rs create mode 100644 arrow-buffer/src/native.rs rename {arrow => arrow-buffer}/src/util/bit_chunk_iterator.rs (92%) rename {arrow => arrow-buffer}/src/util/bit_iterator.rs (53%) create mode 100644 arrow-buffer/src/util/bit_mask.rs rename {arrow => arrow-buffer}/src/util/bit_util.rs (77%) create mode 100644 arrow-buffer/src/util/mod.rs create mode 100644 arrow-cast/Cargo.toml create mode 100644 arrow-cast/benches/parse_date.rs create mode 100644 arrow-cast/benches/parse_decimal.rs create mode 100644 arrow-cast/benches/parse_time.rs rename arrow/src/util/serialization.rs => arrow-cast/benches/parse_timestamp.rs (51%) create mode 100644 arrow-cast/src/base64.rs create mode 100644 arrow-cast/src/cast/decimal.rs create mode 100644 arrow-cast/src/cast/dictionary.rs create mode 100644 arrow-cast/src/cast/list.rs create mode 100644 arrow-cast/src/cast/map.rs create mode 100644 arrow-cast/src/cast/mod.rs create mode 100644 arrow-cast/src/cast/string.rs create mode 100644 arrow-cast/src/display.rs create mode 100644 arrow-cast/src/lib.rs create mode 100644 arrow-cast/src/parse.rs rename {arrow/src/util => arrow-cast/src}/pretty.rs (50%) create mode 100644 arrow-csv/Cargo.toml create mode 100644 arrow-csv/examples/README.md create mode 100644 arrow-csv/examples/csv_calculation.rs create mode 100644 arrow-csv/src/lib.rs create mode 100644 arrow-csv/src/reader/mod.rs create mode 100644 arrow-csv/src/reader/records.rs create mode 100644 arrow-csv/src/writer.rs create mode 100644 arrow-csv/test/data/custom_null_test.csv rename {arrow => arrow-csv}/test/data/decimal_test.csv (100%) create mode 100644 arrow-csv/test/data/example.csv create mode 100644 arrow-csv/test/data/init_null_test.csv rename {arrow => arrow-csv}/test/data/null_test.csv (100%) create mode 100644 arrow-csv/test/data/scientific_notation_test.csv create mode 100644 arrow-csv/test/data/truncated_rows.csv rename {arrow => arrow-csv}/test/data/uk_cities.csv (100%) rename {arrow => arrow-csv}/test/data/uk_cities_with_headers.csv (100%) rename {arrow => arrow-csv}/test/data/various_types.csv (100%) rename {arrow => arrow-csv}/test/data/various_types_invalid.csv (100%) create mode 100644 arrow-data/Cargo.toml create mode 100644 arrow-data/src/byte_view.rs create mode 100644 arrow-data/src/data.rs create mode 100644 arrow-data/src/decimal.rs rename {arrow/src/array => arrow-data/src}/equal/boolean.rs (65%) create mode 100644 arrow-data/src/equal/byte_view.rs rename {arrow/src/array => arrow-data/src}/equal/dictionary.rs (75%) create mode 100644 arrow-data/src/equal/fixed_binary.rs rename {arrow/src/array => arrow-data/src}/equal/fixed_list.rs (75%) rename {arrow/src/array => arrow-data/src}/equal/list.rs (69%) create mode 100644 arrow-data/src/equal/mod.rs rename {arrow/src/array => arrow-data/src}/equal/null.rs (97%) create mode 100644 arrow-data/src/equal/primitive.rs create mode 100644 arrow-data/src/equal/run.rs rename {arrow/src/array => arrow-data/src}/equal/structure.rs (74%) rename {arrow/src/array => arrow-data/src}/equal/union.rs (79%) rename {arrow/src/array => arrow-data/src}/equal/utils.rs (69%) rename {arrow/src/array => arrow-data/src}/equal/variable_size.rs (62%) create mode 100644 arrow-data/src/ffi.rs create mode 100644 arrow-data/src/lib.rs rename {arrow/src/array => arrow-data/src}/transform/boolean.rs (95%) rename {arrow/src/array => arrow-data/src}/transform/fixed_binary.rs (54%) rename {arrow/src/array => arrow-data/src}/transform/fixed_size_list.rs (53%) create mode 100644 arrow-data/src/transform/list.rs create mode 100644 arrow-data/src/transform/mod.rs rename {arrow/src/array => arrow-data/src}/transform/null.rs (97%) rename {arrow/src/array => arrow-data/src}/transform/primitive.rs (91%) create mode 100644 arrow-data/src/transform/structure.rs rename {arrow/src/array => arrow-data/src}/transform/union.rs (82%) rename {arrow/src/array => arrow-data/src}/transform/utils.rs (66%) create mode 100644 arrow-data/src/transform/variable_size.rs create mode 100644 arrow-flight/CONTRIBUTING.md delete mode 100644 arrow-flight/build.rs create mode 100644 arrow-flight/examples/data/ca.pem create mode 100644 arrow-flight/examples/data/client1.key create mode 100644 arrow-flight/examples/data/client1.pem create mode 100644 arrow-flight/examples/data/client_ca.pem create mode 100644 arrow-flight/examples/data/server.key create mode 100644 arrow-flight/examples/data/server.pem create mode 100644 arrow-flight/gen/Cargo.toml create mode 100644 arrow-flight/gen/src/main.rs rename conbench/benchmarks.py => arrow-flight/regen.sh (60%) mode change 100644 => 100755 create mode 100644 arrow-flight/src/bin/flight_sql_client.rs create mode 100644 arrow-flight/src/client.rs create mode 100644 arrow-flight/src/decode.rs create mode 100644 arrow-flight/src/encode.rs create mode 100644 arrow-flight/src/error.rs create mode 100644 arrow-flight/src/sql/client.rs create mode 100644 arrow-flight/src/sql/metadata/catalogs.rs create mode 100644 arrow-flight/src/sql/metadata/db_schemas.rs create mode 100644 arrow-flight/src/sql/metadata/mod.rs create mode 100644 arrow-flight/src/sql/metadata/sql_info.rs create mode 100644 arrow-flight/src/sql/metadata/table_types.rs create mode 100644 arrow-flight/src/sql/metadata/tables.rs create mode 100644 arrow-flight/src/sql/metadata/xdbc_info.rs create mode 100644 arrow-flight/src/streams.rs create mode 100644 arrow-flight/src/trailers.rs create mode 100644 arrow-flight/tests/client.rs create mode 100644 arrow-flight/tests/common/fixture.rs create mode 100644 arrow-flight/tests/common/mod.rs create mode 100644 arrow-flight/tests/common/server.rs create mode 100644 arrow-flight/tests/common/trailers_layer.rs create mode 100644 arrow-flight/tests/common/utils.rs create mode 100644 arrow-flight/tests/encode_decode.rs create mode 100644 arrow-flight/tests/flight_sql_client.rs create mode 100644 arrow-flight/tests/flight_sql_client_cli.rs create mode 100644 arrow-integration-test/Cargo.toml rename {integration-testing => arrow-integration-test}/data/integration.json (100%) create mode 100644 arrow-integration-test/src/datatype.rs create mode 100644 arrow-integration-test/src/field.rs rename integration-testing/src/util.rs => arrow-integration-test/src/lib.rs (73%) create mode 100644 arrow-integration-test/src/schema.rs rename {integration-testing => arrow-integration-testing}/Cargo.toml (69%) rename {integration-testing => arrow-integration-testing}/README.md (99%) rename {integration-testing => arrow-integration-testing}/src/bin/arrow-file-to-stream.rs (97%) rename {integration-testing => arrow-integration-testing}/src/bin/arrow-json-integration-test.rs (66%) rename {integration-testing => arrow-integration-testing}/src/bin/arrow-stream-to-file.rs (100%) rename {integration-testing => arrow-integration-testing}/src/bin/flight-test-integration-client.rs (95%) rename {integration-testing => arrow-integration-testing}/src/bin/flight-test-integration-server.rs (96%) rename {integration-testing => arrow-integration-testing}/src/flight_client_scenarios.rs (100%) rename {integration-testing => arrow-integration-testing}/src/flight_client_scenarios/auth_basic_proto.rs (84%) rename {integration-testing => arrow-integration-testing}/src/flight_client_scenarios/integration_test.rs (80%) rename {integration-testing => arrow-integration-testing}/src/flight_client_scenarios/middleware.rs (90%) rename {integration-testing => arrow-integration-testing}/src/flight_server_scenarios.rs (89%) rename {integration-testing => arrow-integration-testing}/src/flight_server_scenarios/auth_basic_proto.rs (88%) rename {integration-testing => arrow-integration-testing}/src/flight_server_scenarios/integration_test.rs (81%) rename {integration-testing => arrow-integration-testing}/src/flight_server_scenarios/middleware.rs (92%) create mode 100644 arrow-integration-testing/src/lib.rs create mode 100644 arrow-integration-testing/tests/ipc_reader.rs create mode 100644 arrow-integration-testing/tests/ipc_writer.rs create mode 100644 arrow-ipc/CONTRIBUTING.md create mode 100644 arrow-ipc/Cargo.toml rename {arrow => arrow-ipc}/regen.sh (83%) rename arrow/src/ipc/compression/codec.rs => arrow-ipc/src/compression.rs (54%) create mode 100644 arrow-ipc/src/convert.rs rename {arrow/src/ipc => arrow-ipc/src}/gen/File.rs (69%) rename {arrow/src/ipc => arrow-ipc/src}/gen/Message.rs (66%) rename {arrow/src/ipc => arrow-ipc/src}/gen/Schema.rs (62%) rename {arrow/src/ipc => arrow-ipc/src}/gen/SparseTensor.rs (69%) rename {arrow/src/ipc => arrow-ipc/src}/gen/Tensor.rs (64%) rename {arrow/src/ipc => arrow-ipc/src}/gen/mod.rs (100%) rename arrow/src/ipc/mod.rs => arrow-ipc/src/lib.rs (89%) create mode 100644 arrow-ipc/src/reader.rs create mode 100644 arrow-ipc/src/reader/stream.rs create mode 100644 arrow-ipc/src/writer.rs create mode 100644 arrow-json/Cargo.toml create mode 100644 arrow-json/benches/serde.rs create mode 100644 arrow-json/src/lib.rs create mode 100644 arrow-json/src/reader/boolean_array.rs create mode 100644 arrow-json/src/reader/decimal_array.rs create mode 100644 arrow-json/src/reader/list_array.rs create mode 100644 arrow-json/src/reader/map_array.rs create mode 100644 arrow-json/src/reader/mod.rs create mode 100644 arrow-json/src/reader/null_array.rs create mode 100644 arrow-json/src/reader/primitive_array.rs create mode 100644 arrow-json/src/reader/schema.rs create mode 100644 arrow-json/src/reader/serializer.rs create mode 100644 arrow-json/src/reader/string_array.rs create mode 100644 arrow-json/src/reader/struct_array.rs create mode 100644 arrow-json/src/reader/tape.rs create mode 100644 arrow-json/src/reader/timestamp_array.rs create mode 100644 arrow-json/src/writer.rs create mode 100644 arrow-json/src/writer/encoder.rs rename {arrow => arrow-json}/test/data/arrays.json (100%) create mode 100644 arrow-json/test/data/basic.json rename {arrow => arrow-json}/test/data/basic_nulls.json (100%) rename {arrow => arrow-json}/test/data/list_string_dict_nested.json (100%) rename {arrow => arrow-json}/test/data/list_string_dict_nested_nulls.json (100%) rename {arrow => arrow-json}/test/data/mixed_arrays.json (100%) rename {arrow => arrow-json}/test/data/mixed_arrays.json.gz (100%) rename {arrow => arrow-json}/test/data/nested_structs.json (100%) create mode 100644 arrow-json/test/data/nested_with_nulls.json create mode 100644 arrow-ord/Cargo.toml create mode 100644 arrow-ord/src/cmp.rs create mode 100644 arrow-ord/src/comparison.rs create mode 100644 arrow-ord/src/lib.rs create mode 100644 arrow-ord/src/ord.rs create mode 100644 arrow-ord/src/partition.rs create mode 100644 arrow-ord/src/rank.rs rename {arrow/src/compute/kernels => arrow-ord/src}/sort.rs (50%) create mode 100644 arrow-row/Cargo.toml create mode 100644 arrow-row/src/fixed.rs create mode 100644 arrow-row/src/lib.rs create mode 100644 arrow-row/src/list.rs create mode 100644 arrow-row/src/variable.rs create mode 100644 arrow-schema/Cargo.toml create mode 100644 arrow-schema/src/datatype.rs create mode 100644 arrow-schema/src/datatype_parse.rs create mode 100644 arrow-schema/src/error.rs create mode 100644 arrow-schema/src/ffi.rs create mode 100644 arrow-schema/src/field.rs create mode 100644 arrow-schema/src/fields.rs create mode 100644 arrow-schema/src/lib.rs create mode 100644 arrow-schema/src/schema.rs create mode 100644 arrow-select/Cargo.toml create mode 100644 arrow-select/src/concat.rs create mode 100644 arrow-select/src/dictionary.rs rename {arrow/src/compute/kernels => arrow-select/src}/filter.rs (65%) create mode 100644 arrow-select/src/interleave.rs create mode 100644 arrow-select/src/lib.rs create mode 100644 arrow-select/src/nullif.rs rename {arrow/src/compute/kernels => arrow-select/src}/take.rs (54%) create mode 100644 arrow-select/src/union_extract.rs rename {arrow/src/compute/kernels => arrow-select/src}/window.rs (93%) create mode 100644 arrow-select/src/zip.rs create mode 100644 arrow-string/Cargo.toml rename {arrow/src/compute/kernels => arrow-string/src}/concat_elements.rs (55%) rename {arrow/src/compute/kernels => arrow-string/src}/length.rs (55%) rename arrow/src/compute/kernels/mod.rs => arrow-string/src/lib.rs (74%) create mode 100644 arrow-string/src/like.rs create mode 100644 arrow-string/src/predicate.rs create mode 100644 arrow-string/src/regexp.rs rename {arrow/src/compute/kernels => arrow-string/src}/substring.rs (80%) create mode 100644 arrow/benches/bitwise_kernel.rs create mode 100644 arrow/benches/csv_reader.rs create mode 100644 arrow/benches/interleave_kernels.rs create mode 100644 arrow/benches/json_writer.rs create mode 100644 arrow/benches/lexsort.rs create mode 100644 arrow/benches/primitive_run_accessor.rs create mode 100644 arrow/benches/primitive_run_take.rs create mode 100644 arrow/benches/regexp_kernels.rs create mode 100644 arrow/benches/row_format.rs create mode 100644 arrow/benches/string_run_builder.rs create mode 100644 arrow/benches/string_run_iterator.rs create mode 100644 arrow/examples/collect.rs rename parquet/build.rs => arrow/examples/version.rs (73%) delete mode 100644 arrow/src/alloc/mod.rs delete mode 100644 arrow/src/alloc/types.rs delete mode 100644 arrow/src/array/array.rs delete mode 100644 arrow/src/array/array_boolean.rs delete mode 100644 arrow/src/array/array_decimal.rs delete mode 100644 arrow/src/array/array_dictionary.rs delete mode 100644 arrow/src/array/array_fixed_size_list.rs delete mode 100644 arrow/src/array/array_list.rs delete mode 100644 arrow/src/array/array_map.rs delete mode 100644 arrow/src/array/array_primitive.rs delete mode 100644 arrow/src/array/array_string.rs delete mode 100644 arrow/src/array/array_struct.rs delete mode 100644 arrow/src/array/builder/decimal_builder.rs delete mode 100644 arrow/src/array/builder/fixed_size_list_builder.rs delete mode 100644 arrow/src/array/builder/generic_binary_builder.rs delete mode 100644 arrow/src/array/builder/generic_list_builder.rs delete mode 100644 arrow/src/array/builder/generic_string_builder.rs delete mode 100644 arrow/src/array/builder/map_builder.rs delete mode 100644 arrow/src/array/builder/mod.rs delete mode 100644 arrow/src/array/builder/primitive_dictionary_builder.rs delete mode 100644 arrow/src/array/builder/string_dictionary_builder.rs delete mode 100644 arrow/src/array/builder/struct_builder.rs delete mode 100644 arrow/src/array/cast.rs delete mode 100644 arrow/src/array/data.rs delete mode 100644 arrow/src/array/equal/decimal.rs delete mode 100644 arrow/src/array/equal/fixed_binary.rs delete mode 100644 arrow/src/array/equal/mod.rs delete mode 100644 arrow/src/array/equal/primitive.rs delete mode 100644 arrow/src/array/ffi.rs delete mode 100644 arrow/src/array/null.rs delete mode 100644 arrow/src/array/ord.rs delete mode 100644 arrow/src/array/raw_pointer.rs delete mode 100644 arrow/src/array/transform/list.rs delete mode 100644 arrow/src/array/transform/mod.rs delete mode 100644 arrow/src/array/transform/structure.rs delete mode 100644 arrow/src/array/transform/variable_size.rs delete mode 100644 arrow/src/bitmap.rs delete mode 100644 arrow/src/buffer/mod.rs delete mode 100644 arrow/src/buffer/scalar.rs create mode 100644 arrow/src/compute/kernels.rs delete mode 100644 arrow/src/compute/kernels/aggregate.rs delete mode 100644 arrow/src/compute/kernels/arithmetic.rs delete mode 100644 arrow/src/compute/kernels/arity.rs delete mode 100644 arrow/src/compute/kernels/cast.rs delete mode 100644 arrow/src/compute/kernels/cast_utils.rs delete mode 100644 arrow/src/compute/kernels/comparison.rs delete mode 100644 arrow/src/compute/kernels/concat.rs delete mode 100644 arrow/src/compute/kernels/limit.rs delete mode 100644 arrow/src/compute/kernels/partition.rs delete mode 100644 arrow/src/compute/kernels/regexp.rs delete mode 100644 arrow/src/compute/kernels/temporal.rs delete mode 100644 arrow/src/compute/kernels/zip.rs delete mode 100644 arrow/src/compute/util.rs delete mode 100644 arrow/src/csv/reader.rs delete mode 100644 arrow/src/csv/writer.rs delete mode 100644 arrow/src/datatypes/datatype.rs delete mode 100644 arrow/src/datatypes/delta.rs delete mode 100644 arrow/src/datatypes/ffi.rs delete mode 100644 arrow/src/datatypes/field.rs delete mode 100644 arrow/src/datatypes/native.rs delete mode 100644 arrow/src/datatypes/numeric.rs delete mode 100644 arrow/src/datatypes/schema.rs delete mode 100644 arrow/src/datatypes/types.rs delete mode 100644 arrow/src/ffi.rs delete mode 100644 arrow/src/ipc/compression/stub.rs delete mode 100644 arrow/src/ipc/convert.rs delete mode 100644 arrow/src/ipc/reader.rs delete mode 100644 arrow/src/ipc/writer.rs delete mode 100644 arrow/src/json/mod.rs delete mode 100644 arrow/src/json/reader.rs delete mode 100644 arrow/src/json/writer.rs delete mode 100644 arrow/src/temporal_conversions.rs delete mode 100644 arrow/src/util/bit_mask.rs delete mode 100644 arrow/src/util/decimal.rs delete mode 100644 arrow/src/util/display.rs delete mode 100644 arrow/src/util/reader_parser.rs delete mode 100644 arrow/test/data/basic.json create mode 100644 arrow/tests/arithmetic.rs create mode 100644 arrow/tests/array_cast.rs create mode 100644 arrow/tests/array_equal.rs create mode 100644 arrow/tests/array_transform.rs create mode 100644 arrow/tests/array_validation.rs create mode 100644 arrow/tests/csv.rs create mode 100644 arrow/tests/pyarrow.rs create mode 100644 arrow/tests/timezone.rs delete mode 100644 conbench/.flake8 delete mode 100755 conbench/.gitignore delete mode 100644 conbench/.isort.cfg delete mode 100644 conbench/README.md delete mode 100644 conbench/_criterion.py delete mode 100644 conbench/benchmarks.json delete mode 100644 conbench/requirements-test.txt delete mode 100644 conbench/requirements.txt create mode 100644 dev/release/file_release_pr.sh create mode 100755 dev/release/label_issues.py delete mode 100644 integration-testing/src/lib.rs delete mode 100644 integration-testing/tests/ipc_reader.rs delete mode 100644 integration-testing/tests/ipc_writer.rs create mode 100644 object_store/CHANGELOG-old.md create mode 100644 object_store/LICENSE.txt create mode 100644 object_store/NOTICE.txt create mode 100755 object_store/dev/release/remove-old-releases.sh create mode 100644 object_store/src/attributes.rs create mode 100644 object_store/src/aws/builder.rs create mode 100644 object_store/src/aws/checksum.rs create mode 100644 object_store/src/aws/dynamo.rs create mode 100644 object_store/src/aws/precondition.rs create mode 100644 object_store/src/aws/resolve.rs create mode 100644 object_store/src/azure/builder.rs create mode 100644 object_store/src/buffered.rs create mode 100644 object_store/src/chunked.rs create mode 100644 object_store/src/client/get.rs create mode 100644 object_store/src/client/header.rs create mode 100644 object_store/src/client/list.rs create mode 100644 object_store/src/client/parts.rs create mode 100644 object_store/src/client/s3.rs create mode 100644 object_store/src/config.rs create mode 100644 object_store/src/delimited.rs create mode 100644 object_store/src/gcp/builder.rs create mode 100644 object_store/src/gcp/client.rs create mode 100644 object_store/src/http/client.rs create mode 100644 object_store/src/http/mod.rs create mode 100644 object_store/src/integration.rs create mode 100644 object_store/src/parse.rs create mode 100644 object_store/src/payload.rs create mode 100644 object_store/src/prefix.rs create mode 100644 object_store/src/signer.rs create mode 100644 object_store/src/tags.rs create mode 100644 object_store/src/upload.rs create mode 100644 object_store/tests/get_range_file.rs create mode 100644 parquet/benches/arrow_statistics.rs create mode 100644 parquet/benches/compression.rs create mode 100644 parquet/benches/encoding.rs create mode 100644 parquet/benches/metadata.rs create mode 100644 parquet/examples/async_read_parquet.rs create mode 100644 parquet/examples/read_parquet.rs create mode 100644 parquet/examples/read_with_rowgroup.rs create mode 100644 parquet/examples/write_parquet.rs create mode 100644 parquet/pytest/requirements.in create mode 100644 parquet/pytest/requirements.txt create mode 100755 parquet/pytest/test_parquet_integration.py create mode 100755 parquet/regen.sh create mode 100644 parquet/src/arrow/array_reader/byte_view_array.rs create mode 100644 parquet/src/arrow/array_reader/fixed_size_list_array.rs create mode 100644 parquet/src/arrow/arrow_reader/statistics.rs delete mode 100644 parquet/src/arrow/async_reader.rs create mode 100644 parquet/src/arrow/async_reader/metadata.rs create mode 100644 parquet/src/arrow/async_reader/mod.rs create mode 100644 parquet/src/arrow/async_reader/store.rs create mode 100644 parquet/src/arrow/async_writer/mod.rs create mode 100644 parquet/src/arrow/async_writer/store.rs create mode 100644 parquet/src/arrow/buffer/view_buffer.rs rename parquet/src/arrow/{schema.rs => schema/mod.rs} (65%) create mode 100644 parquet/src/bin/parquet-concat.rs create mode 100644 parquet/src/bin/parquet-index.rs create mode 100644 parquet/src/bin/parquet-layout.rs create mode 100644 parquet/src/bin/parquet-rewrite.rs create mode 100644 parquet/src/bin/parquet-show-bloom-filter.rs create mode 100644 parquet/src/bloom_filter/mod.rs create mode 100644 parquet/src/encodings/decoding/byte_stream_split_decoder.rs create mode 100644 parquet/src/encodings/encoding/byte_stream_split_encoder.rs delete mode 100644 parquet/src/file/metadata.rs create mode 100644 parquet/src/file/metadata/memory.rs create mode 100644 parquet/src/file/metadata/mod.rs create mode 100644 parquet/src/file/metadata/reader.rs create mode 100644 parquet/src/file/metadata/writer.rs create mode 100644 parquet/src/file/page_index/offset_index.rs delete mode 100644 parquet/src/file/page_index/range.rs create mode 100644 parquet/src/format.rs create mode 100644 parquet/src/record/record_reader.rs create mode 100644 parquet/src/thrift.rs delete mode 100644 parquet/src/util/io.rs delete mode 100644 parquet/src/util/memory.rs create mode 100644 parquet/tests/arrow_reader/bad_data.rs create mode 100644 parquet/tests/arrow_reader/bad_raw_metadata.bin create mode 100644 parquet/tests/arrow_reader/mod.rs create mode 100644 parquet/tests/arrow_reader/statistics.rs create mode 100644 parquet/tests/arrow_writer_layout.rs delete mode 100644 parquet/tests/boolean_writer.rs diff --git a/.asf.yaml b/.asf.yaml index 968c6779215a..9541db89daf8 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -38,4 +38,10 @@ github: # require branches to be up-to-date before merging strict: true # don't require any jobs to pass - contexts: [] \ No newline at end of file + contexts: [] + +# publishes the content of the `asf-site` branch to +# https://arrow.apache.org/rust/ +publish: + whoami: asf-site + subdir: rust diff --git a/.gitattributes b/.gitattributes index fac7bf85a77f..b7b0d51ff478 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,6 +1,3 @@ -r/R/RcppExports.R linguist-generated=true -r/R/arrowExports.R linguist-generated=true -r/src/RcppExports.cpp linguist-generated=true -r/src/arrowExports.cpp linguist-generated=true -r/man/*.Rd linguist-generated=true - +parquet/src/format.rs linguist-generated +arrow-flight/src/arrow.flight.protocol.rs linguist-generated +arrow-flight/src/sql/arrow.flight.protocol.sql.rs linguist-generated diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0ef6532da477..aa1d1d9c14da 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -20,8 +20,12 @@ description: 'Prepare Rust Build Environment' inputs: rust-version: description: 'version of rust to install (e.g. stable)' - required: true + required: false default: 'stable' + target: + description: 'target architecture(s)' + required: false + default: 'x86_64-unknown-linux-gnu' runs: using: "composite" steps: @@ -51,6 +55,17 @@ runs: shell: bash run: | echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} + rustup toolchain install ${{ inputs.rust-version }} --target ${{ inputs.target }} rustup default ${{ inputs.rust-version }} - echo "CARGO_TARGET_DIR=/github/home/target" >> $GITHUB_ENV + - name: Disable debuginfo generation + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + shell: bash + run: echo "RUSTFLAGS=-C debuginfo=1" >> $GITHUB_ENV + - name: Enable backtraces + shell: bash + run: echo "RUST_BACKTRACE=1" >> $GITHUB_ENV + - name: Fixup git permissions + # https://github.com/actions/checkout/issues/766 + shell: bash + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9c4cda5d034d..ffde5378da93 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -6,10 +6,17 @@ updates: interval: daily open-pull-requests-limit: 10 target-branch: master - labels: [auto-dependencies] + labels: [ auto-dependencies, arrow ] + - package-ecosystem: cargo + directory: "/object_store" + schedule: + interval: daily + open-pull-requests-limit: 10 + target-branch: master + labels: [ auto-dependencies, object_store ] - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" open-pull-requests-limit: 10 - labels: [auto-dependencies] + labels: [ auto-dependencies ] diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index d34ee3b49b5c..d3b2526740fa 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -18,6 +18,10 @@ # tests for arrow crate name: arrow +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + on: # always trigger push: @@ -25,8 +29,23 @@ on: - master pull_request: paths: - - arrow/** - .github/** + - arrow-arith/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-integration-test/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-ord/** + - arrow-row/** + - arrow-schema/** + - arrow-select/** + - arrow-string/** + - arrow/** jobs: @@ -36,24 +55,46 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Test - run: | - cargo test -p arrow - - name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict - run: | - cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + - name: Test arrow-buffer with all features + run: cargo test -p arrow-buffer --all-features + - name: Test arrow-data with all features + run: cargo test -p arrow-data --all-features + - name: Test arrow-schema with all features + run: cargo test -p arrow-schema --all-features + - name: Test arrow-array with all features + run: cargo test -p arrow-array --all-features + - name: Test arrow-select with all features + run: cargo test -p arrow-select --all-features + - name: Test arrow-cast with all features + run: cargo test -p arrow-cast --all-features + - name: Test arrow-ipc with all features + run: cargo test -p arrow-ipc --all-features + - name: Test arrow-csv with all features + run: cargo test -p arrow-csv --all-features + - name: Test arrow-json with all features + run: cargo test -p arrow-json --all-features + - name: Test arrow-avro with all features + run: cargo test -p arrow-avro --all-features + - name: Test arrow-string with all features + run: cargo test -p arrow-string --all-features + - name: Test arrow-ord with all features + run: cargo test -p arrow-ord --all-features + - name: Test arrow-arith with all features + run: cargo test -p arrow-arith --all-features + - name: Test arrow-row with all features + run: cargo test -p arrow-row --all-features + - name: Test arrow-integration-test with all features + run: cargo test -p arrow-integration-test --all-features + - name: Test arrow with default features + run: cargo test -p arrow + - name: Test arrow with all features except pyarrow + run: cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,chrono-tz - name: Run examples run: | # Test arrow examples @@ -64,99 +105,52 @@ jobs: - name: Run non-archery based integration-tests run: cargo test -p arrow-integration-testing - # test compilaton features + # test compilation features linux-features: name: Check Compilation runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Check compilation - run: | - cargo check -p arrow + run: cargo check -p arrow - name: Check compilation --no-default-features - run: | - cargo check -p arrow --no-default-features + run: cargo check -p arrow --no-default-features - name: Check compilation --all-targets - run: | - cargo check -p arrow --all-targets + run: cargo check -p arrow --all-targets - name: Check compilation --no-default-features --all-targets - run: | - cargo check -p arrow --no-default-features --all-targets + run: cargo check -p arrow --no-default-features --all-targets - name: Check compilation --no-default-features --all-targets --features test_utils - run: | - cargo check -p arrow --no-default-features --all-targets --features test_utils - - # test the --features "simd" of the arrow crate. This requires nightly Rust. - linux-test-simd: - name: Test SIMD on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - container: - image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - steps: - - uses: actions/checkout@v3 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: nightly - - name: Run tests --features "simd" - run: | - cargo test -p arrow --features "simd" - - name: Check compilation --features "simd" - run: | - cargo check -p arrow --features simd - - name: Check compilation --features simd --all-targets - run: | - cargo check -p arrow --features simd --all-targets + run: cargo check -p arrow --no-default-features --all-targets --features test_utils + - name: Check compilation --no-default-features --all-targets --features ffi + run: cargo check -p arrow --no-default-features --all-targets --features ffi + - name: Check compilation --no-default-features --all-targets --features chrono-tz + run: cargo check -p arrow --no-default-features --all-targets --features chrono-tz - # test the arrow crate builds against wasm32 in stable rust + # test the arrow crate builds against wasm32 in nightly rust wasm32-build: name: Build wasm32 runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder with: - path: /github/home/.cargo - key: cargo-wasm32-cache3- - - name: Setup Rust toolchain for WASM - run: | - rustup toolchain install nightly - rustup override set nightly - rustup target add wasm32-unknown-unknown - rustup target add wasm32-wasi - - name: Build - run: | - cd arrow - cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-unknown-unknown - cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-wasi + target: wasm32-unknown-unknown,wasm32-wasi + - name: Build wasm32-unknown-unknown + run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-wasi clippy: name: Clippy @@ -164,14 +158,42 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy - - name: Run clippy - run: | - cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings + run: rustup component add clippy + - name: Clippy arrow-buffer with all features + run: cargo clippy -p arrow-buffer --all-targets --all-features -- -D warnings + - name: Clippy arrow-data with all features + run: cargo clippy -p arrow-data --all-targets --all-features -- -D warnings + - name: Clippy arrow-schema with all features + run: cargo clippy -p arrow-schema --all-targets --all-features -- -D warnings + - name: Clippy arrow-array with all features + run: cargo clippy -p arrow-array --all-targets --all-features -- -D warnings + - name: Clippy arrow-select with all features + run: cargo clippy -p arrow-select --all-targets --all-features -- -D warnings + - name: Clippy arrow-cast with all features + run: cargo clippy -p arrow-cast --all-targets --all-features -- -D warnings + - name: Clippy arrow-ipc with all features + run: cargo clippy -p arrow-ipc --all-targets --all-features -- -D warnings + - name: Clippy arrow-csv with all features + run: cargo clippy -p arrow-csv --all-targets --all-features -- -D warnings + - name: Clippy arrow-json with all features + run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings + - name: Clippy arrow-avro with all features + run: cargo clippy -p arrow-avro --all-targets --all-features -- -D warnings + - name: Clippy arrow-string with all features + run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings + - name: Clippy arrow-ord with all features + run: cargo clippy -p arrow-ord --all-targets --all-features -- -D warnings + - name: Clippy arrow-arith with all features + run: cargo clippy -p arrow-arith --all-targets --all-features -- -D warnings + - name: Clippy arrow-row with all features + run: cargo clippy -p arrow-row --all-targets --all-features -- -D warnings + - name: Clippy arrow with all features + run: cargo clippy -p arrow --all-features --all-targets -- -D warnings + - name: Clippy arrow-integration-test with all features + run: cargo clippy -p arrow-integration-test --all-targets --all-features -- -D warnings + - name: Clippy arrow-integration-testing with all features + run: cargo clippy -p arrow-integration-testing --all-targets --all-features -- -D warnings diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index 86a67ff9a6a4..242e0f2a3b0d 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -19,6 +19,9 @@ # tests for arrow_flight crate name: arrow_flight +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -27,35 +30,51 @@ on: - master pull_request: paths: - - arrow/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-data/** - arrow-flight/** + - arrow-ipc/** + - arrow-schema/** + - arrow-select/** - .github/** jobs: - # test the crate linux-test: name: Test runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test run: | cargo test -p arrow-flight - name: Test --all-features run: | cargo test -p arrow-flight --all-features + - name: Test --examples + run: | + cargo test -p arrow-flight --features=flight-sql-experimental,tls --examples + + vendor: + name: Verify Vendored Code + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Run gen + run: ./arrow-flight/regen.sh + - name: Verify workspace clean (if this fails, run ./arrow-flight/regen.sh and check in results) + run: git diff --exit-code clippy: name: Clippy @@ -63,14 +82,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p arrow-flight --all-features -- -D warnings + run: cargo clippy -p arrow-flight --all-targets --all-features -- -D warnings diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 000000000000..2c1dcdfd2100 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: audit + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - '**/Cargo.toml' + - '**/Cargo.lock' + +jobs: + cargo-audit: + name: Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-audit + run: cargo install cargo-audit + - name: Run audit check + run: cargo audit diff --git a/.github/workflows/cancel.yml b/.github/workflows/cancel.yml deleted file mode 100644 index a98c8ee5d225..000000000000 --- a/.github/workflows/cancel.yml +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Attempt to cancel stale workflow runs to save github actions runner time -name: cancel - -on: - workflow_run: - # The name of another workflow (whichever one) that always runs on PRs - workflows: ['Dev'] - types: ['requested'] - -jobs: - cancel-stale-workflow-runs: - name: "Cancel stale workflow runs" - runs-on: ubuntu-latest - steps: - # Unfortunately, we need to define a separate cancellation step for - # each workflow where we want to cancel stale runs. - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Dev runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: dev.yml - skipEventTypes: '["push", "schedule"]' - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Integration runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: integration.yml - skipEventTypes: '["push", "schedule"]' - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Rust runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: rust.yml - skipEventTypes: '["push", "schedule"]' diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e688428e187c..37d697dc3440 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -17,6 +17,10 @@ name: coverage +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # Trigger only on pushes to master, not pull requests on: push: @@ -32,7 +36,7 @@ jobs: # otherwise we get this error: # Failed to run tests: ASLR disable failed: EPERM: Operation not permitted steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -47,7 +51,7 @@ jobs: curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/protoc-21.4-linux-x86_64.zip unzip protoc-21.4-linux-x86_64.zip - name: Cache Cargo - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /home/runner/.cargo key: cargo-coverage-cache3- diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 57dc19482761..2026e257ab29 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -17,6 +17,10 @@ name: dev +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -34,9 +38,9 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.8 - name: Audit licenses @@ -46,12 +50,12 @@ jobs: name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-node@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 with: node-version: "14" - name: Prettier check run: | # if you encounter error, run the command below and commit the changes - npx prettier@2.3.2 --write {arrow,arrow-flight,dev,integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md + npx prettier@2.3.2 --write {arrow,arrow-flight,dev,arrow-integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md git diff --exit-code diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 38bb39390097..0d60ae006796 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -17,6 +17,10 @@ name: dev_pr +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # Trigger whenever a PR is changed (title as well as new / changed commits) on: pull_request_target: @@ -29,15 +33,18 @@ jobs: process: name: Process runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.0.1 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index aadf9c377c64..cae015018eac 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,16 +16,40 @@ # under the License. arrow: - - arrow/**/* + - changed-files: + - any-glob-to-any-file: + - 'arrow-arith/**/*' + - 'arrow-array/**/*' + - 'arrow-buffer/**/*' + - 'arrow-cast/**/*' + - 'arrow-csv/**/*' + - 'arrow-data/**/*' + - 'arrow-flight/**/*' + - 'arrow-integration-test/**/*' + - 'arrow-integration-testing/**/*' + - 'arrow-ipc/**/*' + - 'arrow-json/**/*' + - 'arrow-avro/**/*' + - 'arrow-ord/**/*' + - 'arrow-row/**/*' + - 'arrow-schema/**/*' + - 'arrow-select/**/*' + - 'arrow-string/**/*' + - 'arrow/**/*' arrow-flight: - - arrow-flight/**/* + - changed-files: + - any-glob-to-any-file: + - 'arrow-flight/**/*' parquet: - - parquet/**/* + - changed-files: + - any-glob-to-any-file: [ 'parquet/**/*' ] parquet-derive: - - parquet_derive/**/* + - changed-files: + - any-glob-to-any-file: [ 'parquet_derive/**/*' ] object-store: - - object_store/**/* + - changed-files: + - any-glob-to-any-file: [ 'object_store/**/*' ] diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5e82d76febe6..08d287bcceb2 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,6 +17,10 @@ name: docs +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -37,19 +41,59 @@ jobs: container: image: ${{ matrix.arch }}/rust env: - RUSTDOCFLAGS: "-Dwarnings" + RUSTDOCFLAGS: "-Dwarnings --enable-index-page -Zunstable-options" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install python dev run: | apt update - apt install -y libpython3.9-dev + apt install -y libpython3.11-dev - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: ${{ matrix.rust }} - name: Run cargo doc + run: cargo doc --document-private-items --no-deps --workspace --all-features + - name: Fix file permissions + shell: sh + run: | + chmod -c -R +rX "target/doc" | + while read line; do + echo "::warning title=Invalid file permissions automatically fixed::$line" + done + - name: Upload artifacts + uses: actions/upload-pages-artifact@v3 + with: + name: crate-docs + path: target/doc + + deploy: + # Only deploy if a push to master + if: github.ref_name == 'master' && github.event_name == 'push' + needs: docs + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Download crate docs + uses: actions/download-artifact@v4 + with: + name: crate-docs + path: website/build + - name: Prepare website run: | - cargo doc --document-private-items --no-deps --workspace --all-features + tar -xf website/build/artifact.tar -C website/build + rm website/build/artifact.tar + cp .asf.yaml ./website/build/.asf.yaml + - name: Deploy to gh-pages + uses: peaceiris/actions-gh-pages@v4.0.0 + if: github.event_name == 'push' && github.ref_name == 'master' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: website/build + publish_branch: asf-site + # Avoid accumulating history of in progress API jobs: https://github.com/apache/arrow-rs/issues/5908 + force_orphan: true diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 10a8e30212a9..868729a168e8 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -17,6 +17,10 @@ name: integration +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -24,13 +28,26 @@ on: - master pull_request: paths: - - arrow/** - - arrow-pyarrow-integration-testing/** - - integration-testing/** - .github/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-integration-test/** + - arrow-integration-testing/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-ord/** + - arrow-pyarrow-integration-testing/** + - arrow-schema/** + - arrow-select/** + - arrow-sort/** + - arrow-string/** + - arrow/** jobs: - integration: name: Archery test With other arrows runs-on: ubuntu-latest @@ -39,7 +56,19 @@ jobs: env: ARROW_USE_CCACHE: OFF ARROW_CPP_EXE_PATH: /build/cpp/debug + ARROW_NANOARROW_PATH: /build/nanoarrow + ARROW_RUST_EXE_PATH: /build/rust/debug BUILD_DOCS_CPP: OFF + ARROW_INTEGRATION_CPP: ON + ARROW_INTEGRATION_CSHARP: ON + ARROW_INTEGRATION_GO: ON + ARROW_INTEGRATION_JAVA: ON + ARROW_INTEGRATION_JS: ON + ARCHERY_INTEGRATION_TARGET_IMPLEMENTATIONS: "rust" + # Disable nanoarrow integration, due to https://github.com/apache/arrow-rs/issues/5052 + ARCHERY_INTEGRATION_WITH_NANOARROW: "0" + # https://github.com/apache/arrow/pull/38403/files#r1371281630 + ARCHERY_INTEGRATION_WITH_RUST: "1" # These are necessary because the github runner overrides $HOME # https://github.com/actions/runner/issues/863 RUSTUP_HOME: /root/.rustup @@ -59,49 +88,30 @@ jobs: - name: Check cmake run: which cmake - name: Checkout Arrow - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: apache/arrow submodules: true fetch-depth: 0 - name: Checkout Arrow Rust - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: path: rust fetch-depth: 0 - - name: Make build directory - run: mkdir /build - - name: Build Rust - run: conda run --no-capture-output ci/scripts/rust_build.sh $PWD /build - - name: Build C++ - run: conda run --no-capture-output ci/scripts/cpp_build.sh $PWD /build - - name: Build C# - run: conda run --no-capture-output ci/scripts/csharp_build.sh $PWD /build - - name: Build Go - run: conda run --no-capture-output ci/scripts/go_build.sh $PWD - - name: Build Java - run: conda run --no-capture-output ci/scripts/java_build.sh $PWD /build - # Temporarily disable JS https://issues.apache.org/jira/browse/ARROW-17410 - # - name: Build JS - # run: conda run --no-capture-output ci/scripts/js_build.sh $PWD /build - - name: Install archery - run: conda run --no-capture-output pip install -e dev/archery - - name: Run integration tests - run: | - conda run --no-capture-output archery integration \ - --run-flight \ - --with-cpp=1 \ - --with-csharp=1 \ - --with-java=1 \ - --with-js=0 \ - --with-go=1 \ - --with-rust=1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/0.14.1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/0.17.1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-bigendian \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-littleendian \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/2.0.0-compression \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/4.0.0-shareddict + - name: Checkout Arrow nanoarrow + uses: actions/checkout@v4 + with: + repository: apache/arrow-nanoarrow + path: nanoarrow + fetch-depth: 0 + # Workaround https://github.com/rust-lang/rust/issues/125067 + - name: Downgrade rust + working-directory: rust + run: rustup override set 1.77 + - name: Build + run: conda run --no-capture-output ci/scripts/integration_arrow_build.sh $PWD /build + - name: Run + run: conda run --no-capture-output ci/scripts/integration_arrow.sh $PWD /build # test FFI against the C-Data interface exposed by pyarrow pyarrow-integration-test: @@ -109,9 +119,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: [ stable ] + rust: [stable] + # PyArrow 15 was the first version to introduce StringView/BinaryView support + pyarrow: ["15", "16", "17"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -120,33 +132,33 @@ jobs: rustup default ${{ matrix.rust }} rustup component add rustfmt clippy - name: Cache Cargo - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /home/runner/.cargo key: cargo-maturin-cache- - name: Cache Rust dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.8' - name: Upgrade pip and setuptools run: pip install --upgrade pip setuptools wheel virtualenv - name: Create virtualenv and install dependencies run: | virtualenv venv source venv/bin/activate - pip install maturin toml pytest pytz pyarrow>=5.0 + pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }} + - name: Run Rust tests + run: | + source venv/bin/activate + cargo test -p arrow --test pyarrow --features pyarrow - name: Run tests - env: - CARGO_HOME: "/home/runner/.cargo" - CARGO_TARGET_DIR: "/home/runner/target" run: | source venv/bin/activate - pushd arrow-pyarrow-integration-testing + cd arrow-pyarrow-integration-testing maturin develop pytest -v . - popd diff --git a/.github/workflows/miri.sh b/.github/workflows/miri.sh index 56da5c5c5d3e..86be2100ee67 100755 --- a/.github/workflows/miri.sh +++ b/.github/workflows/miri.sh @@ -5,13 +5,16 @@ # Must be run with nightly rust for example # rustup default nightly +set -e -# stacked borrows checking uses too much memory to run successfully in github actions -# re-enable if the CI is migrated to something more powerful (https://github.com/apache/arrow-rs/issues/1833) -# see also https://github.com/rust-lang/miri/issues/1367 -export MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" +export MIRIFLAGS="-Zmiri-disable-isolation" cargo miri setup cargo clean echo "Starting Arrow MIRI run..." -cargo miri test -p arrow -- --skip csv --skip ipc --skip json +cargo miri test -p arrow-buffer +cargo miri test -p arrow-data --features ffi +cargo miri test -p arrow-schema --features ffi +cargo miri test -p arrow-ord +cargo miri test -p arrow-array +cargo miri test -p arrow-arith \ No newline at end of file diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index b4669bbcccc0..19b432121b6f 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -17,6 +17,10 @@ name: miri +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -24,15 +28,26 @@ on: - master pull_request: paths: - - arrow/** - .github/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-schema/** + - arrow-select/** + - arrow-string/** + - arrow/** jobs: miri-checks: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -44,5 +59,4 @@ jobs: env: RUST_BACKTRACE: full RUST_LOG: "trace" - run: | - bash .github/workflows/miri.sh + run: bash .github/workflows/miri.sh diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 6996aa706636..bdbfc0bec4bb 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -19,6 +19,10 @@ # tests for `object_store` crate name: object_store +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -35,29 +39,51 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust + defaults: + run: + working-directory: object_store steps: - - uses: actions/checkout@v3 - - name: Setup Rust toolchain with clippy - run: | - rustup toolchain install stable - rustup default stable - rustup component add clippy + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Setup Clippy + run: rustup component add clippy # Run different tests for the library on its own as well as # all targets to ensure that it still works in the absence of # features that might be enabled by dev-dependencies of other # targets. - name: Run clippy with default features - run: cargo clippy -p object_store -- -D warnings + run: cargo clippy -- -D warnings - name: Run clippy with aws feature - run: cargo clippy -p object_store --features aws -- -D warnings + run: cargo clippy --features aws -- -D warnings - name: Run clippy with gcp feature - run: cargo clippy -p object_store --features gcp -- -D warnings + run: cargo clippy --features gcp -- -D warnings - name: Run clippy with azure feature - run: cargo clippy -p object_store --features azure -- -D warnings + run: cargo clippy --features azure -- -D warnings + - name: Run clippy with http feature + run: cargo clippy --features http -- -D warnings - name: Run clippy with all features - run: cargo clippy -p object_store --all-features -- -D warnings + run: cargo clippy --all-features -- -D warnings - name: Run clippy with all features and all targets - run: cargo clippy -p object_store --all-features --all-targets -- -D warnings + run: cargo clippy --all-features --all-targets -- -D warnings + + # test doc links still work + # + # Note that since object_store is not part of the main workspace, + # this needs a separate docs job as it is not covered by + # `cargo doc --workspace` + docs: + name: Rustdocs + runs-on: ubuntu-latest + defaults: + run: + working-directory: object_store + env: + RUSTDOCFLAGS: "-Dwarnings" + steps: + - uses: actions/checkout@v4 + - name: Run cargo doc + run: cargo doc --document-private-items --no-deps --workspace --all-features # test the crate # This runs outside a container to workaround lack of support for passing arguments @@ -65,47 +91,66 @@ jobs: linux-test: name: Emulator Tests runs-on: ubuntu-latest + defaults: + run: + working-directory: object_store env: # Disable full debug symbol generation to speed up CI build and keep memory down # "1" means line tables only, which is useful for panic tracebacks. RUSTFLAGS: "-C debuginfo=1" - # https://github.com/rust-lang/cargo/issues/10280 - CARGO_NET_GIT_FETCH_WITH_CLI: "true" RUST_BACKTRACE: "1" # Run integration tests TEST_INTEGRATION: 1 EC2_METADATA_ENDPOINT: http://localhost:1338 - AZURE_USE_EMULATOR: "1" + AZURE_CONTAINER_NAME: test-bucket + AZURE_STORAGE_USE_EMULATOR: "1" AZURITE_BLOB_STORAGE_URL: "http://localhost:10000" AZURITE_QUEUE_STORAGE_URL: "http://localhost:10001" + AWS_BUCKET: test-bucket + AWS_DEFAULT_REGION: "us-east-1" + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT: http://localhost:4566 + AWS_ALLOW_HTTP: true + AWS_COPY_IF_NOT_EXISTS: dynamo:test-table:2000 + AWS_CONDITIONAL_PUT: dynamo:test-table:2000 + AWS_SERVER_SIDE_ENCRYPTION: aws:kms + HTTP_URL: "http://localhost:8080" + GOOGLE_BUCKET: test-bucket GOOGLE_SERVICE_ACCOUNT: "/tmp/gcs.json" - OBJECT_STORE_BUCKET: test-bucket steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + # We are forced to use docker commands instead of service containers as we need to override the entrypoints + # which is currently not supported - https://github.com/actions/runner/discussions/1872 - name: Configure Fake GCS Server (GCP emulation) + # Custom image - see fsouza/fake-gcs-server#1164 run: | - docker run -d -p 4443:4443 fsouza/fake-gcs-server -scheme http + echo "GCS_CONTAINER=$(docker run -d -p 4443:4443 tustvold/fake-gcs-server -scheme http -backend memory -public-host localhost:4443)" >> $GITHUB_ENV + # Give the container a moment to start up prior to configuring it + sleep 1 curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" - echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": "", "private_key_id": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + + - name: Setup WebDav + run: docker run -d -p 8080:80 rclone/rclone serve webdav /data --addr :80 - name: Setup LocalStack (AWS emulation) - env: - AWS_DEFAULT_REGION: "us-east-1" - AWS_ACCESS_KEY_ID: test - AWS_SECRET_ACCESS_KEY: test - AWS_ENDPOINT: http://localhost:4566 run: | - docker run -d -p 4566:4566 localstack/localstack:0.14.4 - docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2 + echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:3.3.0)" >> $GITHUB_ENV + echo "EC2_METADATA_CONTAINER=$(docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2)" >> $GITHUB_ENV aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket + aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5 + + KMS_KEY=$(aws --endpoint-url=http://localhost:4566 kms create-key --description "test key") + echo "AWS_SSE_KMS_KEY_ID=$(echo $KMS_KEY | jq -r .KeyMetadata.KeyId)" >> $GITHUB_ENV - name: Configure Azurite (Azure emulation) # the magical connection string is from # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings run: | - docker run -d -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azure-storage/azurite + echo "AZURITE_CONTAINER=$(docker run -d -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azure-storage/azurite)" >> $GITHUB_ENV az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://localhost:10000/devstoreaccount1;QueueEndpoint=http://localhost:10001/devstoreaccount1;' - name: Setup Rust toolchain @@ -114,11 +159,55 @@ jobs: rustup default stable - name: Run object_store tests - env: - OBJECT_STORE_AWS_DEFAULT_REGION: "us-east-1" - OBJECT_STORE_AWS_ACCESS_KEY_ID: test - OBJECT_STORE_AWS_SECRET_ACCESS_KEY: test - OBJECT_STORE_AWS_ENDPOINT: http://localhost:4566 - run: | - # run tests - cargo test -p object_store --features=aws,azure,gcp + run: cargo test --features=aws,azure,gcp,http + + - name: GCS Output + if: ${{ !cancelled() }} + run: docker logs $GCS_CONTAINER + + - name: LocalStack Output + if: ${{ !cancelled() }} + run: docker logs $LOCALSTACK_CONTAINER + + - name: EC2 Metadata Output + if: ${{ !cancelled() }} + run: docker logs $EC2_METADATA_CONTAINER + + - name: Azurite Output + if: ${{ !cancelled() }} + run: docker logs $AZURITE_CONTAINER + + # test the object_store crate builds against wasm32 in stable rust + wasm32-build: + name: Build wasm32 + runs-on: ubuntu-latest + container: + image: amd64/rust + defaults: + run: + working-directory: object_store + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + target: wasm32-unknown-unknown,wasm32-wasi + - name: Build wasm32-unknown-unknown + run: cargo build --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build --target wasm32-wasi + + windows: + name: cargo test LocalFileSystem (win64) + runs-on: windows-latest + defaults: + run: + working-directory: object_store + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Run LocalFileSystem tests + run: cargo test local::tests diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 42cb06bb0a86..a4e654892662 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -19,6 +19,9 @@ # tests for parquet crate name: "parquet" +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -28,6 +31,16 @@ on: pull_request: paths: - arrow/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-data/** + - arrow-schema/** + - arrow-select/** + - arrow-ipc/** + - arrow-csv/** + - arrow-json/** + - arrow-avro/** - parquet/** - .github/** @@ -38,25 +51,22 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test - run: | - cargo test -p parquet + run: cargo test -p parquet - name: Test --all-features + run: cargo test -p parquet --all-features + - name: Run examples run: | - cargo test -p parquet --all-features - + # Test parquet examples + cargo run -p parquet --example read_parquet + cargo run -p parquet --example async_read_parquet --features="async" + cargo run -p parquet --example read_with_rowgroup --features="async" # test compilation linux-features: @@ -64,18 +74,12 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable # Run different tests for the library on its own as well as # all targets to ensure that it still works in the absence of @@ -88,29 +92,78 @@ jobs: # 3. compiles with just arrow feature # 3. compiles with all features - name: Check compilation - run: | - cargo check -p parquet + run: cargo check -p parquet - name: Check compilation --no-default-features - run: | - cargo check -p parquet --no-default-features + run: cargo check -p parquet --no-default-features - name: Check compilation --no-default-features --features arrow - run: | - cargo check -p parquet --no-default-features --features arrow + run: cargo check -p parquet --no-default-features --features arrow - name: Check compilation --no-default-features --all-features - run: | - cargo check -p parquet --all-features + run: cargo check -p parquet --all-features - name: Check compilation --all-targets - run: | - cargo check -p parquet --all-targets + run: cargo check -p parquet --all-targets - name: Check compilation --all-targets --no-default-features - run: | - cargo check -p parquet --all-targets --no-default-features + run: cargo check -p parquet --all-targets --no-default-features - name: Check compilation --all-targets --no-default-features --features arrow - run: | - cargo check -p parquet --all-targets --no-default-features --features arrow + run: cargo check -p parquet --all-targets --no-default-features --features arrow - name: Check compilation --all-targets --all-features + run: cargo check -p parquet --all-targets --all-features + - name: Check compilation --all-targets --no-default-features --features json + run: cargo check -p parquet --all-targets --no-default-features --features json + + # test the parquet crate builds against wasm32 in stable rust + wasm32-build: + name: Build wasm32 + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + target: wasm32-unknown-unknown,wasm32-wasi + - name: Install clang # Needed for zlib compilation + run: apt-get update && apt-get install -y clang gcc-multilib + - name: Build wasm32-unknown-unknown + run: cargo build -p parquet --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build -p parquet --target wasm32-wasi + + pyspark-integration-test: + name: PySpark Integration Test + runs-on: ubuntu-latest + strategy: + matrix: + rust: [ stable ] + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + - name: Install Python dependencies run: | - cargo check -p parquet --all-targets --all-features + cd parquet/pytest + pip install -r requirements.txt + - name: Black check the test files + run: | + cd parquet/pytest + black --check *.py --verbose + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + - name: Install binary for checking + run: | + cargo install --path parquet --bin parquet-show-bloom-filter --features=cli + cargo install --path parquet --bin parquet-fromcsv --features=arrow,cli + - name: Run pytest + run: | + cd parquet/pytest + pytest -v clippy: name: Clippy @@ -118,14 +171,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p parquet --all-targets --all-features -- -D warnings + run: cargo clippy -p parquet --all-targets --all-features -- -D warnings diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index bd70fc30d1c5..d8b02f73a8aa 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -19,6 +19,9 @@ # tests for parquet_derive crate name: parquet_derive +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -39,21 +42,14 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test - run: | - cargo test -p parquet_derive + run: cargo test -p parquet_derive clippy: name: Clippy @@ -61,14 +57,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p parquet_derive --all-features -- -D warnings + run: cargo clippy -p parquet_derive --all-features -- -D warnings diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c04d5643b49a..1b65c5057de1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -18,6 +18,10 @@ # workspace wide tests name: rust +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -33,12 +37,11 @@ jobs: name: Test on Mac runs-on: macos-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protoc with brew - run: | - brew install protobuf + run: brew install protobuf - name: Setup Rust toolchain run: | rustup toolchain install stable --no-self-update @@ -57,7 +60,7 @@ jobs: name: Test on Windows runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protobuf compiler in /d/protoc @@ -90,11 +93,53 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 - - name: Setup toolchain - run: | - rustup toolchain install stable - rustup default stable - rustup component add rustfmt - - name: Run + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Setup rustfmt + run: rustup component add rustfmt + - name: Format arrow + run: cargo fmt --all -- --check + - name: Format parquet + # Many modules in parquet are skipped, so check parquet separately. If this check fails, run: + # cargo fmt -p parquet -- --config skip_children=true `find ./parquet -name "*.rs" \! -name format.rs` + # from the top level arrow-rs directory and check in the result. + # https://github.com/apache/arrow-rs/issues/6179 + working-directory: parquet + run: cargo fmt -p parquet -- --check --config skip_children=true `find . -name "*.rs" \! -name format.rs` + - name: Format object_store + working-directory: object_store run: cargo fmt --all -- --check + + msrv: + name: Verify MSRV + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Install cargo-msrv + run: cargo install cargo-msrv + - name: Downgrade arrow dependencies + run: cargo update -p ahash --precise 0.8.7 + - name: Check arrow + working-directory: arrow + run: cargo msrv --log-target stdout verify + - name: Check parquet + working-directory: parquet + run: cargo msrv --log-target stdout verify + - name: Check arrow-flight + working-directory: arrow-flight + run: cargo msrv --log-target stdout verify + - name: Downgrade object_store dependencies + working-directory: object_store + # Necessary because tokio 1.30.0 updates MSRV to 1.63 + # and url 2.5.1, updates to 1.67 + run: | + cargo update -p tokio --precise 1.29.1 + cargo update -p url --precise 2.5.0 + - name: Check object_store + working-directory: object_store + run: cargo msrv --log-target stdout verify diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml new file mode 100644 index 000000000000..dd21c794960e --- /dev/null +++ b/.github/workflows/take.yml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Assign the issue via a `take` comment +on: + issue_comment: + types: created + +permissions: + issues: write + +jobs: + issue_assign: + if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' + runs-on: ubuntu-latest + steps: + - uses: actions/github-script@v7 + with: + script: | + github.rest.issues.addAssignees({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + assignees: [context.payload.comment.user.login], + }) diff --git a/.github_changelog_generator b/.github_changelog_generator index 9a9a84344866..a8279702b3aa 100644 --- a/.github_changelog_generator +++ b/.github_changelog_generator @@ -24,5 +24,5 @@ add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":[" #pull-requests=false # so that the component is shown associated with the issue issue-line-labels=arrow,parquet,arrow-flight -exclude-labels=development-process,invalid,object-store +exclude-labels=development-process,invalid,object-store,question breaking_labels=api-change diff --git a/.gitignore b/.gitignore index 2a21776aa545..0788daea0166 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ rusty-tags.vi .flatbuffers/ .idea/ .vscode +.devcontainer venv/* # created by doctests parquet/data.parquet @@ -14,13 +15,14 @@ parquet/data.parquet justfile .prettierignore .env +.editorconfig # local azurite file __azurite* __blobstorage__ # .bak files *.bak - +*.bak2 # OS-specific .gitignores # Mac .gitignore @@ -92,3 +94,6 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +# Python virtual env in parquet crate +parquet/pytest/venv/ +__pycache__/ diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 70322b5cfd1d..4808cde703ad 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -17,9 +17,2810 @@ under the License. --> - # Historical Changelog +## [52.2.0](https://github.com/apache/arrow-rs/tree/52.2.0) (2024-07-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/52.1.0...52.2.0) + +**Implemented enhancements:** + +- Faster min/max for string/binary view arrays [\#6088](https://github.com/apache/arrow-rs/issues/6088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting to/from Utf8View [\#6076](https://github.com/apache/arrow-rs/issues/6076) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Min/max support for String/BinaryViewArray [\#6052](https://github.com/apache/arrow-rs/issues/6052) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of constructing `ByteView`s for small strings [\#6034](https://github.com/apache/arrow-rs/issues/6034) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fast UTF-8 validation when reading StringViewArray from Parquet [\#5995](https://github.com/apache/arrow-rs/issues/5995) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optimize StringView row decoding [\#5945](https://github.com/apache/arrow-rs/issues/5945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implementing `deduplicate` / `intern` functionality for StringView [\#5910](https://github.com/apache/arrow-rs/issues/5910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `FlightSqlServiceClient::new_from_inner` [\#6003](https://github.com/apache/arrow-rs/pull/6003) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- Complete `StringViewArray` and `BinaryViewArray` parquet decoder: [\#6004](https://github.com/apache/arrow-rs/pull/6004) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add begin/end\_transaction methods in FlightSqlServiceClient [\#6026](https://github.com/apache/arrow-rs/pull/6026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- Read Parquet statistics as arrow `Arrays` [\#6046](https://github.com/apache/arrow-rs/pull/6046) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([efredine](https://github.com/efredine)) + +**Fixed bugs:** + +- Panic in `ParquetMetadata::memory_size` if no min/max set [\#6091](https://github.com/apache/arrow-rs/issues/6091) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- BinaryViewArray doesn't roundtrip a single `Some(&[])` through parquet [\#6086](https://github.com/apache/arrow-rs/issues/6086) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet `ColumnIndex` for null columns is written even when statistics are disabled [\#6010](https://github.com/apache/arrow-rs/issues/6010) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Fix typo in GenericByteViewArray documentation [\#6054](https://github.com/apache/arrow-rs/pull/6054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([progval](https://github.com/progval)) +- Minor: Improve parquet PageIndex documentation [\#6042](https://github.com/apache/arrow-rs/pull/6042) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Potential performance improvements for reading Parquet to StringViewArray/BinaryViewArray [\#5904](https://github.com/apache/arrow-rs/issues/5904) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Faster `GenericByteView` construction [\#6102](https://github.com/apache/arrow-rs/pull/6102) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add benchmark to track byte-view construction performance [\#6101](https://github.com/apache/arrow-rs/pull/6101) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Optimize `bool_or` using `max_boolean` [\#6100](https://github.com/apache/arrow-rs/pull/6100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- Optimize `max_boolean` by operating on u64 chunks [\#6098](https://github.com/apache/arrow-rs/pull/6098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- fix panic in `ParquetMetadata::memory_size`: check has\_min\_max\_set before invoking min\(\)/max\(\) [\#6092](https://github.com/apache/arrow-rs/pull/6092) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Fischer0522](https://github.com/Fischer0522)) +- Implement specialized min/max for `GenericBinaryView` \(`StringView` and `BinaryView`\) [\#6089](https://github.com/apache/arrow-rs/pull/6089) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add PartialEq to ParquetMetaData and FileMetadata [\#6082](https://github.com/apache/arrow-rs/pull/6082) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adriangb](https://github.com/adriangb)) +- Enable casting from Utf8View [\#6077](https://github.com/apache/arrow-rs/pull/6077) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([a10y](https://github.com/a10y)) +- StringView support in arrow-csv [\#6062](https://github.com/apache/arrow-rs/pull/6062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) +- Implement min max support for string/binary view types [\#6053](https://github.com/apache/arrow-rs/pull/6053) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Minor: clarify the relationship between `file::metadata` and `format` in docs [\#6049](https://github.com/apache/arrow-rs/pull/6049) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor API adjustments for StringViewBuilder [\#6047](https://github.com/apache/arrow-rs/pull/6047) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add parquet `StatisticsConverter` for arrow reader [\#6046](https://github.com/apache/arrow-rs/pull/6046) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([efredine](https://github.com/efredine)) +- Directly decode String/BinaryView types from arrow-row format [\#6044](https://github.com/apache/arrow-rs/pull/6044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Clean up unused code for view types in offset buffer [\#6040](https://github.com/apache/arrow-rs/pull/6040) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Avoid using Buffer api that accidentally copies data [\#6039](https://github.com/apache/arrow-rs/pull/6039) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([XiangpengHao](https://github.com/XiangpengHao)) +- MINOR: Fix `hashbrown` version in `arrow-array`, remove from `arrow-row` [\#6035](https://github.com/apache/arrow-rs/pull/6035) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Improve performance reading `ByteViewArray` from parquet by removing an implicit copy [\#6031](https://github.com/apache/arrow-rs/pull/6031) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add begin/end\_transaction methods in FlightSqlServiceClient [\#6026](https://github.com/apache/arrow-rs/pull/6026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- Unsafe improvements: core `parquet` crate. [\#6024](https://github.com/apache/arrow-rs/pull/6024) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([veluca93](https://github.com/veluca93)) +- Additional tests for parquet reader utf8 validation [\#6023](https://github.com/apache/arrow-rs/pull/6023) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update zstd-sys requirement from \>=2.0.0, \<2.0.12 to \>=2.0.0, \<2.0.13 [\#6019](https://github.com/apache/arrow-rs/pull/6019) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix doc ci in latest rust nightly version [\#6012](https://github.com/apache/arrow-rs/pull/6012) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Rachelint](https://github.com/Rachelint)) +- Do not write `ColumnIndex` for null columns when not writing page statistics [\#6011](https://github.com/apache/arrow-rs/pull/6011) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fast utf8 validation when loading string view from parquet [\#6009](https://github.com/apache/arrow-rs/pull/6009) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Deduplicate strings/binarys when building view types [\#6005](https://github.com/apache/arrow-rs/pull/6005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Complete `StringViewArray` and `BinaryViewArray` parquet decoder: implement delta byte array and delta length byte array encoding [\#6004](https://github.com/apache/arrow-rs/pull/6004) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add `FlightSqlServiceClient::new_from_inner` [\#6003](https://github.com/apache/arrow-rs/pull/6003) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- Rename `Schema::all_fields` to `flattened_fields` [\#6001](https://github.com/apache/arrow-rs/pull/6001) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- Refine documentation and examples for `DataType` [\#5997](https://github.com/apache/arrow-rs/pull/5997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- implement `DataType::try_form(&str)` [\#5994](https://github.com/apache/arrow-rs/pull/5994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- Implement dictionary support for reading ByteView from parquet [\#5973](https://github.com/apache/arrow-rs/pull/5973) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +## [52.1.0](https://github.com/apache/arrow-rs/tree/52.1.0) (2024-07-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/52.0.0...52.1.0) + + +**Implemented enhancements:** + +- Implement `eq` comparison for StructArray [\#5960](https://github.com/apache/arrow-rs/issues/5960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- A new feature as a workaround hack to unavailable offset support in Arrow Java [\#5959](https://github.com/apache/arrow-rs/issues/5959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `min_bytes` and `max_bytes` to `PageIndex` [\#5949](https://github.com/apache/arrow-rs/issues/5949) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error message in ArrowNativeTypeOp::neg\_checked doesn't include the operation [\#5944](https://github.com/apache/arrow-rs/issues/5944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add object\_store\_opendal as related projects [\#5925](https://github.com/apache/arrow-rs/issues/5925) +- Opaque retry errors make debugging difficult [\#5923](https://github.com/apache/arrow-rs/issues/5923) +- Implement arrow-row en/decoding for GenericByteView types [\#5921](https://github.com/apache/arrow-rs/issues/5921) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The arrow-rs repo is very large [\#5908](https://github.com/apache/arrow-rs/issues/5908) +- \[DISCUSS\] Release arrow-rs / parquet patch release `52.0.1` [\#5906](https://github.com/apache/arrow-rs/issues/5906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `compare_op` for `GenericBinaryView` [\#5897](https://github.com/apache/arrow-rs/issues/5897) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- New null with view types are not supported [\#5893](https://github.com/apache/arrow-rs/issues/5893) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cleanup ByteView construction [\#5878](https://github.com/apache/arrow-rs/issues/5878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `cast` kernel support for `StringViewArray` and `BinaryViewArray` `\<--\> `DictionaryArray` [\#5861](https://github.com/apache/arrow-rs/issues/5861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet::ArrowWriter show allow writing Bloom filters before the end of the file [\#5859](https://github.com/apache/arrow-rs/issues/5859) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- API to get memory usage for parquet ArrowWriter [\#5851](https://github.com/apache/arrow-rs/issues/5851) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support writing `IntervalMonthDayNanoArray` to parquet via Arrow Writer [\#5849](https://github.com/apache/arrow-rs/issues/5849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Write parquet statistics for `IntervalDayTimeArray` , `IntervalMonthDayNanoArray` and `IntervalYearMonthArray` [\#5847](https://github.com/apache/arrow-rs/issues/5847) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make `RowSelection::from_consecutive_ranges` public [\#5846](https://github.com/apache/arrow-rs/issues/5846) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `Schema::try_merge` should be able to merge List of any data type with List of Null data type [\#5843](https://github.com/apache/arrow-rs/issues/5843) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add a way to move `fields` out of parquet `Row` [\#5841](https://github.com/apache/arrow-rs/issues/5841) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make `TimeUnit` and `IntervalUnit` `Copy` [\#5839](https://github.com/apache/arrow-rs/issues/5839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Limit Parquet Page Row Count By Default to reduce writer memory requirements with highly compressable columns [\#5797](https://github.com/apache/arrow-rs/issues/5797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Report / blog on parquet metadata sizes for "large" \(1000+\) numbers of columns [\#5770](https://github.com/apache/arrow-rs/issues/5770) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Structured ByteView Access \(underlying StringView/BinaryView representation\) [\#5736](https://github.com/apache/arrow-rs/issues/5736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[parquet\_derive\] support OPTIONAL \(def\_level = 1\) columns by default [\#5716](https://github.com/apache/arrow-rs/issues/5716) +- Maps cast to other Maps with different Elements, Key and Value Names [\#5702](https://github.com/apache/arrow-rs/issues/5702) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Provide Arrow Schema Hint to Parquet Reader [\#5657](https://github.com/apache/arrow-rs/issues/5657) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Wrong error type in case of invalid amount in Interval components [\#5986](https://github.com/apache/arrow-rs/issues/5986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Empty and Null structarray fails to IPC roundtrip [\#5920](https://github.com/apache/arrow-rs/issues/5920) +- FixedSizeList got out of range when the total length of the underlying values over i32::MAX [\#5901](https://github.com/apache/arrow-rs/issues/5901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Out of range when extending on a slice of string array imported through FFI [\#5896](https://github.com/apache/arrow-rs/issues/5896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cargo msrv test is failing on main for `object_store` [\#5864](https://github.com/apache/arrow-rs/issues/5864) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- chore: update RunArray reference in run\_iterator.rs [\#5892](https://github.com/apache/arrow-rs/pull/5892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Minor: Clarify when page index structures are read [\#5886](https://github.com/apache/arrow-rs/pull/5886) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve Parquet reader/writer properties docs [\#5863](https://github.com/apache/arrow-rs/pull/5863) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Refine documentation for `unary_mut` and `binary_mut` [\#5798](https://github.com/apache/arrow-rs/pull/5798) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Implement benchmarks for `compare_op` for `GenericBinaryView` [\#5903](https://github.com/apache/arrow-rs/issues/5903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- fix: error in case of invalid amount interval component [\#5987](https://github.com/apache/arrow-rs/pull/5987) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([DDtKey](https://github.com/DDtKey)) +- Minor: fix clippy complaint in parquet\_derive [\#5984](https://github.com/apache/arrow-rs/pull/5984) ([alamb](https://github.com/alamb)) +- Reduce repo size by removing accumulative commits in CI job [\#5982](https://github.com/apache/arrow-rs/pull/5982) ([Owen-CH-Leung](https://github.com/Owen-CH-Leung)) +- Add operation in ArrowNativeTypeOp::neg\_check error message \(\#5944\) [\#5980](https://github.com/apache/arrow-rs/pull/5980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhao-gang](https://github.com/zhao-gang)) +- Implement directly build byte view array on top of parquet buffer [\#5972](https://github.com/apache/arrow-rs/pull/5972) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Handle flight dictionary ID assignment automatically [\#5971](https://github.com/apache/arrow-rs/pull/5971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Add view buffer for parquet reader [\#5970](https://github.com/apache/arrow-rs/pull/5970) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add benchmark for reading binary/binary view from parquet [\#5968](https://github.com/apache/arrow-rs/pull/5968) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- feat\(5851\): ArrowWriter memory usage [\#5967](https://github.com/apache/arrow-rs/pull/5967) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([wiedld](https://github.com/wiedld)) +- Add ParquetMetadata::memory\_size size estimation [\#5965](https://github.com/apache/arrow-rs/pull/5965) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix FFI array offset handling [\#5964](https://github.com/apache/arrow-rs/pull/5964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement sort for String/BinaryViewArray [\#5963](https://github.com/apache/arrow-rs/pull/5963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Improve error message for unsupported nested comparison [\#5961](https://github.com/apache/arrow-rs/pull/5961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- chore\(5797\): change default parquet data\_page\_row\_limit to 20k [\#5957](https://github.com/apache/arrow-rs/pull/5957) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([wiedld](https://github.com/wiedld)) +- Document process for PRs with breaking changes [\#5953](https://github.com/apache/arrow-rs/pull/5953) ([alamb](https://github.com/alamb)) +- Minor: fixup contribution guide about clippy [\#5952](https://github.com/apache/arrow-rs/pull/5952) ([alamb](https://github.com/alamb)) +- feat: add max\_bytes and min\_bytes on PageIndex [\#5950](https://github.com/apache/arrow-rs/pull/5950) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tshauck](https://github.com/tshauck)) +- test: Add unit test for extending slice of list array [\#5948](https://github.com/apache/arrow-rs/pull/5948) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- minor: row format benches for bool & nullable int [\#5943](https://github.com/apache/arrow-rs/pull/5943) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([korowa](https://github.com/korowa)) +- Better document support for nested comparison [\#5942](https://github.com/apache/arrow-rs/pull/5942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Provide Arrow Schema Hint to Parquet Reader - Alternative 2 [\#5939](https://github.com/apache/arrow-rs/pull/5939) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([efredine](https://github.com/efredine)) +- `like` benchmark for StringView [\#5936](https://github.com/apache/arrow-rs/pull/5936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix typo in benchmark name `egexp` --\> `regexp` [\#5935](https://github.com/apache/arrow-rs/pull/5935) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Revert "Write Bloom filters between row groups instead of the end " [\#5932](https://github.com/apache/arrow-rs/pull/5932) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Implement like/ilike etc for StringViewArray [\#5931](https://github.com/apache/arrow-rs/pull/5931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- docs: Fix broken links of object\_store\_opendal README [\#5929](https://github.com/apache/arrow-rs/pull/5929) ([Xuanwo](https://github.com/Xuanwo)) +- Expose `IntervalMonthDayNano` and `IntervalDayTime` and update docs [\#5928](https://github.com/apache/arrow-rs/pull/5928) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update proc-macro2 requirement from =1.0.85 to =1.0.86 [\#5927](https://github.com/apache/arrow-rs/pull/5927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- docs: Add object\_store\_opendal as related projects [\#5926](https://github.com/apache/arrow-rs/pull/5926) ([Xuanwo](https://github.com/Xuanwo)) +- Add eq benchmark for StringArray/StringViewArray [\#5924](https://github.com/apache/arrow-rs/pull/5924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Implement arrow-row encoding/decoding for view types [\#5922](https://github.com/apache/arrow-rs/pull/5922) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- fix\(ipc\): set correct row count when reading struct arrays with zero fields [\#5918](https://github.com/apache/arrow-rs/pull/5918) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Update zstd-sys requirement from \>=2.0.0, \<2.0.10 to \>=2.0.0, \<2.0.12 [\#5913](https://github.com/apache/arrow-rs/pull/5913) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: prevent potential out-of-range access in FixedSizeListArray [\#5902](https://github.com/apache/arrow-rs/pull/5902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([BubbleCal](https://github.com/BubbleCal)) +- Implement compare operations for view types [\#5900](https://github.com/apache/arrow-rs/pull/5900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- minor: use as\_primitive replace downcast\_ref [\#5898](https://github.com/apache/arrow-rs/pull/5898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- fix: Adjust FFI\_ArrowArray offset based on the offset of offset buffer [\#5895](https://github.com/apache/arrow-rs/pull/5895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- implement `new_null_array` for view types [\#5894](https://github.com/apache/arrow-rs/pull/5894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- chore: add view type single column tests [\#5891](https://github.com/apache/arrow-rs/pull/5891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ariesdevil](https://github.com/ariesdevil)) +- Minor: expose timestamp\_tz\_format for csv writing [\#5890](https://github.com/apache/arrow-rs/pull/5890) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tmi](https://github.com/tmi)) +- chore: implement parquet error handling for object\_store [\#5889](https://github.com/apache/arrow-rs/pull/5889) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([abhiaagarwal](https://github.com/abhiaagarwal)) +- Document when the ParquetRecordBatchReader will re-read metadata [\#5887](https://github.com/apache/arrow-rs/pull/5887) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add simple GC for view array types [\#5885](https://github.com/apache/arrow-rs/pull/5885) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Update for new clippy rules [\#5881](https://github.com/apache/arrow-rs/pull/5881) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- clean up ByteView construction [\#5879](https://github.com/apache/arrow-rs/pull/5879) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Avoid copy/allocation when read view types from parquet [\#5877](https://github.com/apache/arrow-rs/pull/5877) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Document parquet ArrowWriter type limitations [\#5875](https://github.com/apache/arrow-rs/pull/5875) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Benchmark for casting view to dict arrays \(and the reverse\) [\#5874](https://github.com/apache/arrow-rs/pull/5874) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Implement Take for Dense UnionArray [\#5873](https://github.com/apache/arrow-rs/pull/5873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- Improve performance of casting `StringView`/`BinaryView` to `DictionaryArray` [\#5872](https://github.com/apache/arrow-rs/pull/5872) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Improve performance of casting `DictionaryArray` to `StringViewArray` [\#5871](https://github.com/apache/arrow-rs/pull/5871) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- fix: msrv CI for object\_store [\#5866](https://github.com/apache/arrow-rs/pull/5866) ([korowa](https://github.com/korowa)) +- parquet: Fix warning about unused import [\#5865](https://github.com/apache/arrow-rs/pull/5865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([progval](https://github.com/progval)) +- Preallocate for `FixedSizeList` in `concat` [\#5862](https://github.com/apache/arrow-rs/pull/5862) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([judahrand](https://github.com/judahrand)) +- Faster primitive arrays encoding into row format [\#5858](https://github.com/apache/arrow-rs/pull/5858) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([korowa](https://github.com/korowa)) +- Added panic message to docs. [\#5857](https://github.com/apache/arrow-rs/pull/5857) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([SeeRightThroughMe](https://github.com/SeeRightThroughMe)) +- feat: call try\_merge recursively for list field [\#5852](https://github.com/apache/arrow-rs/pull/5852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mnpw](https://github.com/mnpw)) +- Minor: refine row selection example more [\#5850](https://github.com/apache/arrow-rs/pull/5850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Make RowSelection's from\_consecutive\_ranges public [\#5848](https://github.com/apache/arrow-rs/pull/5848) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([advancedxy](https://github.com/advancedxy)) +- Add exposing fields from parquet row [\#5842](https://github.com/apache/arrow-rs/pull/5842) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SHaaD94](https://github.com/SHaaD94)) +- Derive `Copy` for `TimeUnit` and `IntervalUnit` [\#5840](https://github.com/apache/arrow-rs/pull/5840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- feat: support reading OPTIONAL column in parquet\_derive [\#5717](https://github.com/apache/arrow-rs/pull/5717) ([double-free](https://github.com/double-free)) +- Add the ability for Maps to cast to another case where the field names are different [\#5703](https://github.com/apache/arrow-rs/pull/5703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HawaiianSpork](https://github.com/HawaiianSpork)) +## [52.0.0](https://github.com/apache/arrow-rs/tree/52.0.0) (2024-06-03) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/51.0.0...52.0.0) + +**Breaking changes:** + +- chore: Make binary\_mut kernel accept different type for second arg [\#5833](https://github.com/apache/arrow-rs/pull/5833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix\(flightsql\): remove Any encoding of `DoPutPreparedStatementResult` [\#5817](https://github.com/apache/arrow-rs/pull/5817) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([erratic-pattern](https://github.com/erratic-pattern)) +- Encode UUID as FixedLenByteArray in parquet\_derive [\#5773](https://github.com/apache/arrow-rs/pull/5773) ([conradludgate](https://github.com/conradludgate)) +- Structured interval types for `IntervalMonthDayNano` or `IntervalDayTime` \(\#3125\) \(\#5654\) [\#5769](https://github.com/apache/arrow-rs/pull/5769) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fallible stream for arrow-flight do\_exchange call \(\#3462\) [\#5698](https://github.com/apache/arrow-rs/pull/5698) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([opensourcegeek](https://github.com/opensourcegeek)) +- Update object\_store dependency in arrow to `0.10.0` [\#5675](https://github.com/apache/arrow-rs/pull/5675) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove deprecated JSON writer [\#5651](https://github.com/apache/arrow-rs/pull/5651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Change `UnionArray` constructors [\#5623](https://github.com/apache/arrow-rs/pull/5623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Update py03 from 0.20 to 0.21 [\#5566](https://github.com/apache/arrow-rs/pull/5566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Optionally require alignment when reading IPC, respect alignment when writing [\#5554](https://github.com/apache/arrow-rs/pull/5554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([hzuo](https://github.com/hzuo)) + +**Implemented enhancements:** + +- Serialize `Binary` and `LargeBinary` as HEX with JSON Writer [\#5783](https://github.com/apache/arrow-rs/issues/5783) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Some optimizations in arrow\_buffer::util::bit\_util do more harm than good [\#5771](https://github.com/apache/arrow-rs/issues/5771) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support skipping comments in CSV files [\#5758](https://github.com/apache/arrow-rs/issues/5758) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `parquet-derive` should be included in repository README. [\#5751](https://github.com/apache/arrow-rs/issues/5751) +- proposal: Make AsyncArrowWriter accepts AsyncFileWriter trait instead [\#5738](https://github.com/apache/arrow-rs/issues/5738) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Nested nullable fields do not get treated as nullable in data\_gen [\#5712](https://github.com/apache/arrow-rs/issues/5712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Optionally support flexible column lengths [\#5678](https://github.com/apache/arrow-rs/issues/5678) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow Flight SQL example server: do\_handshake should include auth header [\#5665](https://github.com/apache/arrow-rs/issues/5665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add support for the "r+" datatype in the C Data interface / `RunArray` [\#5631](https://github.com/apache/arrow-rs/issues/5631) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Serialize `FixedSizeBinary` as HEX with JSON Writer [\#5620](https://github.com/apache/arrow-rs/issues/5620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cleanup UnionArray Constructors [\#5613](https://github.com/apache/arrow-rs/issues/5613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Zero Copy Support [\#5593](https://github.com/apache/arrow-rs/issues/5593) +- ObjectStore bulk delete [\#5591](https://github.com/apache/arrow-rs/issues/5591) +- Retry on Broken Connection [\#5589](https://github.com/apache/arrow-rs/issues/5589) +- `StreamReader` is not zero-copy [\#5584](https://github.com/apache/arrow-rs/issues/5584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Create `ArrowReaderMetadata` from externalized metadata [\#5582](https://github.com/apache/arrow-rs/issues/5582) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make `filter` in `filter_leaves` API propagate error [\#5574](https://github.com/apache/arrow-rs/issues/5574) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `List` in `compare_op` [\#5572](https://github.com/apache/arrow-rs/issues/5572) +- Make FixedSizedList Json serializable [\#5568](https://github.com/apache/arrow-rs/issues/5568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-ord: Support sortting StructArray [\#5559](https://github.com/apache/arrow-rs/issues/5559) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add scientific notation decimal parsing in `parse_decimal` [\#5549](https://github.com/apache/arrow-rs/issues/5549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `take` kernel support for `StringViewArray` and `BinaryViewArray` [\#5511](https://github.com/apache/arrow-rs/issues/5511) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `filter` kernel support for `StringViewArray` and `BinaryViewArray` [\#5510](https://github.com/apache/arrow-rs/issues/5510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Display support for `StringViewArray` and `BinaryViewArray` [\#5509](https://github.com/apache/arrow-rs/issues/5509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow Flight format support for `StringViewArray` and `BinaryViewArray` [\#5507](https://github.com/apache/arrow-rs/issues/5507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- IPC format support for `StringViewArray` and `BinaryViewArray` [\#5506](https://github.com/apache/arrow-rs/issues/5506) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- List Row Encoding Sorts Incorrectly [\#5807](https://github.com/apache/arrow-rs/issues/5807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema Root Message Name Ignored by parquet-fromcsv [\#5804](https://github.com/apache/arrow-rs/issues/5804) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Compute data buffer length by using start and end values in offset buffer [\#5756](https://github.com/apache/arrow-rs/issues/5756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: ByteArrayEncoder allocates large unused FallbackEncoder for Parquet 2 [\#5755](https://github.com/apache/arrow-rs/issues/5755) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The CI pipeline `Archery test With other arrow` is broken [\#5742](https://github.com/apache/arrow-rs/issues/5742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unable to parse scientific notation string to decimal when scale is 0 [\#5739](https://github.com/apache/arrow-rs/issues/5739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Stateless prepared statements wrap `DoPutPreparedStatementResult` with `Any` which is differs from Go implementation [\#5731](https://github.com/apache/arrow-rs/issues/5731) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- "Rustdocs are clean \(amd64, nightly\)" CI check is failing [\#5725](https://github.com/apache/arrow-rs/issues/5725) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- "Archery test With other arrows" integration tests are failing [\#5719](https://github.com/apache/arrow-rs/issues/5719) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet\_derive: invalid examples/documentation [\#5687](https://github.com/apache/arrow-rs/issues/5687) +- Arrow FLight SQL: invalid location in get\_flight\_info\_prepared\_statement [\#5669](https://github.com/apache/arrow-rs/issues/5669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Rust Interval definition incorrect [\#5654](https://github.com/apache/arrow-rs/issues/5654) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DECIMAL regex in csv reader does not accept positive exponent specifier [\#5648](https://github.com/apache/arrow-rs/issues/5648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- panic when casting `ListArray` to `FixedSizeList` [\#5642](https://github.com/apache/arrow-rs/issues/5642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FixedSizeListArray::try\_new Errors on Entirely Null Array With Size 0 [\#5614](https://github.com/apache/arrow-rs/issues/5614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `parquet / Build wasm32 (pull_request)` CI check failing on main [\#5565](https://github.com/apache/arrow-rs/issues/5565) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Documentation fix: example in parquet/src/column/mod.rs is incorrect [\#5560](https://github.com/apache/arrow-rs/issues/5560) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- IPC code writes data with insufficient alignment [\#5553](https://github.com/apache/arrow-rs/issues/5553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Cannot access example Flight SQL Server from dbeaver [\#5540](https://github.com/apache/arrow-rs/issues/5540) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- parquet: "not yet implemented" error when codec is actually implemented but disabled [\#5520](https://github.com/apache/arrow-rs/issues/5520) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Minor: Improve arrow\_cast documentation [\#5825](https://github.com/apache/arrow-rs/pull/5825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Improve `ArrowReaderBuilder::with_row_selection` docs [\#5824](https://github.com/apache/arrow-rs/pull/5824) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Add examples for ColumnPath::from [\#5813](https://github.com/apache/arrow-rs/pull/5813) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Clarify docs on `EnabledStatistics` [\#5812](https://github.com/apache/arrow-rs/pull/5812) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add parquet-derive to repository README [\#5795](https://github.com/apache/arrow-rs/pull/5795) ([konjac](https://github.com/konjac)) +- Refine ParquetRecordBatchReaderBuilder docs [\#5774](https://github.com/apache/arrow-rs/pull/5774) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- docs: add sizing explanation to bloom filter docs in parquet [\#5705](https://github.com/apache/arrow-rs/pull/5705) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hiltontj](https://github.com/hiltontj)) + +**Closed issues:** + +- `binary_mut` kernel requires both args to be the same type \(which is inconsistent with `binary`\) [\#5818](https://github.com/apache/arrow-rs/issues/5818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic when displaying debug the results via log::info in the browser. [\#5599](https://github.com/apache/arrow-rs/issues/5599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- feat: impl \*Assign ops for types in arrow-buffer [\#5832](https://github.com/apache/arrow-rs/pull/5832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waynexia](https://github.com/waynexia)) +- Relax zstd-sys Version Pin [\#5829](https://github.com/apache/arrow-rs/pull/5829) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([waynexia](https://github.com/waynexia)) +- Minor: Document timestamp with/without cast behavior [\#5826](https://github.com/apache/arrow-rs/pull/5826) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: invalid examples/documentation in parquet\_derive doc [\#5823](https://github.com/apache/arrow-rs/pull/5823) ([Weijun-H](https://github.com/Weijun-H)) +- Check length of `FIXED_LEN_BYTE_ARRAY` for `uuid` logical parquet type [\#5821](https://github.com/apache/arrow-rs/pull/5821) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mbrobbel](https://github.com/mbrobbel)) +- Allow overriding the inferred parquet schema root [\#5814](https://github.com/apache/arrow-rs/pull/5814) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Revisit List Row Encoding \(\#5807\) [\#5811](https://github.com/apache/arrow-rs/pull/5811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.83 to =1.0.84 [\#5805](https://github.com/apache/arrow-rs/pull/5805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix typo continuation maker -\> marker [\#5802](https://github.com/apache/arrow-rs/pull/5802) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([djanderson](https://github.com/djanderson)) +- fix: serialization of decimal [\#5801](https://github.com/apache/arrow-rs/pull/5801) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Allow constructing ByteViewArray from existing blocks [\#5796](https://github.com/apache/arrow-rs/pull/5796) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Push SortOptions into DynComparator Allowing Nested Comparisons \(\#5426\) [\#5792](https://github.com/apache/arrow-rs/pull/5792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix incorrect URL to Parquet CPP types.h [\#5790](https://github.com/apache/arrow-rs/pull/5790) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.82 to =1.0.83 [\#5789](https://github.com/apache/arrow-rs/pull/5789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update prost-build requirement from =0.12.4 to =0.12.6 [\#5788](https://github.com/apache/arrow-rs/pull/5788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Refine parquet documentation on types and metadata [\#5786](https://github.com/apache/arrow-rs/pull/5786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- feat\(arrow-json\): encode `Binary` and `LargeBinary` types as hex when writing JSON [\#5785](https://github.com/apache/arrow-rs/pull/5785) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([hiltontj](https://github.com/hiltontj)) +- fix broken link to ballista crate in README.md [\#5784](https://github.com/apache/arrow-rs/pull/5784) ([navicore](https://github.com/navicore)) +- feat\(arrow-csv\): support encoding of binary in CSV writer [\#5782](https://github.com/apache/arrow-rs/pull/5782) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([hiltontj](https://github.com/hiltontj)) +- Fix documentation for parquet `parse_metadata`, `decode_metadata` and `decode_footer` [\#5781](https://github.com/apache/arrow-rs/pull/5781) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Support casting a `FixedSizedList[1]` to `T` [\#5779](https://github.com/apache/arrow-rs/pull/5779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sadboy](https://github.com/sadboy)) +- \[parquet\] Set the default size of BitWriter in DeltaBitPackEncoder to 1MB [\#5776](https://github.com/apache/arrow-rs/pull/5776) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +- Remove harmful table lookup optimization for bitmap operations [\#5772](https://github.com/apache/arrow-rs/pull/5772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HadrienG2](https://github.com/HadrienG2)) +- Remove deprecated comparison kernels \(\#4733\) [\#5768](https://github.com/apache/arrow-rs/pull/5768) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add environment variable definitions to run the nanoarrow integration tests [\#5764](https://github.com/apache/arrow-rs/pull/5764) ([paleolimbot](https://github.com/paleolimbot)) +- Downgrade to Rust 1.77 in integration pipeline to fix CI \(\#5719\) [\#5761](https://github.com/apache/arrow-rs/pull/5761) ([tustvold](https://github.com/tustvold)) +- Expose boolean builder contents [\#5760](https://github.com/apache/arrow-rs/pull/5760) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HadrienG2](https://github.com/HadrienG2)) +- Allow specifying comment character for CSV reader [\#5759](https://github.com/apache/arrow-rs/pull/5759) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([bbannier](https://github.com/bbannier)) +- Expose the null buffer of every builder that has one [\#5754](https://github.com/apache/arrow-rs/pull/5754) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HadrienG2](https://github.com/HadrienG2)) +- feat: Make AsyncArrowWriter accepts AsyncFileWriter [\#5753](https://github.com/apache/arrow-rs/pull/5753) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Xuanwo](https://github.com/Xuanwo)) +- Improve repository readme [\#5752](https://github.com/apache/arrow-rs/pull/5752) ([alamb](https://github.com/alamb)) +- Document object store release cadence [\#5750](https://github.com/apache/arrow-rs/pull/5750) ([alamb](https://github.com/alamb)) +- Compute data buffer length by using start and end values in offset buffer [\#5741](https://github.com/apache/arrow-rs/pull/5741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix: parse string of scientific notation to decimal when the scale is 0 [\#5740](https://github.com/apache/arrow-rs/pull/5740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Minor: avoid \(likely unreachable\) panic in FlightClient [\#5734](https://github.com/apache/arrow-rs/pull/5734) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update proc-macro2 requirement from =1.0.81 to =1.0.82 [\#5732](https://github.com/apache/arrow-rs/pull/5732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Improve error message for timestamp queries outside supported range [\#5730](https://github.com/apache/arrow-rs/pull/5730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Abdi-29](https://github.com/Abdi-29)) +- Refactor to share code between do\_put and do\_exchange calls [\#5728](https://github.com/apache/arrow-rs/pull/5728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([opensourcegeek](https://github.com/opensourcegeek)) +- Update brotli requirement from 5.0 to 6.0 [\#5726](https://github.com/apache/arrow-rs/pull/5726) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix `GenericListBuilder` test typo [\#5724](https://github.com/apache/arrow-rs/pull/5724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- Deprecate NullBuilder capacity, as it behaves in a surprising way [\#5721](https://github.com/apache/arrow-rs/pull/5721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HadrienG2](https://github.com/HadrienG2)) +- Fix nested nullability when randomly generating arrays [\#5713](https://github.com/apache/arrow-rs/pull/5713) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- Fix up clippy for Rust 1.78 [\#5710](https://github.com/apache/arrow-rs/pull/5710) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Support casting `StringView`/`BinaryView` --\> `StringArray`/`BinaryArray`. [\#5704](https://github.com/apache/arrow-rs/pull/5704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Fix documentation around handling of nulls in cmp kernels [\#5697](https://github.com/apache/arrow-rs/pull/5697) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Support casting `StringArray`/`BinaryArray` --\> `StringView` / `BinaryView` [\#5686](https://github.com/apache/arrow-rs/pull/5686) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Add support for flexible column lengths [\#5679](https://github.com/apache/arrow-rs/pull/5679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Posnet](https://github.com/Posnet)) +- Move ffi stream and utils from arrow to arrow-array [\#5670](https://github.com/apache/arrow-rs/pull/5670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Arrow Flight SQL example JDBC driver incompatibility [\#5666](https://github.com/apache/arrow-rs/pull/5666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([istvan-fodor](https://github.com/istvan-fodor)) +- Add `ListView` & `LargeListView` basic construction and validation [\#5664](https://github.com/apache/arrow-rs/pull/5664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- Update proc-macro2 requirement from =1.0.80 to =1.0.81 [\#5659](https://github.com/apache/arrow-rs/pull/5659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Modify decimal regex to accept positive exponent specifier [\#5649](https://github.com/apache/arrow-rs/pull/5649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jdcasale](https://github.com/jdcasale)) +- feat: JSON encoding of `FixedSizeList` [\#5646](https://github.com/apache/arrow-rs/pull/5646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([hiltontj](https://github.com/hiltontj)) +- Update proc-macro2 requirement from =1.0.79 to =1.0.80 [\#5644](https://github.com/apache/arrow-rs/pull/5644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: panic when casting `ListArray` to `FixedSizeList` [\#5643](https://github.com/apache/arrow-rs/pull/5643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonahgao](https://github.com/jonahgao)) +- Add more invalid utf8 parquet reader tests [\#5639](https://github.com/apache/arrow-rs/pull/5639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update brotli requirement from 4.0 to 5.0 [\#5637](https://github.com/apache/arrow-rs/pull/5637) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update flatbuffers requirement from 23.1.21 to 24.3.25 [\#5636](https://github.com/apache/arrow-rs/pull/5636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Increase `BinaryViewArray` test coverage [\#5635](https://github.com/apache/arrow-rs/pull/5635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- PrettyPrint support for `StringViewArray` and `BinaryViewArray` [\#5634](https://github.com/apache/arrow-rs/pull/5634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat\(ffi\): add run end encoded arrays [\#5632](https://github.com/apache/arrow-rs/pull/5632) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([notfilippo](https://github.com/notfilippo)) +- Accept parquet schemas without explicitly required Map keys [\#5630](https://github.com/apache/arrow-rs/pull/5630) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jupiter](https://github.com/jupiter)) +- Implement `filter` kernel for byte view arrays. [\#5624](https://github.com/apache/arrow-rs/pull/5624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- feat: encode FixedSizeBinary in JSON as hex string [\#5622](https://github.com/apache/arrow-rs/pull/5622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([hiltontj](https://github.com/hiltontj)) +- Update Flight crate README version [\#5621](https://github.com/apache/arrow-rs/pull/5621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([phillipleblanc](https://github.com/phillipleblanc)) +- feat: support reading and writing`StringView` and `BinaryView` in parquet \(part 1\) [\#5618](https://github.com/apache/arrow-rs/pull/5618) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Use FixedSizeListArray::new in FixedSizeListBuilder [\#5612](https://github.com/apache/arrow-rs/pull/5612) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- String to decimal conversion written using E/scientific notation [\#5611](https://github.com/apache/arrow-rs/pull/5611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Nekit2217](https://github.com/Nekit2217)) +- Account for Timezone when Casting Timestamp to Date32 [\#5605](https://github.com/apache/arrow-rs/pull/5605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Lordworms](https://github.com/Lordworms)) +- Update prost-build requirement from =0.12.3 to =0.12.4 [\#5604](https://github.com/apache/arrow-rs/pull/5604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix panic when displaying dates on 32-bit platforms [\#5603](https://github.com/apache/arrow-rs/pull/5603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ivanceras](https://github.com/ivanceras)) +- Implement `take` kernel for byte view array. [\#5602](https://github.com/apache/arrow-rs/pull/5602) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Add tests for Arrow Flight support for `StringViewArray` and `BinaryViewArray` [\#5601](https://github.com/apache/arrow-rs/pull/5601) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([XiangpengHao](https://github.com/XiangpengHao)) +- test: Add a test for RowFilter with nested type [\#5600](https://github.com/apache/arrow-rs/pull/5600) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Minor: Add docs for GenericBinaryBuilder, links to `GenericStringBuilder` [\#5597](https://github.com/apache/arrow-rs/pull/5597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Bump chrono-tz from 0.8 to 0.9 [\#5596](https://github.com/apache/arrow-rs/pull/5596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Update brotli requirement from 3.3 to 4.0 [\#5586](https://github.com/apache/arrow-rs/pull/5586) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add `UnionArray::into_parts` [\#5585](https://github.com/apache/arrow-rs/pull/5585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Expose ArrowReaderMetadata::try\_new [\#5583](https://github.com/apache/arrow-rs/pull/5583) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kylebarron](https://github.com/kylebarron)) +- Add `try_filter_leaves` to propagate error from filter closure [\#5575](https://github.com/apache/arrow-rs/pull/5575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- filter for run end array [\#5573](https://github.com/apache/arrow-rs/pull/5573) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fabianmurariu](https://github.com/fabianmurariu)) +- Pin zstd-sys to `v2.0.9` in parquet [\#5567](https://github.com/apache/arrow-rs/pull/5567) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Split arrow\_cast::cast::string into it's own submodule [\#5563](https://github.com/apache/arrow-rs/pull/5563) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([monkwire](https://github.com/monkwire)) +- Correct example code for column \(\#5560\) [\#5561](https://github.com/apache/arrow-rs/pull/5561) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zgershkoff](https://github.com/zgershkoff)) +- Split arrow\_cast::cast::dictionary into it's own submodule [\#5555](https://github.com/apache/arrow-rs/pull/5555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([monkwire](https://github.com/monkwire)) +- Split arrow\_cast::cast::decimal into it's own submodule [\#5552](https://github.com/apache/arrow-rs/pull/5552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([monkwire](https://github.com/monkwire)) +- Fix new clippy lints for Rust 1.77 [\#5544](https://github.com/apache/arrow-rs/pull/5544) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: correctly encode ticket [\#5543](https://github.com/apache/arrow-rs/pull/5543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([freddieptf](https://github.com/freddieptf)) +- feat: implemented with\_field\(\) for FixedSizeListBuilder [\#5541](https://github.com/apache/arrow-rs/pull/5541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([istvan-fodor](https://github.com/istvan-fodor)) +- Split arrow\_cast::cast::list into it's own submodule [\#5537](https://github.com/apache/arrow-rs/pull/5537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([monkwire](https://github.com/monkwire)) +- Bump black from 22.10.0 to 24.3.0 in /parquet/pytest [\#5535](https://github.com/apache/arrow-rs/pull/5535) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add OffsetBufferBuilder [\#5532](https://github.com/apache/arrow-rs/pull/5532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add IPC StreamDecoder [\#5531](https://github.com/apache/arrow-rs/pull/5531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- IPC format support for StringViewArray and BinaryViewArray [\#5525](https://github.com/apache/arrow-rs/pull/5525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- parquet: Use specific error variant when codec is disabled [\#5521](https://github.com/apache/arrow-rs/pull/5521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([progval](https://github.com/progval)) +- impl `From>` for `Vec` [\#5518](https://github.com/apache/arrow-rs/pull/5518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +## [51.0.0](https://github.com/apache/arrow-rs/tree/51.0.0) (2024-03-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/50.0.0...51.0.0) + +**Breaking changes:** + +- Remove internal buffering from AsyncArrowWriter \(\#5484\) [\#5485](https://github.com/apache/arrow-rs/pull/5485) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make ArrayBuilder also Sync [\#5353](https://github.com/apache/arrow-rs/pull/5353) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dvic](https://github.com/dvic)) +- Raw JSON writer \(~10x faster\) \(\#5314\) [\#5318](https://github.com/apache/arrow-rs/pull/5318) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Prototype Arrow over HTTP in Rust [\#5496](https://github.com/apache/arrow-rs/issues/5496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add DataType::ListView and DataType::LargeListView [\#5492](https://github.com/apache/arrow-rs/issues/5492) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve documentation around handling of dictionary arrays in arrow flight [\#5487](https://github.com/apache/arrow-rs/issues/5487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Better memory limiting in parquet `ArrowWriter` [\#5484](https://github.com/apache/arrow-rs/issues/5484) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support Creating Non-Nullable Lists and Maps within a Struct [\#5482](https://github.com/apache/arrow-rs/issues/5482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DISCUSSION\] Better borrow propagation \(e.g. `RecordBatch::schema()` to return `&SchemaRef` vs `SchemaRef`\) [\#5463](https://github.com/apache/arrow-rs/issues/5463) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Build Scalar with ArrayRef [\#5459](https://github.com/apache/arrow-rs/issues/5459) +- AsyncArrowWriter doesn't limit underlying ArrowWriter to respect buffer-size [\#5450](https://github.com/apache/arrow-rs/issues/5450) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Refine `Display` implementation for `FlightError` [\#5438](https://github.com/apache/arrow-rs/issues/5438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Better ergonomics for `FixedSizeList` and `LargeList` [\#5372](https://github.com/apache/arrow-rs/issues/5372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update Flight proto [\#5367](https://github.com/apache/arrow-rs/issues/5367) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support check similar datatype but with different magnitudes [\#5358](https://github.com/apache/arrow-rs/issues/5358) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Buffer memory usage for custom allocations is reported as 0 [\#5346](https://github.com/apache/arrow-rs/issues/5346) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Can the ArrayBuilder trait be made Sync? [\#5344](https://github.com/apache/arrow-rs/issues/5344) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- support cast 'UTF8' to `FixedSizeList` [\#5339](https://github.com/apache/arrow-rs/issues/5339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Creating Non-Nullable Lists with ListBuilder [\#5330](https://github.com/apache/arrow-rs/issues/5330) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `ParquetRecordBatchStreamBuilder::new()` panics instead of erroring out when opening a corrupted file [\#5315](https://github.com/apache/arrow-rs/issues/5315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Raw JSON Writer [\#5314](https://github.com/apache/arrow-rs/issues/5314) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for more fused boolean operations [\#5297](https://github.com/apache/arrow-rs/issues/5297) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: Allow disabling embed `ARROW_SCHEMA_META_KEY` added by the `ArrowWriter` [\#5296](https://github.com/apache/arrow-rs/issues/5296) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting strings like '2001-01-01 01:01:01' to Date32 [\#5280](https://github.com/apache/arrow-rs/issues/5280) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Temporal Extract/Date Part Kernel [\#5266](https://github.com/apache/arrow-rs/issues/5266) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support for extracting hours/minutes/seconds/etc. from `Time32`/`Time64` type in temporal kernels [\#5261](https://github.com/apache/arrow-rs/issues/5261) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: add method to get both the inner writer and the file metadata when closing SerializedFileWriter [\#5253](https://github.com/apache/arrow-rs/issues/5253) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Release arrow-rs version 50.0.0 [\#5234](https://github.com/apache/arrow-rs/issues/5234) + +**Fixed bugs:** + +- Empty String Parses as Zero in Unreleased Arrow [\#5504](https://github.com/apache/arrow-rs/issues/5504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unused import in nightly rust [\#5476](https://github.com/apache/arrow-rs/issues/5476) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Error `The data type type List .. has no natural order` when using `arrow::compute::lexsort_to_indices` with list and more than one column [\#5454](https://github.com/apache/arrow-rs/issues/5454) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Wrong size assertion in arrow\_buffer::builder::NullBufferBuilder::new\_from\_buffer [\#5445](https://github.com/apache/arrow-rs/issues/5445) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Inconsistency between comments and code implementation [\#5430](https://github.com/apache/arrow-rs/issues/5430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- OOB access in `Buffer::from_iter` [\#5412](https://github.com/apache/arrow-rs/issues/5412) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast kernel doesn't return null for string to integral cases when overflowing under safe option enabled [\#5397](https://github.com/apache/arrow-rs/issues/5397) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make ffi consume variable layout arrays with empty offsets [\#5391](https://github.com/apache/arrow-rs/issues/5391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RecordBatch conversion from pyarrow loses Schema's metadata [\#5354](https://github.com/apache/arrow-rs/issues/5354) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Debug output of Time32/Time64 arrays with invalid values has confusing nulls [\#5336](https://github.com/apache/arrow-rs/issues/5336) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Removing a column from a `RecordBatch` drops schema metadata [\#5327](https://github.com/apache/arrow-rs/issues/5327) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic when read an empty parquet file [\#5304](https://github.com/apache/arrow-rs/issues/5304) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- How to enable statistics for string columns? [\#5270](https://github.com/apache/arrow-rs/issues/5270) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `concat::tests::test_string_dictionary_merge failure` fails on Mac / has different results in different platforms [\#5255](https://github.com/apache/arrow-rs/issues/5255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Minor: Add doc comments to `GenericByteViewArray` [\#5512](https://github.com/apache/arrow-rs/pull/5512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve docs for logical and physical nulls even more [\#5434](https://github.com/apache/arrow-rs/pull/5434) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add example of converting RecordBatches to JSON objects [\#5364](https://github.com/apache/arrow-rs/pull/5364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- improve float to string cast by ~20%-40% [\#5401](https://github.com/apache/arrow-rs/pull/5401) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) + +**Closed issues:** + +- Add `StringViewArray` implementation and layout and basic construction + tests [\#5469](https://github.com/apache/arrow-rs/issues/5469) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `DataType::Utf8View` and `DataType::BinaryView` [\#5468](https://github.com/apache/arrow-rs/issues/5468) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Deprecate array\_to\_json\_array [\#5515](https://github.com/apache/arrow-rs/pull/5515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix integer parsing of empty strings \(\#5504\) [\#5505](https://github.com/apache/arrow-rs/pull/5505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: clarifying comments in struct\_builder.rs \#5494 [\#5499](https://github.com/apache/arrow-rs/pull/5499) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([istvan-fodor](https://github.com/istvan-fodor)) +- Update proc-macro2 requirement from =1.0.78 to =1.0.79 [\#5498](https://github.com/apache/arrow-rs/pull/5498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add DataType::ListView and DataType::LargeListView [\#5493](https://github.com/apache/arrow-rs/pull/5493) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- Better document parquet pushdown [\#5491](https://github.com/apache/arrow-rs/pull/5491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix NullBufferBuilder::new\_from\_buffer wrong size assertion [\#5489](https://github.com/apache/arrow-rs/pull/5489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- Support dictionary encoding in structures for `FlightDataEncoder`, add documentation for `arrow_flight::encode::Dictionary` [\#5488](https://github.com/apache/arrow-rs/pull/5488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Add MapBuilder::with\_values\_field to support non-nullable values \(\#5482\) [\#5483](https://github.com/apache/arrow-rs/pull/5483) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lasantosr](https://github.com/lasantosr)) +- feat: initial support string\_view and binary\_view, supports layout and basic construction + tests [\#5481](https://github.com/apache/arrow-rs/pull/5481) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ariesdevil](https://github.com/ariesdevil)) +- Add more comprehensive documentation on testing and benchmarking to CONTRIBUTING.md [\#5478](https://github.com/apache/arrow-rs/pull/5478) ([monkwire](https://github.com/monkwire)) +- Remove unused import detected by nightly rust [\#5477](https://github.com/apache/arrow-rs/pull/5477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Add RecordBatch::schema\_ref [\#5474](https://github.com/apache/arrow-rs/pull/5474) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([monkwire](https://github.com/monkwire)) +- Provide access to inner Write for parquet writers [\#5471](https://github.com/apache/arrow-rs/pull/5471) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add DataType::Utf8View and DataType::BinaryView [\#5470](https://github.com/apache/arrow-rs/pull/5470) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Update base64 requirement from 0.21 to 0.22 [\#5467](https://github.com/apache/arrow-rs/pull/5467) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Minor: Fix formatting typo in `Field::new_list_field` [\#5464](https://github.com/apache/arrow-rs/pull/5464) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix test\_string\_dictionary\_merge \(\#5255\) [\#5461](https://github.com/apache/arrow-rs/pull/5461) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use Vec::from\_iter in Buffer::from\_iter [\#5460](https://github.com/apache/arrow-rs/pull/5460) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- Document parquet writer memory limiting \(\#5450\) [\#5457](https://github.com/apache/arrow-rs/pull/5457) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Document UnionArray Panics [\#5456](https://github.com/apache/arrow-rs/pull/5456) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Kikkon](https://github.com/Kikkon)) +- fix: lexsort\_to\_indices unsupported mixed types with list [\#5455](https://github.com/apache/arrow-rs/pull/5455) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Refine `Display` and `Source` implementation for error types [\#5439](https://github.com/apache/arrow-rs/pull/5439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([BugenZhao](https://github.com/BugenZhao)) +- Improve debug output of Time32/Time64 arrays [\#5428](https://github.com/apache/arrow-rs/pull/5428) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([monkwire](https://github.com/monkwire)) +- Miri fix: Rename invalid\_mut to without\_provenance\_mut [\#5418](https://github.com/apache/arrow-rs/pull/5418) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Ensure addition/multiplications in when allocating buffers don't overflow [\#5417](https://github.com/apache/arrow-rs/pull/5417) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Update Flight proto: PollFlightInfo & expiration time [\#5413](https://github.com/apache/arrow-rs/pull/5413) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Jefffrey](https://github.com/Jefffrey)) +- Add tests for serializing lists of dictionary encoded values to json [\#5399](https://github.com/apache/arrow-rs/pull/5399) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Return null for overflow when casting string to integer under safe option enabled [\#5398](https://github.com/apache/arrow-rs/pull/5398) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Propagate error instead of panic for `take_bytes` [\#5395](https://github.com/apache/arrow-rs/pull/5395) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve like kernel by ~2% [\#5390](https://github.com/apache/arrow-rs/pull/5390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Enable running arrow-array and arrow-arith with miri and avoid strict provenance warning [\#5387](https://github.com/apache/arrow-rs/pull/5387) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Update to chrono 0.4.34 [\#5385](https://github.com/apache/arrow-rs/pull/5385) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return error instead of panic when reading invalid Parquet metadata [\#5382](https://github.com/apache/arrow-rs/pull/5382) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mmaitre314](https://github.com/mmaitre314)) +- Update tonic requirement from 0.10.0 to 0.11.0 [\#5380](https://github.com/apache/arrow-rs/pull/5380) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update tonic-build requirement from =0.10.2 to =0.11.0 [\#5379](https://github.com/apache/arrow-rs/pull/5379) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix latest clippy lints [\#5376](https://github.com/apache/arrow-rs/pull/5376) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: utility functions for creating `FixedSizeList` and `LargeList` dtypes [\#5373](https://github.com/apache/arrow-rs/pull/5373) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([universalmind303](https://github.com/universalmind303)) +- Minor\(docs\): update master to main for DataFusion/Ballista [\#5363](https://github.com/apache/arrow-rs/pull/5363) ([caicancai](https://github.com/caicancai)) +- Return an error instead of a panic when reading a corrupted Parquet file with mismatched column counts [\#5362](https://github.com/apache/arrow-rs/pull/5362) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mmaitre314](https://github.com/mmaitre314)) +- feat: support casting FixedSizeList with new child type [\#5360](https://github.com/apache/arrow-rs/pull/5360) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Add more debugging info to StructBuilder validate\_content [\#5357](https://github.com/apache/arrow-rs/pull/5357) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- pyarrow: Preserve RecordBatch's schema metadata [\#5355](https://github.com/apache/arrow-rs/pull/5355) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([atwam](https://github.com/atwam)) +- Mark Encoding::BIT\_PACKED as deprecated and document its compatibility issues [\#5348](https://github.com/apache/arrow-rs/pull/5348) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- Track the size of custom allocations for use via Array::get\_buffer\_memory\_size [\#5347](https://github.com/apache/arrow-rs/pull/5347) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- fix: Return an error on type mismatch rather than panic \(\#4995\) [\#5341](https://github.com/apache/arrow-rs/pull/5341) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carols10cents](https://github.com/carols10cents)) +- Minor: support cast values to fixedsizelist [\#5340](https://github.com/apache/arrow-rs/pull/5340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Enhance Time32/Time64 support in date\_part [\#5337](https://github.com/apache/arrow-rs/pull/5337) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- feat: add `take_record_batch`. [\#5333](https://github.com/apache/arrow-rs/pull/5333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Add ListBuilder::with\_field to support non nullable list fields \(\#5330\) [\#5331](https://github.com/apache/arrow-rs/pull/5331) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't omit schema metadata when removing column [\#5328](https://github.com/apache/arrow-rs/pull/5328) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- Update proc-macro2 requirement from =1.0.76 to =1.0.78 [\#5324](https://github.com/apache/arrow-rs/pull/5324) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Enhance Date64 type documentation [\#5323](https://github.com/apache/arrow-rs/pull/5323) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- fix panic when decode a group with no child [\#5322](https://github.com/apache/arrow-rs/pull/5322) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Liyixin95](https://github.com/Liyixin95)) +- Minor/Doc Expand FlightSqlServiceClient::handshake doc [\#5321](https://github.com/apache/arrow-rs/pull/5321) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([devinjdangelo](https://github.com/devinjdangelo)) +- Refactor temporal extract date part kernels [\#5319](https://github.com/apache/arrow-rs/pull/5319) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Add JSON writer benchmarks \(\#5314\) [\#5317](https://github.com/apache/arrow-rs/pull/5317) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/cache from 3 to 4 [\#5308](https://github.com/apache/arrow-rs/pull/5308) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Avro block decompression [\#5306](https://github.com/apache/arrow-rs/pull/5306) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Result into error in case of endianness mismatches [\#5301](https://github.com/apache/arrow-rs/pull/5301) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pangiole](https://github.com/pangiole)) +- parquet: Add ArrowWriterOptions to skip embedding the arrow metadata [\#5299](https://github.com/apache/arrow-rs/pull/5299) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([evenyag](https://github.com/evenyag)) +- Add support for more fused boolean operations [\#5298](https://github.com/apache/arrow-rs/pull/5298) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([RTEnzyme](https://github.com/RTEnzyme)) +- Support Parquet Byte Stream Split Encoding [\#5293](https://github.com/apache/arrow-rs/pull/5293) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mwlon](https://github.com/mwlon)) +- Extend string parsing support for Date32 [\#5282](https://github.com/apache/arrow-rs/pull/5282) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Bring some methods over from ArrowWriter to the async version [\#5251](https://github.com/apache/arrow-rs/pull/5251) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +## [50.0.0](https://github.com/apache/arrow-rs/tree/50.0.0) (2024-01-08) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/49.0.0...50.0.0) + +**Breaking changes:** + +- Make regexp\_match take scalar pattern and flag [\#5245](https://github.com/apache/arrow-rs/pull/5245) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use Vec in ColumnReader \(\#5177\) [\#5193](https://github.com/apache/arrow-rs/pull/5193) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove SIMD Feature [\#5184](https://github.com/apache/arrow-rs/pull/5184) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use Total Ordering for Aggregates and Refactor for Better Auto-Vectorization [\#5100](https://github.com/apache/arrow-rs/pull/5100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Allow the `zip` compute function to operator on `Scalar` values via `Datum` [\#5086](https://github.com/apache/arrow-rs/pull/5086) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Nathan-Fenner](https://github.com/Nathan-Fenner)) +- Improve C Data Interface and Add Integration Testing Entrypoints [\#5080](https://github.com/apache/arrow-rs/pull/5080) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pitrou](https://github.com/pitrou)) +- Parquet: read/write f16 for Arrow [\#5003](https://github.com/apache/arrow-rs/pull/5003) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) + +**Implemented enhancements:** + +- Support get offsets or blocks info from arrow file. [\#5252](https://github.com/apache/arrow-rs/issues/5252) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make regexp\_match take scalar pattern and flag [\#5246](https://github.com/apache/arrow-rs/issues/5246) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cannot access pen state website on arrow-row [\#5238](https://github.com/apache/arrow-rs/issues/5238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RecordBatch with\_schema's error message is hard to read [\#5227](https://github.com/apache/arrow-rs/issues/5227) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support cast between StructArray. [\#5219](https://github.com/apache/arrow-rs/issues/5219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove nightly-only simd feature and related code in ArrowNumericType [\#5185](https://github.com/apache/arrow-rs/issues/5185) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use Vec instead of Slice in ColumnReader [\#5177](https://github.com/apache/arrow-rs/issues/5177) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Request to Memmap Arrow IPC files on disk [\#5153](https://github.com/apache/arrow-rs/issues/5153) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- GenericColumnReader::read\_records Yields Truncated Records [\#5150](https://github.com/apache/arrow-rs/issues/5150) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Nested Schema Projection [\#5148](https://github.com/apache/arrow-rs/issues/5148) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support specifying `quote` and `escape` in Csv `WriterBuilder` [\#5146](https://github.com/apache/arrow-rs/issues/5146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting of Float16 with other numeric types [\#5138](https://github.com/apache/arrow-rs/issues/5138) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet: read parquet metadata with page index in async and with size hints [\#5129](https://github.com/apache/arrow-rs/issues/5129) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Cast from floating/timestamp to timestamp/floating [\#5122](https://github.com/apache/arrow-rs/issues/5122) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Casting List To/From LargeList in Cast Kernel [\#5113](https://github.com/apache/arrow-rs/issues/5113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Expose a path for converting `bytes::Bytes` into `arrow_buffer::Buffer` without copy [\#5104](https://github.com/apache/arrow-rs/issues/5104) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- API inconsistency of ListBuilder make it hard to use as nested builder [\#5098](https://github.com/apache/arrow-rs/issues/5098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet: don't truncate min/max statistics for float16 and decimal when writing file [\#5075](https://github.com/apache/arrow-rs/issues/5075) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: derive boundary order when writing columns [\#5074](https://github.com/apache/arrow-rs/issues/5074) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support new Arrow PyCapsule Interface for Python FFI [\#5067](https://github.com/apache/arrow-rs/issues/5067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `48.0.1 ` arrow patch release [\#5050](https://github.com/apache/arrow-rs/issues/5050) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Binary columns do not receive truncated statistics [\#5037](https://github.com/apache/arrow-rs/issues/5037) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Re-evaluate Explicit SIMD Aggregations [\#5032](https://github.com/apache/arrow-rs/issues/5032) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Min/Max Kernels Should Use Total Ordering [\#5031](https://github.com/apache/arrow-rs/issues/5031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow `zip` compute kernel to take `Scalar` / `Datum` [\#5011](https://github.com/apache/arrow-rs/issues/5011) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Float16/Half-float logical type to Parquet [\#4986](https://github.com/apache/arrow-rs/issues/4986) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- feat: cast \(Large\)List to FixedSizeList [\#5081](https://github.com/apache/arrow-rs/pull/5081) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Update Parquet Encoding Documentation [\#5051](https://github.com/apache/arrow-rs/issues/5051) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- json schema inference can't handle null field turned into object field in subsequent rows [\#5215](https://github.com/apache/arrow-rs/issues/5215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Invalid trailing content after `Z` in timezone is ignored [\#5182](https://github.com/apache/arrow-rs/issues/5182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Take panics on a fixed size list array when given null indices [\#5169](https://github.com/apache/arrow-rs/issues/5169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- EnabledStatistics::Page does not take effect on ByteArrayEncoder [\#5162](https://github.com/apache/arrow-rs/issues/5162) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: ColumnOrder not being written when writing parquet files [\#5152](https://github.com/apache/arrow-rs/issues/5152) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: Interval columns shouldn't write min/max stats [\#5145](https://github.com/apache/arrow-rs/issues/5145) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- cast `Utf8` to decimal failure [\#5127](https://github.com/apache/arrow-rs/issues/5127) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- coerce\_primitive not honored when decoding from serde object [\#5095](https://github.com/apache/arrow-rs/issues/5095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unsound MutableArrayData Constructor [\#5091](https://github.com/apache/arrow-rs/issues/5091) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RowGroupReader.get\_row\_iter\(\) fails with Path ColumnPath not found [\#5064](https://github.com/apache/arrow-rs/issues/5064) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- cast format 'yyyymmdd' to Date32 give a error [\#5044](https://github.com/apache/arrow-rs/issues/5044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Performance improvements:** + +- ArrowArrayStreamReader imports FFI\_ArrowSchema on each iteration [\#5103](https://github.com/apache/arrow-rs/issues/5103) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Working example of list\_flights with ObjectStore [\#5116](https://github.com/apache/arrow-rs/issues/5116) +- \(object\_store\) Error broken pipe on S3 multipart upload [\#5106](https://github.com/apache/arrow-rs/issues/5106) + +**Merged pull requests:** + +- Update parquet object\_store dependency to 0.9.0 [\#5290](https://github.com/apache/arrow-rs/pull/5290) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.75 to =1.0.76 [\#5289](https://github.com/apache/arrow-rs/pull/5289) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Enable JS tests again [\#5287](https://github.com/apache/arrow-rs/pull/5287) ([domoritz](https://github.com/domoritz)) +- Update proc-macro2 requirement from =1.0.74 to =1.0.75 [\#5279](https://github.com/apache/arrow-rs/pull/5279) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.73 to =1.0.74 [\#5271](https://github.com/apache/arrow-rs/pull/5271) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.71 to =1.0.73 [\#5265](https://github.com/apache/arrow-rs/pull/5265) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update docs for datatypes [\#5260](https://github.com/apache/arrow-rs/pull/5260) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Don't suppress errors in ArrowArrayStreamReader [\#5256](https://github.com/apache/arrow-rs/pull/5256) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add IPC FileDecoder [\#5249](https://github.com/apache/arrow-rs/pull/5249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- optimize the next function of ArrowArrayStreamReader [\#5248](https://github.com/apache/arrow-rs/pull/5248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- ci: Fail Miri CI on first failure [\#5243](https://github.com/apache/arrow-rs/pull/5243) ([Jefffrey](https://github.com/Jefffrey)) +- Remove 'unwrap' from Result [\#5241](https://github.com/apache/arrow-rs/pull/5241) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Update arrow-row docs URL [\#5239](https://github.com/apache/arrow-rs/pull/5239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([thomas-k-cameron](https://github.com/thomas-k-cameron)) +- Improve regexp kernels performance by avoiding cloning Regex [\#5235](https://github.com/apache/arrow-rs/pull/5235) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.70 to =1.0.71 [\#5231](https://github.com/apache/arrow-rs/pull/5231) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Minor: Improve comments and errors for ArrowPredicate [\#5230](https://github.com/apache/arrow-rs/pull/5230) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Bump actions/upload-pages-artifact from 2 to 3 [\#5229](https://github.com/apache/arrow-rs/pull/5229) ([dependabot[bot]](https://github.com/apps/dependabot)) +- make with\_schema's error more readable [\#5228](https://github.com/apache/arrow-rs/pull/5228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([shuoli84](https://github.com/shuoli84)) +- Use `try_new` when casting between structs to propagate error [\#5226](https://github.com/apache/arrow-rs/pull/5226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat\(cast\): support cast between struct [\#5221](https://github.com/apache/arrow-rs/pull/5221) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([my-vegetable-has-exploded](https://github.com/my-vegetable-has-exploded)) +- Add `entries` to `MapBuilder` to return both key and value array builders [\#5218](https://github.com/apache/arrow-rs/pull/5218) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix\(json\): fix inferring object after field was null [\#5216](https://github.com/apache/arrow-rs/pull/5216) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Support MapBuilder in make\_builder [\#5210](https://github.com/apache/arrow-rs/pull/5210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- impl `From>` for `ScalarBuffer` [\#5203](https://github.com/apache/arrow-rs/pull/5203) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- impl `From>` for `Buffer` [\#5202](https://github.com/apache/arrow-rs/pull/5202) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- impl `From>` for `ScalarBuffer` [\#5201](https://github.com/apache/arrow-rs/pull/5201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- feat: Support quote and escape in Csv WriterBuilder [\#5196](https://github.com/apache/arrow-rs/pull/5196) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([my-vegetable-has-exploded](https://github.com/my-vegetable-has-exploded)) +- chore: simplify cast\_string\_to\_interval [\#5195](https://github.com/apache/arrow-rs/pull/5195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Clarify interval comparison behavior with documentation and tests [\#5192](https://github.com/apache/arrow-rs/pull/5192) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `BooleanArray::into_parts` method [\#5191](https://github.com/apache/arrow-rs/pull/5191) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Fix deprecated note for `Buffer::from_raw_parts` [\#5190](https://github.com/apache/arrow-rs/pull/5190) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Fix: Ensure Timestamp Parsing Rejects Characters After 'Z [\#5189](https://github.com/apache/arrow-rs/pull/5189) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([razeghi71](https://github.com/razeghi71)) +- Simplify parquet statistics generation [\#5183](https://github.com/apache/arrow-rs/pull/5183) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet: Ensure page statistics are written only when conifgured from the Arrow Writer [\#5181](https://github.com/apache/arrow-rs/pull/5181) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +- Blockwise IO in IPC FileReader \(\#5153\) [\#5179](https://github.com/apache/arrow-rs/pull/5179) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Replace ScalarBuffer in Parquet with Vec \(\#1849\) \(\#5177\) [\#5178](https://github.com/apache/arrow-rs/pull/5178) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Bump actions/setup-python from 4 to 5 [\#5175](https://github.com/apache/arrow-rs/pull/5175) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add `LargeListBuilder` to `make_builder` [\#5171](https://github.com/apache/arrow-rs/pull/5171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix: ensure take\_fixed\_size\_list can handle null indices [\#5170](https://github.com/apache/arrow-rs/pull/5170) ([westonpace](https://github.com/westonpace)) +- Removing redundant `as casts` in parquet [\#5168](https://github.com/apache/arrow-rs/pull/5168) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- Bump actions/labeler from 4.3.0 to 5.0.0 [\#5167](https://github.com/apache/arrow-rs/pull/5167) ([dependabot[bot]](https://github.com/apps/dependabot)) +- improve: make RunArray displayable [\#5166](https://github.com/apache/arrow-rs/pull/5166) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yukkit](https://github.com/yukkit)) +- ci: Add cargo audit CI action [\#5160](https://github.com/apache/arrow-rs/pull/5160) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Parquet: write column\_orders in FileMetaData [\#5158](https://github.com/apache/arrow-rs/pull/5158) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Adding `is_null` datatype shortcut method [\#5157](https://github.com/apache/arrow-rs/pull/5157) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Parquet: don't truncate f16/decimal min/max stats [\#5154](https://github.com/apache/arrow-rs/pull/5154) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Support nested schema projection \(\#5148\) [\#5149](https://github.com/apache/arrow-rs/pull/5149) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Parquet: omit min/max for interval columns when writing stats [\#5147](https://github.com/apache/arrow-rs/pull/5147) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Deprecate Fields::remove and Schema::remove [\#5144](https://github.com/apache/arrow-rs/pull/5144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support casting of Float16 with other numeric types [\#5139](https://github.com/apache/arrow-rs/pull/5139) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Parquet: Make `MetadataLoader` public [\#5137](https://github.com/apache/arrow-rs/pull/5137) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +- Add FileReaderBuilder for arrow-ipc to allow reading large no. of column files [\#5136](https://github.com/apache/arrow-rs/pull/5136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Parquet: clear metadata and project fields of ParquetRecordBatchStream::schema [\#5135](https://github.com/apache/arrow-rs/pull/5135) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- JSON: write struct array nulls as null [\#5133](https://github.com/apache/arrow-rs/pull/5133) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Update proc-macro2 requirement from =1.0.69 to =1.0.70 [\#5131](https://github.com/apache/arrow-rs/pull/5131) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix negative decimal string [\#5128](https://github.com/apache/arrow-rs/pull/5128) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup list casting and support nested lists \(\#5113\) [\#5124](https://github.com/apache/arrow-rs/pull/5124) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cast from numeric/timestamp to timestamp/numeric [\#5123](https://github.com/apache/arrow-rs/pull/5123) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve cast docs [\#5114](https://github.com/apache/arrow-rs/pull/5114) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.12.2 to =0.12.3 [\#5112](https://github.com/apache/arrow-rs/pull/5112) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Parquet: derive boundary order when writing [\#5110](https://github.com/apache/arrow-rs/pull/5110) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Implementing `ArrayBuilder` for `Box` [\#5109](https://github.com/apache/arrow-rs/pull/5109) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix 'ColumnPath not found' error reading Parquet files with nested REPEATED fields [\#5102](https://github.com/apache/arrow-rs/pull/5102) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mmaitre314](https://github.com/mmaitre314)) +- fix: coerce\_primitive for serde decoded data [\#5101](https://github.com/apache/arrow-rs/pull/5101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- Extend aggregation benchmarks [\#5096](https://github.com/apache/arrow-rs/pull/5096) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Expand parquet crate overview doc [\#5093](https://github.com/apache/arrow-rs/pull/5093) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mmaitre314](https://github.com/mmaitre314)) +- Ensure arrays passed to MutableArrayData have same type \(\#5091\) [\#5092](https://github.com/apache/arrow-rs/pull/5092) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.12.1 to =0.12.2 [\#5088](https://github.com/apache/arrow-rs/pull/5088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add FFI from\_raw [\#5082](https://github.com/apache/arrow-rs/pull/5082) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- \[fix \#5044\] Support converting 'yyyymmdd' format to date [\#5078](https://github.com/apache/arrow-rs/pull/5078) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Tangruilin](https://github.com/Tangruilin)) +- Enable truncation of binary statistics columns [\#5076](https://github.com/apache/arrow-rs/pull/5076) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([emcake](https://github.com/emcake)) +## [49.0.0](https://github.com/apache/arrow-rs/tree/49.0.0) (2023-11-07) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/48.0.0...49.0.0) + +**Breaking changes:** + +- Return row count when inferring schema from JSON [\#5008](https://github.com/apache/arrow-rs/pull/5008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asayers](https://github.com/asayers)) +- Update object\_store 0.8.0 [\#5043](https://github.com/apache/arrow-rs/pull/5043) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Cast from integer/timestamp to timestamp/integer [\#5039](https://github.com/apache/arrow-rs/issues/5039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting from integer to binary [\#5014](https://github.com/apache/arrow-rs/issues/5014) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Return row count when inferring schema from JSON [\#5007](https://github.com/apache/arrow-rs/issues/5007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] Allow custom commands in get-flight-info [\#4996](https://github.com/apache/arrow-rs/issues/4996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support `RecordBatch::remove_column()` and `Schema::remove_field()` [\#4952](https://github.com/apache/arrow-rs/issues/4952) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow_json`: support `binary` deserialization [\#4945](https://github.com/apache/arrow-rs/issues/4945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support StructArray in Cast Kernel [\#4908](https://github.com/apache/arrow-rs/issues/4908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- There exists a `ParquetRecordWriter` proc macro in `parquet_derive`, but `ParquetRecordReader` is missing [\#4772](https://github.com/apache/arrow-rs/issues/4772) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Regression when serializing large json numbers [\#5038](https://github.com/apache/arrow-rs/issues/5038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RowSelection::intersection Produces Invalid RowSelection [\#5036](https://github.com/apache/arrow-rs/issues/5036) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Incorrect comment on arrow::compute::kernels::sort::sort\_to\_indices [\#5029](https://github.com/apache/arrow-rs/issues/5029) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- chore: Update docs to refer to non deprecated function \(`partition`\) [\#5027](https://github.com/apache/arrow-rs/pull/5027) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Parquet f32/f64 handle signed zeros in statistics [\#5048](https://github.com/apache/arrow-rs/pull/5048) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Fix serialization of large integers in JSON \(\#5038\) [\#5042](https://github.com/apache/arrow-rs/pull/5042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix RowSelection::intersection \(\#5036\) [\#5041](https://github.com/apache/arrow-rs/pull/5041) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cast from integer/timestamp to timestamp/integer [\#5040](https://github.com/apache/arrow-rs/pull/5040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- doc: update comment on sort\_to\_indices to reflect correct ordering [\#5033](https://github.com/apache/arrow-rs/pull/5033) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([westonpace](https://github.com/westonpace)) +- Support casting from integer to binary [\#5015](https://github.com/apache/arrow-rs/pull/5015) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update tracing-log requirement from 0.1 to 0.2 [\#4998](https://github.com/apache/arrow-rs/pull/4998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat\(flight-sql\): Allow custom commands in get-flight-info [\#4997](https://github.com/apache/arrow-rs/pull/4997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- \[MINOR\] No need to jump to web pages [\#4994](https://github.com/apache/arrow-rs/pull/4994) ([smallzhongfeng](https://github.com/smallzhongfeng)) +- Support metadata in SchemaBuilder [\#4987](https://github.com/apache/arrow-rs/pull/4987) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: support schema change by idx and reverse [\#4985](https://github.com/apache/arrow-rs/pull/4985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- Bump actions/setup-node from 3 to 4 [\#4982](https://github.com/apache/arrow-rs/pull/4982) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add arrow\_cast::base64 and document usage in arrow\_json [\#4975](https://github.com/apache/arrow-rs/pull/4975) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add SchemaBuilder::remove \(\#4952\) [\#4964](https://github.com/apache/arrow-rs/pull/4964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `Field::remove()`, `Schema::remove()`, and `RecordBatch::remove_column()` APIs [\#4959](https://github.com/apache/arrow-rs/pull/4959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Folyd](https://github.com/Folyd)) +- Add `RecordReader` trait and proc macro to implement it for a struct [\#4773](https://github.com/apache/arrow-rs/pull/4773) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Joseph-Rance](https://github.com/Joseph-Rance)) +## [48.0.0](https://github.com/apache/arrow-rs/tree/48.0.0) (2023-10-18) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/47.0.0...48.0.0) + +**Breaking changes:** + +- Evaluate null\_regex for string type in csv \(now such values will be parsed as `Null` rather than `""`\) [\#4942](https://github.com/apache/arrow-rs/pull/4942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haohuaijin](https://github.com/haohuaijin)) +- fix\(csv\)!: infer null for empty column. [\#4910](https://github.com/apache/arrow-rs/pull/4910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- feat: log headers/trailers in flight CLI \(+ minor fixes\) [\#4898](https://github.com/apache/arrow-rs/pull/4898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- fix\(arrow-json\)!: include null fields in schema inference with a type of Null [\#4894](https://github.com/apache/arrow-rs/pull/4894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Mark OnCloseRowGroup Send [\#4893](https://github.com/apache/arrow-rs/pull/4893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) +- Specialize Thrift Decoding \(~40% Faster\) \(\#4891\) [\#4892](https://github.com/apache/arrow-rs/pull/4892) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make ArrowRowGroupWriter Public and SerializedRowGroupWriter Send [\#4850](https://github.com/apache/arrow-rs/pull/4850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) + +**Implemented enhancements:** + +- Allow schema fields to merge with `Null` datatype [\#4901](https://github.com/apache/arrow-rs/issues/4901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add option to FlightDataEncoder to always send dictionaries [\#4895](https://github.com/apache/arrow-rs/issues/4895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Rework Thrift Encoding / Decoding of Parquet Metadata [\#4891](https://github.com/apache/arrow-rs/issues/4891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Plans for supporting Extension Array to support Fixed shape tensor Array [\#4890](https://github.com/apache/arrow-rs/issues/4890) +- Implement Take for UnionArray [\#4882](https://github.com/apache/arrow-rs/issues/4882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check precision overflow for casting floating to decimal [\#4865](https://github.com/apache/arrow-rs/issues/4865) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace lexical [\#4774](https://github.com/apache/arrow-rs/issues/4774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add read access to settings in `csv::WriterBuilder` [\#4735](https://github.com/apache/arrow-rs/issues/4735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve the performance of "DictionaryValue" row encoding [\#4712](https://github.com/apache/arrow-rs/issues/4712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Should we make blank values and empty string to `None` in csv? [\#4939](https://github.com/apache/arrow-rs/issues/4939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] SubstraitPlan structure is not exported [\#4932](https://github.com/apache/arrow-rs/issues/4932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Loading page index breaks skipping of pages with nested types [\#4921](https://github.com/apache/arrow-rs/issues/4921) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV schema inference assumes `Utf8` for empty columns [\#4903](https://github.com/apache/arrow-rs/issues/4903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: Field Ids are not read from a Parquet file without serialized arrow schema [\#4877](https://github.com/apache/arrow-rs/issues/4877) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- make\_primitive\_scalar function loses DataType Internal information [\#4851](https://github.com/apache/arrow-rs/issues/4851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder doesn't handle nulls correctly for empty structs [\#4842](https://github.com/apache/arrow-rs/issues/4842) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `NullArray::is_null()` returns `false` incorrectly [\#4835](https://github.com/apache/arrow-rs/issues/4835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cast\_string\_to\_decimal should check precision overflow [\#4829](https://github.com/apache/arrow-rs/issues/4829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Null fields are omitted by `infer_json_schema_from_seekable` [\#4814](https://github.com/apache/arrow-rs/issues/4814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Support for reading JSON Array to Arrow [\#4905](https://github.com/apache/arrow-rs/issues/4905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Assume Pages Delimit Records When Offset Index Loaded \(\#4921\) [\#4943](https://github.com/apache/arrow-rs/pull/4943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update pyo3 requirement from 0.19 to 0.20 [\#4941](https://github.com/apache/arrow-rs/pull/4941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add `FileWriter` schema getter [\#4940](https://github.com/apache/arrow-rs/pull/4940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haixuanTao](https://github.com/haixuanTao)) +- feat: support parsing for parquet writer option [\#4938](https://github.com/apache/arrow-rs/pull/4938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([fansehep](https://github.com/fansehep)) +- Export `SubstraitPlan` structure in arrow\_flight::sql \(\#4932\) [\#4933](https://github.com/apache/arrow-rs/pull/4933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Update zstd requirement from 0.12.0 to 0.13.0 [\#4923](https://github.com/apache/arrow-rs/pull/4923) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add method for async read bloom filter [\#4917](https://github.com/apache/arrow-rs/pull/4917) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hengfeiyang](https://github.com/hengfeiyang)) +- Minor: Clarify rationale for `FlightDataEncoder` API, add examples [\#4916](https://github.com/apache/arrow-rs/pull/4916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [\#4914](https://github.com/apache/arrow-rs/pull/4914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: document & streamline flight SQL CLI [\#4912](https://github.com/apache/arrow-rs/pull/4912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Arbitrary JSON values in JSON Reader \(\#4905\) [\#4911](https://github.com/apache/arrow-rs/pull/4911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV WriterBuilder, Default to AutoSI Second Precision \(\#4735\) [\#4909](https://github.com/apache/arrow-rs/pull/4909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.68 to =1.0.69 [\#4907](https://github.com/apache/arrow-rs/pull/4907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- chore: add csv example [\#4904](https://github.com/apache/arrow-rs/pull/4904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- feat\(schema\): allow null fields to be merged with other datatypes [\#4902](https://github.com/apache/arrow-rs/pull/4902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Update proc-macro2 requirement from =1.0.67 to =1.0.68 [\#4900](https://github.com/apache/arrow-rs/pull/4900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add option to `FlightDataEncoder` to always resend batch dictionaries [\#4896](https://github.com/apache/arrow-rs/pull/4896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- Fix integration tests [\#4889](https://github.com/apache/arrow-rs/pull/4889) ([tustvold](https://github.com/tustvold)) +- Support Parsing Avro File Headers [\#4888](https://github.com/apache/arrow-rs/pull/4888) ([tustvold](https://github.com/tustvold)) +- Support parquet bloom filter length [\#4885](https://github.com/apache/arrow-rs/pull/4885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([letian-jiang](https://github.com/letian-jiang)) +- Replace lz4 with lz4\_flex Allowing Compilation for WASM [\#4884](https://github.com/apache/arrow-rs/pull/4884) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Take for UnionArray [\#4883](https://github.com/apache/arrow-rs/pull/4883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Update tonic-build requirement from =0.10.1 to =0.10.2 [\#4881](https://github.com/apache/arrow-rs/pull/4881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- parquet: Read field IDs from Parquet Schema [\#4878](https://github.com/apache/arrow-rs/pull/4878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samrose-Ahmed](https://github.com/Samrose-Ahmed)) +- feat: improve flight CLI error handling [\#4873](https://github.com/apache/arrow-rs/pull/4873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Encoding Parquet Columns in Parallel [\#4871](https://github.com/apache/arrow-rs/pull/4871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Check precision overflow for casting floating to decimal [\#4866](https://github.com/apache/arrow-rs/pull/4866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make align\_buffers as public API [\#4863](https://github.com/apache/arrow-rs/pull/4863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable new integration tests \(\#4828\) [\#4862](https://github.com/apache/arrow-rs/pull/4862) ([tustvold](https://github.com/tustvold)) +- Faster Serde Integration \(~80% faster\) [\#4861](https://github.com/apache/arrow-rs/pull/4861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: make\_primitive\_scalar bug [\#4852](https://github.com/apache/arrow-rs/pull/4852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JasonLi-cn](https://github.com/JasonLi-cn)) +- Update tonic-build requirement from =0.10.0 to =0.10.1 [\#4846](https://github.com/apache/arrow-rs/pull/4846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow Constructing Non-Empty StructArray with no Fields \(\#4842\) [\#4845](https://github.com/apache/arrow-rs/pull/4845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refine documentation to `Array::is_null` [\#4838](https://github.com/apache/arrow-rs/pull/4838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: add missing precision overflow checking for `cast_string_to_decimal` [\#4830](https://github.com/apache/arrow-rs/pull/4830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonahgao](https://github.com/jonahgao)) +## [47.0.0](https://github.com/apache/arrow-rs/tree/47.0.0) (2023-09-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/46.0.0...47.0.0) + +**Breaking changes:** + +- Make FixedSizeBinaryArray value\_data return a reference [\#4820](https://github.com/apache/arrow-rs/issues/4820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update prost to v0.12.1 [\#4825](https://github.com/apache/arrow-rs/pull/4825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: FixedSizeBinaryArray::value\_data return reference [\#4821](https://github.com/apache/arrow-rs/pull/4821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Stateless Row Encoding / Don't Preserve Dictionaries in `RowConverter` \(\#4811\) [\#4819](https://github.com/apache/arrow-rs/pull/4819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: entries field is non-nullable [\#4808](https://github.com/apache/arrow-rs/pull/4808) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix flight sql do put handling, add bind parameter support to FlightSQL cli client [\#4797](https://github.com/apache/arrow-rs/pull/4797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([suremarc](https://github.com/suremarc)) +- Remove unused dyn\_cmp\_dict feature [\#4766](https://github.com/apache/arrow-rs/pull/4766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add underlying `std::io::Error` to `IoError` and add `IpcError` variant [\#4726](https://github.com/apache/arrow-rs/pull/4726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Row Format Adapative Block Size [\#4812](https://github.com/apache/arrow-rs/issues/4812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Stateless Row Conversion [\#4811](https://github.com/apache/arrow-rs/issues/4811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add option to specify custom null values for CSV reader [\#4794](https://github.com/apache/arrow-rs/issues/4794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet::record::RowIter cannot be customized with batch\_size and defaults to 1024 [\#4782](https://github.com/apache/arrow-rs/issues/4782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `DynScalar` abstraction \(something that makes it easy to create scalar `Datum`s\) [\#4781](https://github.com/apache/arrow-rs/issues/4781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Datum` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4780](https://github.com/apache/arrow-rs/issues/4780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Scalar` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4779](https://github.com/apache/arrow-rs/issues/4779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support IntoPyArrow for impl RecordBatchReader [\#4730](https://github.com/apache/arrow-rs/issues/4730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based String Kernels [\#4595](https://github.com/apache/arrow-rs/issues/4595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- MapArray::new\_from\_strings creates nullable entries field [\#4807](https://github.com/apache/arrow-rs/issues/4807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pyarrow module can't roundtrip tensor arrays [\#4805](https://github.com/apache/arrow-rs/issues/4805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `concat_batches` errors with "schema mismatch" error when only metadata differs [\#4799](https://github.com/apache/arrow-rs/issues/4799) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- panic in `cmp` kernels with DictionaryArrays: `Option::unwrap()` on a `None` value' [\#4788](https://github.com/apache/arrow-rs/issues/4788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- stream ffi panics if schema metadata values aren't valid utf8 [\#4750](https://github.com/apache/arrow-rs/issues/4750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: Incorrect Sorting of `*ListArray` in 46.0.0 [\#4746](https://github.com/apache/arrow-rs/issues/4746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row is no longer comparable after reuse [\#4741](https://github.com/apache/arrow-rs/issues/4741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DoPut FlightSQL handler inadvertently consumes schema at start of Request\\> [\#4658](https://github.com/apache/arrow-rs/issues/4658) +- Return error when converting schema [\#4752](https://github.com/apache/arrow-rs/pull/4752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Implement PyArrowType for `Box` [\#4751](https://github.com/apache/arrow-rs/pull/4751) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Closed issues:** + +- Building arrow-rust for target wasm32-wasi falied to compile packed\_simd\_2 [\#4717](https://github.com/apache/arrow-rs/issues/4717) + +**Merged pull requests:** + +- Respect FormatOption::nulls for NullArray [\#4836](https://github.com/apache/arrow-rs/pull/4836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix merge\_dictionary\_values in selection kernels [\#4833](https://github.com/apache/arrow-rs/pull/4833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix like scalar null [\#4832](https://github.com/apache/arrow-rs/pull/4832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More chrono deprecations [\#4822](https://github.com/apache/arrow-rs/pull/4822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Adaptive Row Block Size \(\#4812\) [\#4818](https://github.com/apache/arrow-rs/pull/4818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.66 to =1.0.67 [\#4816](https://github.com/apache/arrow-rs/pull/4816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Do not check schema for equality in concat\_batches [\#4815](https://github.com/apache/arrow-rs/pull/4815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: export record batch through stream [\#4806](https://github.com/apache/arrow-rs/pull/4806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve CSV Reader Benchmark Coverage of Small Primitives [\#4803](https://github.com/apache/arrow-rs/pull/4803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- csv: Add option to specify custom null values [\#4795](https://github.com/apache/arrow-rs/pull/4795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vrongmeal](https://github.com/vrongmeal)) +- Expand docstring and add example to `Scalar` [\#4793](https://github.com/apache/arrow-rs/pull/4793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Re-export array crate root \(\#4780\) \(\#4779\) [\#4791](https://github.com/apache/arrow-rs/pull/4791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix DictionaryArray::normalized\_keys \(\#4788\) [\#4789](https://github.com/apache/arrow-rs/pull/4789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow custom tree builder for parquet::record::RowIter [\#4783](https://github.com/apache/arrow-rs/pull/4783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([YuraKotov](https://github.com/YuraKotov)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: avoid panic if offset index not exists. [\#4761](https://github.com/apache/arrow-rs/pull/4761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Relax constraints on PyArrowType [\#4757](https://github.com/apache/arrow-rs/pull/4757) ([tustvold](https://github.com/tustvold)) +- Chrono deprecations [\#4748](https://github.com/apache/arrow-rs/pull/4748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix List Sorting, Revert Removal of Rank Kernels [\#4747](https://github.com/apache/arrow-rs/pull/4747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear row buffer before reuse [\#4742](https://github.com/apache/arrow-rs/pull/4742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Datum based like kernels \(\#4595\) [\#4732](https://github.com/apache/arrow-rs/pull/4732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: expose DoGet response headers & trailers [\#4727](https://github.com/apache/arrow-rs/pull/4727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Cleanup length and bit\_length kernels [\#4718](https://github.com/apache/arrow-rs/pull/4718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [46.0.0](https://github.com/apache/arrow-rs/tree/46.0.0) (2023-08-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/45.0.0...46.0.0) + +**Breaking changes:** + +- API improvement: `batches_to_flight_data` forces clone [\#4656](https://github.com/apache/arrow-rs/issues/4656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add AnyDictionary Abstraction and Take ArrayRef in DictionaryArray::with\_values [\#4707](https://github.com/apache/arrow-rs/pull/4707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup parquet type builders [\#4706](https://github.com/apache/arrow-rs/pull/4706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Take kernel dyn Array [\#4705](https://github.com/apache/arrow-rs/pull/4705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve ergonomics of Scalar [\#4704](https://github.com/apache/arrow-rs/pull/4704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Datum based comparison kernels \(\#4596\) [\#4701](https://github.com/apache/arrow-rs/pull/4701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Improve `Array` Logical Nullability [\#4691](https://github.com/apache/arrow-rs/pull/4691) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate ArrayData Buffer Alignment and Automatically Align IPC buffers \(\#4255\) [\#4681](https://github.com/apache/arrow-rs/pull/4681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More intuitive bool-to-string casting [\#4666](https://github.com/apache/arrow-rs/pull/4666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fsdvh](https://github.com/fsdvh)) +- enhancement: batches\_to\_flight\_data use a schema ref as param. [\#4665](https://github.com/apache/arrow-rs/pull/4665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([jackwener](https://github.com/jackwener)) +- fix: from\_thrift avoid panic when stats in invalid. [\#4642](https://github.com/apache/arrow-rs/pull/4642) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jackwener](https://github.com/jackwener)) +- bug: Add some missing field in row group metadata: ordinal, total co… [\#4636](https://github.com/apache/arrow-rs/pull/4636) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liurenjie1024](https://github.com/liurenjie1024)) +- Remove deprecated limit kernel [\#4597](https://github.com/apache/arrow-rs/pull/4597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- parquet: support setting the field\_id with an ArrowWriter [\#4702](https://github.com/apache/arrow-rs/issues/4702) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support references in i256 arithmetic ops [\#4694](https://github.com/apache/arrow-rs/issues/4694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Precision-Loss Decimal Arithmetic [\#4664](https://github.com/apache/arrow-rs/issues/4664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Faster i256 Division [\#4663](https://github.com/apache/arrow-rs/issues/4663) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `concat_batches` for 0 columns [\#4661](https://github.com/apache/arrow-rs/issues/4661) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `filter_record_batch` should support filtering record batch without columns [\#4647](https://github.com/apache/arrow-rs/issues/4647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve speed of `lexicographical_partition_ranges` [\#4614](https://github.com/apache/arrow-rs/issues/4614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- object\_store: multipart ranges for HTTP [\#4612](https://github.com/apache/arrow-rs/issues/4612) +- Add Rank Function [\#4606](https://github.com/apache/arrow-rs/issues/4606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based Comparison Kernels [\#4596](https://github.com/apache/arrow-rs/issues/4596) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Convenience method to create `DataType::List` correctly [\#4544](https://github.com/apache/arrow-rs/issues/4544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Deprecated Arithmetic Kernels [\#4481](https://github.com/apache/arrow-rs/issues/4481) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Equality kernel where null==null gives true [\#4438](https://github.com/apache/arrow-rs/issues/4438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Parquet ArrowWriter Ignores Nulls in Dictionary Values [\#4690](https://github.com/apache/arrow-rs/issues/4690) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema Nullability Validation Fails to Account for Dictionary Nulls [\#4689](https://github.com/apache/arrow-rs/issues/4689) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Comparison Kernels Ignore Nulls in Dictionary Values [\#4688](https://github.com/apache/arrow-rs/issues/4688) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting List to String Ignores Format Options [\#4669](https://github.com/apache/arrow-rs/issues/4669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Double free in C Stream Interface [\#4659](https://github.com/apache/arrow-rs/issues/4659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CI Failing On Packed SIMD [\#4651](https://github.com/apache/arrow-rs/issues/4651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowInterner::size()` much too low for high cardinality dictionary columns [\#4645](https://github.com/apache/arrow-rs/issues/4645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal PrimitiveArray change datatype after try\_unary [\#4644](https://github.com/apache/arrow-rs/issues/4644) +- Better explanation in docs for Dictionary field encoding using RowConverter [\#4639](https://github.com/apache/arrow-rs/issues/4639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `List(FixedSizeBinary)` array equality check may return wrong result [\#4637](https://github.com/apache/arrow-rs/issues/4637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow::compute::nullif` panics if `NullArray` is provided [\#4634](https://github.com/apache/arrow-rs/issues/4634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Empty lists in FixedSizeListArray::try\_new is not handled [\#4623](https://github.com/apache/arrow-rs/issues/4623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bounds checking in `MutableBuffer::set_null_bits` can be bypassed [\#4620](https://github.com/apache/arrow-rs/issues/4620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- TypedDictionaryArray Misleading Null Behaviour [\#4616](https://github.com/apache/arrow-rs/issues/4616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- bug: Parquet writer missing row group metadata fields such as `compressed_size`, `file offset`. [\#4610](https://github.com/apache/arrow-rs/issues/4610) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `new_null_array` generates an invalid union array [\#4600](https://github.com/apache/arrow-rs/issues/4600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Footer parsing fails for very large parquet file. [\#4592](https://github.com/apache/arrow-rs/issues/4592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- bug\(parquet\): Disabling global statistics but enabling for particular column breaks reading [\#4587](https://github.com/apache/arrow-rs/issues/4587) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `arrow::compute::concat` panics for dense union arrays with non-trivial type IDs [\#4578](https://github.com/apache/arrow-rs/issues/4578) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- \[object\_store\] when Create a AmazonS3 instance work with MinIO without set endpoint got error MissingRegion [\#4617](https://github.com/apache/arrow-rs/issues/4617) + +**Merged pull requests:** + +- Add distinct kernels \(\#960\) \(\#4438\) [\#4716](https://github.com/apache/arrow-rs/pull/4716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update parquet object\_store 0.7 [\#4715](https://github.com/apache/arrow-rs/pull/4715) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support Field ID in ArrowWriter \(\#4702\) [\#4710](https://github.com/apache/arrow-rs/pull/4710) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove rank kernels [\#4703](https://github.com/apache/arrow-rs/pull/4703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support references in i256 arithmetic ops [\#4692](https://github.com/apache/arrow-rs/pull/4692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup DynComparator \(\#2654\) [\#4687](https://github.com/apache/arrow-rs/pull/4687) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Separate metadata fetch from `ArrowReaderBuilder` construction \(\#4674\) [\#4676](https://github.com/apache/arrow-rs/pull/4676) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- cleanup some assert\(\) with error propagation [\#4673](https://github.com/apache/arrow-rs/pull/4673) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Faster i256 Division \(2-100x\) \(\#4663\) [\#4672](https://github.com/apache/arrow-rs/pull/4672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix MSRV CI [\#4671](https://github.com/apache/arrow-rs/pull/4671) ([tustvold](https://github.com/tustvold)) +- Fix equality of nested nullable FixedSizeBinary \(\#4637\) [\#4670](https://github.com/apache/arrow-rs/pull/4670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use ArrayFormatter in cast kernel [\#4668](https://github.com/apache/arrow-rs/pull/4668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve API docs for FlightSQL metadata builders [\#4667](https://github.com/apache/arrow-rs/pull/4667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Support `concat_batches` for 0 columns [\#4662](https://github.com/apache/arrow-rs/pull/4662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- fix ownership of c stream error [\#4660](https://github.com/apache/arrow-rs/pull/4660) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Minor: Fix illustration for dict encoding [\#4657](https://github.com/apache/arrow-rs/pull/4657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) +- minor: move comment to the correct location [\#4655](https://github.com/apache/arrow-rs/pull/4655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Update packed\_simd and run miri tests on simd code [\#4654](https://github.com/apache/arrow-rs/pull/4654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- impl `From>` for `BufferBuilder` and `MutableBuffer` [\#4650](https://github.com/apache/arrow-rs/pull/4650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Filter record batch with 0 columns [\#4648](https://github.com/apache/arrow-rs/pull/4648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Account for child `Bucket` size in OrderPreservingInterner [\#4646](https://github.com/apache/arrow-rs/pull/4646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Implement `Default`,`Extend` and `FromIterator` for `BufferBuilder` [\#4638](https://github.com/apache/arrow-rs/pull/4638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- fix\(select\): handle `NullArray` in `nullif` [\#4635](https://github.com/apache/arrow-rs/pull/4635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Move `BufferBuilder` to `arrow-buffer` [\#4630](https://github.com/apache/arrow-rs/pull/4630) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- allow zero sized empty fixed [\#4626](https://github.com/apache/arrow-rs/pull/4626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix: compute\_dictionary\_mapping use wrong offsetSize [\#4625](https://github.com/apache/arrow-rs/pull/4625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- impl `FromIterator` for `MutableBuffer` [\#4624](https://github.com/apache/arrow-rs/pull/4624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- expand docs for FixedSizeListArray [\#4622](https://github.com/apache/arrow-rs/pull/4622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix\(buffer\): panic on end index overflow in `MutableBuffer::set_null_bits` [\#4621](https://github.com/apache/arrow-rs/pull/4621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- impl `Default` for `arrow_buffer::buffer::MutableBuffer` [\#4619](https://github.com/apache/arrow-rs/pull/4619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Minor: improve docs and add example for lexicographical\_partition\_ranges [\#4615](https://github.com/apache/arrow-rs/pull/4615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Cleanup sort [\#4613](https://github.com/apache/arrow-rs/pull/4613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add rank function \(\#4606\) [\#4609](https://github.com/apache/arrow-rs/pull/4609) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more docs and examples for ListArray and OffsetsBuffer [\#4607](https://github.com/apache/arrow-rs/pull/4607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Simplify dictionary sort [\#4605](https://github.com/apache/arrow-rs/pull/4605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate sort benchmarks [\#4604](https://github.com/apache/arrow-rs/pull/4604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't Reorder Nulls in sort\_to\_indices \(\#4545\) [\#4603](https://github.com/apache/arrow-rs/pull/4603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix\(data\): create child arrays of correct length when building a sparse union null array [\#4601](https://github.com/apache/arrow-rs/pull/4601) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Use u32 metadata\_len when parsing footer of parquet. [\#4599](https://github.com/apache/arrow-rs/pull/4599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Berrysoft](https://github.com/Berrysoft)) +- fix\(data\): map type ID to child index before indexing a union child array [\#4598](https://github.com/apache/arrow-rs/pull/4598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Remove deprecated arithmetic kernels \(\#4481\) [\#4594](https://github.com/apache/arrow-rs/pull/4594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Test Disabled Page Statistics \(\#4587\) [\#4589](https://github.com/apache/arrow-rs/pull/4589) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup ArrayData::buffers [\#4583](https://github.com/apache/arrow-rs/pull/4583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use contains\_nulls in ArrayData equality of byte arrays [\#4582](https://github.com/apache/arrow-rs/pull/4582) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Vectorized lexicographical\_partition\_ranges \(~80% faster\) [\#4575](https://github.com/apache/arrow-rs/pull/4575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- chore: add datatype new\_list [\#4561](https://github.com/apache/arrow-rs/pull/4561) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +## [45.0.0](https://github.com/apache/arrow-rs/tree/45.0.0) (2023-07-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/44.0.0...45.0.0) + +**Breaking changes:** + +- Fix timezoned timestamp arithmetic [\#4546](https://github.com/apache/arrow-rs/pull/4546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Use FormatOptions in Const Contexts [\#4580](https://github.com/apache/arrow-rs/issues/4580) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Human Readable Duration Display [\#4554](https://github.com/apache/arrow-rs/issues/4554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `BooleanBuilder`: Add `validity_slice` method for accessing validity bits [\#4535](https://github.com/apache/arrow-rs/issues/4535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `FixedSizedListArray` for `length` kernel [\#4517](https://github.com/apache/arrow-rs/issues/4517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowCoverter::convert` that targets an existing `Rows` [\#4479](https://github.com/apache/arrow-rs/issues/4479) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Panic `assertion failed: idx < self.len` when casting DictionaryArrays with nulls [\#4576](https://github.com/apache/arrow-rs/issues/4576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-arith is\_null is buggy with NullArray [\#4565](https://github.com/apache/arrow-rs/issues/4565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval to Duration Casting [\#4553](https://github.com/apache/arrow-rs/issues/4553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Too large validity buffer pre-allocation in `FixedSizeListBuilder::new` [\#4549](https://github.com/apache/arrow-rs/issues/4549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Like with wildcards fail to match fields with new lines. [\#4547](https://github.com/apache/arrow-rs/issues/4547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Timestamp Interval Arithmetic Ignores Timezone [\#4457](https://github.com/apache/arrow-rs/issues/4457) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- refactor: simplify hour\_dyn\(\) with time\_fraction\_dyn\(\) [\#4588](https://github.com/apache/arrow-rs/pull/4588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Move from\_iter\_values to GenericByteArray [\#4586](https://github.com/apache/arrow-rs/pull/4586) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Mark GenericByteArray::new\_unchecked unsafe [\#4584](https://github.com/apache/arrow-rs/pull/4584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Configurable Duration Display [\#4581](https://github.com/apache/arrow-rs/pull/4581) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix take\_bytes Null and Overflow Handling \(\#4576\) [\#4579](https://github.com/apache/arrow-rs/pull/4579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move chrono-tz arithmetic tests to integration [\#4571](https://github.com/apache/arrow-rs/pull/4571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Write Page Offset Index For All-Nan Pages [\#4567](https://github.com/apache/arrow-rs/pull/4567) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([MachaelLee](https://github.com/MachaelLee)) +- support NullArray un arith/boolean kernel [\#4566](https://github.com/apache/arrow-rs/pull/4566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- Remove Sync from arrow-flight example [\#4564](https://github.com/apache/arrow-rs/pull/4564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix interval to duration casting \(\#4553\) [\#4562](https://github.com/apache/arrow-rs/pull/4562) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong parameter name [\#4559](https://github.com/apache/arrow-rs/pull/4559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- Fix FixedSizeListBuilder capacity \(\#4549\) [\#4552](https://github.com/apache/arrow-rs/pull/4552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong inline code snippet in parquet document [\#4550](https://github.com/apache/arrow-rs/pull/4550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- fix multiline wildcard likes \(fixes \#4547\) [\#4548](https://github.com/apache/arrow-rs/pull/4548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nl5887](https://github.com/nl5887)) +- Provide default `is_empty` impl for `arrow::array::ArrayBuilder` [\#4543](https://github.com/apache/arrow-rs/pull/4543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Add RowConverter::append \(\#4479\) [\#4541](https://github.com/apache/arrow-rs/pull/4541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clarify GenericColumnReader::read\_records [\#4540](https://github.com/apache/arrow-rs/pull/4540) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Initial loongarch port [\#4538](https://github.com/apache/arrow-rs/pull/4538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xiangzhai](https://github.com/xiangzhai)) +- Update proc-macro2 requirement from =1.0.64 to =1.0.66 [\#4537](https://github.com/apache/arrow-rs/pull/4537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- add a validity slice access for boolean array builders [\#4536](https://github.com/apache/arrow-rs/pull/4536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ChristianBeilschmidt](https://github.com/ChristianBeilschmidt)) +- use new num version instead of explicit num-complex dependency [\#4532](https://github.com/apache/arrow-rs/pull/4532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- feat: Support `FixedSizedListArray` for `length` kernel [\#4520](https://github.com/apache/arrow-rs/pull/4520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +## [44.0.0](https://github.com/apache/arrow-rs/tree/44.0.0) (2023-07-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/43.0.0...44.0.0) + +**Breaking changes:** + +- Use Parser for cast kernel \(\#4512\) [\#4513](https://github.com/apache/arrow-rs/pull/4513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum based arithmetic kernels \(\#3999\) [\#4465](https://github.com/apache/arrow-rs/pull/4465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- eq\_dyn\_binary\_scalar should support FixedSizeBinary types [\#4491](https://github.com/apache/arrow-rs/issues/4491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Port Tests from Deprecated Arithmetic Kernels [\#4480](https://github.com/apache/arrow-rs/issues/4480) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement RecordBatchReader for Boxed trait object [\#4474](https://github.com/apache/arrow-rs/issues/4474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Date` - `Date` kernel [\#4383](https://github.com/apache/arrow-rs/issues/4383) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Default FlightSqlService Implementations [\#4372](https://github.com/apache/arrow-rs/issues/4372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Parquet: `AsyncArrowWriter` to a file corrupts the footer for large columns [\#4526](https://github.com/apache/arrow-rs/issues/4526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[object\_store\] Failure to send bytes to azure [\#4522](https://github.com/apache/arrow-rs/issues/4522) +- Cannot cast string '2021-01-02' to value of Date64 type [\#4512](https://github.com/apache/arrow-rs/issues/4512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval Subtraction [\#4489](https://github.com/apache/arrow-rs/issues/4489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Interval Negation Incorrect [\#4488](https://github.com/apache/arrow-rs/issues/4488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet: AsyncArrowWriter inner buffer is not correctly limited and causes OOM [\#4477](https://github.com/apache/arrow-rs/issues/4477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Fix AsyncArrowWriter flush for large buffer sizes \(\#4526\) [\#4527](https://github.com/apache/arrow-rs/pull/4527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup cast\_primitive\_to\_list [\#4511](https://github.com/apache/arrow-rs/pull/4511) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/upload-pages-artifact from 1 to 2 [\#4508](https://github.com/apache/arrow-rs/pull/4508) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Date - Date \(\#4383\) [\#4504](https://github.com/apache/arrow-rs/pull/4504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/labeler from 4.2.0 to 4.3.0 [\#4501](https://github.com/apache/arrow-rs/pull/4501) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.63 to =1.0.64 [\#4500](https://github.com/apache/arrow-rs/pull/4500) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add negate kernels \(\#4488\) [\#4494](https://github.com/apache/arrow-rs/pull/4494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum Arithmetic tests, Fix Interval Substraction \(\#4480\) [\#4493](https://github.com/apache/arrow-rs/pull/4493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- support FixedSizeBinary types in eq\_dyn\_binary\_scalar/neq\_dyn\_binary\_scalar [\#4492](https://github.com/apache/arrow-rs/pull/4492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- Add default implementations to the FlightSqlService trait [\#4485](https://github.com/apache/arrow-rs/pull/4485) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([rossjones](https://github.com/rossjones)) +- add num-complex requirement [\#4482](https://github.com/apache/arrow-rs/pull/4482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- fix incorrect buffer size limiting in parquet async writer [\#4478](https://github.com/apache/arrow-rs/pull/4478) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([richox](https://github.com/richox)) +- feat: support RecordBatchReader on boxed trait objects [\#4475](https://github.com/apache/arrow-rs/pull/4475) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve in-place primitive sorts by 13-67% [\#4473](https://github.com/apache/arrow-rs/pull/4473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add Scalar/Datum abstraction \(\#1047\) [\#4393](https://github.com/apache/arrow-rs/pull/4393) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [43.0.0](https://github.com/apache/arrow-rs/tree/43.0.0) (2023-06-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/42.0.0...43.0.0) + +**Breaking changes:** + +- Simplify ffi import/export [\#4447](https://github.com/apache/arrow-rs/pull/4447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Virgiel](https://github.com/Virgiel)) +- Return Result from Parquet Row APIs [\#4428](https://github.com/apache/arrow-rs/pull/4428) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Remove Binary Dictionary Arithmetic Support [\#4407](https://github.com/apache/arrow-rs/pull/4407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Request: a way to copy a `Row` to `Rows` [\#4466](https://github.com/apache/arrow-rs/issues/4466) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Reuse schema when importing from FFI [\#4444](https://github.com/apache/arrow-rs/issues/4444) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] Allow implementations of `FlightSqlService` to handle custom actions and commands [\#4439](https://github.com/apache/arrow-rs/issues/4439) +- Support `NullBuilder` [\#4429](https://github.com/apache/arrow-rs/issues/4429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Regression in in parquet `42.0.0` : Bad parquet column indexes for All Null Columns, resulting in `Parquet error: StructArrayReader out of sync` on read [\#4459](https://github.com/apache/arrow-rs/issues/4459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Regression in 42.0.0: Parsing fractional intervals without leading 0 is not supported [\#4424](https://github.com/apache/arrow-rs/issues/4424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- doc: deploy crate docs to GitHub pages [\#4436](https://github.com/apache/arrow-rs/pull/4436) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) + +**Merged pull requests:** + +- Append Row to Rows \(\#4466\) [\#4470](https://github.com/apache/arrow-rs/pull/4470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat\(flight-sql\): Allow implementations of FlightSqlService to handle custom actions and commands [\#4463](https://github.com/apache/arrow-rs/pull/4463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Docs: Add clearer API doc links [\#4461](https://github.com/apache/arrow-rs/pull/4461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Fix empty offset index for all null columns \(\#4459\) [\#4460](https://github.com/apache/arrow-rs/pull/4460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Bump peaceiris/actions-gh-pages from 3.9.2 to 3.9.3 [\#4455](https://github.com/apache/arrow-rs/pull/4455) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Convince the compiler to auto-vectorize the range check in parquet DictionaryBuffer [\#4453](https://github.com/apache/arrow-rs/pull/4453) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- fix docs deployment [\#4452](https://github.com/apache/arrow-rs/pull/4452) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) +- Update indexmap requirement from 1.9 to 2.0 [\#4451](https://github.com/apache/arrow-rs/pull/4451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.60 to =1.0.63 [\#4450](https://github.com/apache/arrow-rs/pull/4450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Revise error message in From\ for ScalarBuffer [\#4446](https://github.com/apache/arrow-rs/pull/4446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- minor: remove useless mut [\#4443](https://github.com/apache/arrow-rs/pull/4443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- unify substring for binary&utf8 [\#4442](https://github.com/apache/arrow-rs/pull/4442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Casting fixedsizelist to list/largelist [\#4433](https://github.com/apache/arrow-rs/pull/4433) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- feat: support `NullBuilder` [\#4430](https://github.com/apache/arrow-rs/pull/4430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Remove Float64 -\> Float32 cast in IPC Reader [\#4427](https://github.com/apache/arrow-rs/pull/4427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Parse intervals like `.5` the same as `0.5` [\#4425](https://github.com/apache/arrow-rs/pull/4425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add strict mode to json reader [\#4421](https://github.com/apache/arrow-rs/pull/4421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([blinkseb](https://github.com/blinkseb)) +- Add DictionaryArray::occupancy [\#4415](https://github.com/apache/arrow-rs/pull/4415) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [42.0.0](https://github.com/apache/arrow-rs/tree/42.0.0) (2023-06-16) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/41.0.0...42.0.0) + +**Breaking changes:** + +- Remove 64-bit to 32-bit Cast from IPC Reader [\#4412](https://github.com/apache/arrow-rs/pull/4412) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Truncate Min/Max values in the Column Index [\#4389](https://github.com/apache/arrow-rs/pull/4389) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +- feat\(flight\): harmonize server metadata APIs [\#4384](https://github.com/apache/arrow-rs/pull/4384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Move record delimiting into ColumnReader \(\#4365\) [\#4376](https://github.com/apache/arrow-rs/pull/4376) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Changed array\_to\_json\_array to take &dyn Array [\#4370](https://github.com/apache/arrow-rs/pull/4370) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- Make PrimitiveArray::with\_timezone consuming [\#4366](https://github.com/apache/arrow-rs/pull/4366) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add doc example of constructing a MapArray [\#4385](https://github.com/apache/arrow-rs/issues/4385) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `millisecond` and `microsecond` functions [\#4374](https://github.com/apache/arrow-rs/issues/4374) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Changed array\_to\_json\_array to take &dyn Array [\#4369](https://github.com/apache/arrow-rs/issues/4369) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- compute::ord kernel for getting min and max of two scalar/array values [\#4347](https://github.com/apache/arrow-rs/issues/4347) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release 41.0.0 of arrow/arrow-flight/parquet/parquet-derive [\#4346](https://github.com/apache/arrow-rs/issues/4346) +- Refactor CAST tests to use new cast array syntax [\#4336](https://github.com/apache/arrow-rs/issues/4336) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pass bytes directly to parquet's KeyValue [\#4317](https://github.com/apache/arrow-rs/issues/4317) +- PyArrow conversions could return TypeError if provided incorrect Python type [\#4312](https://github.com/apache/arrow-rs/issues/4312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Have array\_to\_json\_array support Map [\#4297](https://github.com/apache/arrow-rs/issues/4297) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FlightSQL: Add helpers to create `CommandGetXdbcTypeInfo` responses \(`XdbcInfoValue` and builders\) [\#4257](https://github.com/apache/arrow-rs/issues/4257) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Have array\_to\_json\_array support FixedSizeList [\#4248](https://github.com/apache/arrow-rs/issues/4248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Truncate ColumnIndex ByteArray Statistics [\#4126](https://github.com/apache/arrow-rs/issues/4126) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Arrow compute kernel regards selection vector [\#4095](https://github.com/apache/arrow-rs/issues/4095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Wrongly calculated data compressed length in IPC writer [\#4410](https://github.com/apache/arrow-rs/issues/4410) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Take Kernel Handles Nullable Indices Incorrectly [\#4404](https://github.com/apache/arrow-rs/issues/4404) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder::new Doesn't Validate Builder DataTypes [\#4397](https://github.com/apache/arrow-rs/issues/4397) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet error: Not all children array length are the same! when using RowSelection to read a parquet file [\#4396](https://github.com/apache/arrow-rs/issues/4396) +- RecordReader::skip\_records Is Incorrect for Repeated Columns [\#4368](https://github.com/apache/arrow-rs/issues/4368) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- List-of-String Array panics in the presence of row filters [\#4365](https://github.com/apache/arrow-rs/issues/4365) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Fail to read block compressed gzip files with parquet-fromcsv [\#4173](https://github.com/apache/arrow-rs/issues/4173) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Closed issues:** + +- Have a parquet file not able to be deduped via arrow-rs, complains about Decimal precision? [\#4356](https://github.com/apache/arrow-rs/issues/4356) +- Question: Could we move `dict_id, dict_is_ordered` into DataType? [\#4325](https://github.com/apache/arrow-rs/issues/4325) + +**Merged pull requests:** + +- Fix reading gzip file with multiple gzip headers in parquet-fromcsv. [\#4419](https://github.com/apache/arrow-rs/pull/4419) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ghuls](https://github.com/ghuls)) +- Cleanup nullif kernel [\#4416](https://github.com/apache/arrow-rs/pull/4416) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix bug in IPC logic that determines if the buffer should be compressed or not [\#4411](https://github.com/apache/arrow-rs/pull/4411) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lwpyr](https://github.com/lwpyr)) +- Faster unpacking of Int32Type dictionary [\#4406](https://github.com/apache/arrow-rs/pull/4406) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve `take` kernel performance on primitive arrays, fix bad null index handling \(\#4404\) [\#4405](https://github.com/apache/arrow-rs/pull/4405) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More take benchmarks [\#4403](https://github.com/apache/arrow-rs/pull/4403) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `BooleanBuffer::new_unset` and `BooleanBuffer::new_set` and `BooleanArray::new_null` constructors [\#4402](https://github.com/apache/arrow-rs/pull/4402) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveBuilder type constructors [\#4401](https://github.com/apache/arrow-rs/pull/4401) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- StructBuilder Validate Child Data \(\#4397\) [\#4400](https://github.com/apache/arrow-rs/pull/4400) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster UTF-8 truncation [\#4399](https://github.com/apache/arrow-rs/pull/4399) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Minor: Derive `Hash` impls for `CastOptions` and `FormatOptions` [\#4395](https://github.com/apache/arrow-rs/pull/4395) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix typo in README [\#4394](https://github.com/apache/arrow-rs/pull/4394) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([okue](https://github.com/okue)) +- Improve parquet `WriterProperites` and `ReaderProperties` docs [\#4392](https://github.com/apache/arrow-rs/pull/4392) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Cleanup downcast macros [\#4391](https://github.com/apache/arrow-rs/pull/4391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.59 to =1.0.60 [\#4388](https://github.com/apache/arrow-rs/pull/4388) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Consolidate ByteArray::from\_iterator [\#4386](https://github.com/apache/arrow-rs/pull/4386) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add MapArray constructors and doc example [\#4382](https://github.com/apache/arrow-rs/pull/4382) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Documentation Improvements [\#4381](https://github.com/apache/arrow-rs/pull/4381) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add NullBuffer and BooleanBuffer From conversions [\#4380](https://github.com/apache/arrow-rs/pull/4380) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more examples of constructing Boolean, Primitive, String, and Decimal Arrays, and From impl for i256 [\#4379](https://github.com/apache/arrow-rs/pull/4379) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add ListArrayReader benchmarks [\#4378](https://github.com/apache/arrow-rs/pull/4378) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update comfy-table requirement from 6.0 to 7.0 [\#4377](https://github.com/apache/arrow-rs/pull/4377) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: Add`microsecond` and `millisecond` kernels [\#4375](https://github.com/apache/arrow-rs/pull/4375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Update hashbrown requirement from 0.13 to 0.14 [\#4373](https://github.com/apache/arrow-rs/pull/4373) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- minor: use as\_boolean to resolve TODO [\#4367](https://github.com/apache/arrow-rs/pull/4367) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Have array\_to\_json\_array support MapArray [\#4364](https://github.com/apache/arrow-rs/pull/4364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- deprecate: as\_decimal\_array [\#4363](https://github.com/apache/arrow-rs/pull/4363) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add support for FixedSizeList in array\_to\_json\_array [\#4361](https://github.com/apache/arrow-rs/pull/4361) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- refact: use as\_primitive in cast.rs test [\#4360](https://github.com/apache/arrow-rs/pull/4360) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat\(flight\): add xdbc type info helpers [\#4359](https://github.com/apache/arrow-rs/pull/4359) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Minor: float16 to json [\#4358](https://github.com/apache/arrow-rs/pull/4358) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Raise TypeError on PyArrow import [\#4316](https://github.com/apache/arrow-rs/pull/4316) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Arrow Cast: Fixed Point Arithmetic for Interval Parsing [\#4291](https://github.com/apache/arrow-rs/pull/4291) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mr-brobot](https://github.com/mr-brobot)) +## [41.0.0](https://github.com/apache/arrow-rs/tree/41.0.0) (2023-06-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/40.0.0...41.0.0) + +**Breaking changes:** + +- Rename list contains kernels to in\_list \(\#4289\) [\#4342](https://github.com/apache/arrow-rs/pull/4342) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move BooleanBufferBuilder and NullBufferBuilder to arrow\_buffer [\#4338](https://github.com/apache/arrow-rs/pull/4338) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add separate row\_count and level\_count to PageMetadata \(\#4321\) [\#4326](https://github.com/apache/arrow-rs/pull/4326) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Treat legacy TIMSETAMP\_X converted types as UTC [\#4309](https://github.com/apache/arrow-rs/pull/4309) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sergiimk](https://github.com/sergiimk)) +- Simplify parquet PageIterator [\#4306](https://github.com/apache/arrow-rs/pull/4306) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add Builder style APIs and docs for `FlightData`,` FlightInfo`, `FlightEndpoint`, `Locaation` and `Ticket` [\#4294](https://github.com/apache/arrow-rs/pull/4294) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Make GenericColumnWriter Send [\#4287](https://github.com/apache/arrow-rs/pull/4287) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: update flight-sql to latest specs [\#4250](https://github.com/apache/arrow-rs/pull/4250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- feat\(api!\): make ArrowArrayStreamReader Send [\#4232](https://github.com/apache/arrow-rs/pull/4232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Implemented enhancements:** + +- Make SerializedRowGroupReader::new\(\) Public [\#4330](https://github.com/apache/arrow-rs/issues/4330) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up i256 division and remainder operations [\#4302](https://github.com/apache/arrow-rs/issues/4302) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- export function parquet\_to\_array\_schema\_and\_fields [\#4298](https://github.com/apache/arrow-rs/issues/4298) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FLightSQL: add helpers to create `CommandGetCatalogs`, `CommandGetSchemas`, and `CommandGetTables` requests [\#4295](https://github.com/apache/arrow-rs/issues/4295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Make ColumnWriter Send [\#4286](https://github.com/apache/arrow-rs/issues/4286) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add Builder for `FlightInfo` to make it easier to create new requests [\#4281](https://github.com/apache/arrow-rs/issues/4281) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Writing/Reading Decimal256 to/from Parquet [\#4264](https://github.com/apache/arrow-rs/issues/4264) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FlightSQL: Add helpers to create `CommandGetSqlInfo` responses \(`SqlInfoValue` and builders\) [\#4256](https://github.com/apache/arrow-rs/issues/4256) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Update flight-sql implementation to latest specs [\#4249](https://github.com/apache/arrow-rs/issues/4249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Make ArrowArrayStreamReader Send [\#4222](https://github.com/apache/arrow-rs/issues/4222) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support writing FixedSizeList to Parquet [\#4214](https://github.com/apache/arrow-rs/issues/4214) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Cast between `Intervals` [\#4181](https://github.com/apache/arrow-rs/issues/4181) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Splice Parquet Data [\#4155](https://github.com/apache/arrow-rs/issues/4155) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV Schema More Flexible Timestamp Inference [\#4131](https://github.com/apache/arrow-rs/issues/4131) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Doc for arrow\_flight::sql is missing enums that are Xdbc related [\#4339](https://github.com/apache/arrow-rs/issues/4339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- concat\_batches panics with total\_len \<= bit\_len assertion for records with lists [\#4324](https://github.com/apache/arrow-rs/issues/4324) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect PageMetadata Row Count returned for V1 DataPage [\#4321](https://github.com/apache/arrow-rs/issues/4321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[parquet\] Not following the spec for TIMESTAMP\_MILLIS legacy converted types [\#4308](https://github.com/apache/arrow-rs/issues/4308) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- ambiguous glob re-exports of contains\_utf8 [\#4289](https://github.com/apache/arrow-rs/issues/4289) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- flight\_sql\_client --header "key: value" yields a value with a leading whitespace [\#4270](https://github.com/apache/arrow-rs/issues/4270) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Casting Timestamp to date is off by one day for dates before 1970-01-01 [\#4211](https://github.com/apache/arrow-rs/issues/4211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Don't infer 16-byte decimal as decimal256 [\#4349](https://github.com/apache/arrow-rs/pull/4349) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix MutableArrayData::extend\_nulls \(\#1230\) [\#4343](https://github.com/apache/arrow-rs/pull/4343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update FlightSQL metadata locations, names and docs [\#4341](https://github.com/apache/arrow-rs/pull/4341) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- chore: expose Xdbc related FlightSQL enums [\#4340](https://github.com/apache/arrow-rs/pull/4340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([appletreeisyellow](https://github.com/appletreeisyellow)) +- Update pyo3 requirement from 0.18 to 0.19 [\#4335](https://github.com/apache/arrow-rs/pull/4335) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Skip unnecessary null checks in MutableArrayData [\#4333](https://github.com/apache/arrow-rs/pull/4333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add read parquet by custom rowgroup examples [\#4332](https://github.com/apache/arrow-rs/pull/4332) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sundy-li](https://github.com/sundy-li)) +- Make SerializedRowGroupReader::new\(\) public [\#4331](https://github.com/apache/arrow-rs/pull/4331) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([burmecia](https://github.com/burmecia)) +- Don't split record across pages \(\#3680\) [\#4327](https://github.com/apache/arrow-rs/pull/4327) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- fix date conversion if timestamp below unixtimestamp [\#4323](https://github.com/apache/arrow-rs/pull/4323) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Short-circuit on exhausted page in skip\_records [\#4320](https://github.com/apache/arrow-rs/pull/4320) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Handle trailing padding when skipping repetition levels \(\#3911\) [\#4319](https://github.com/apache/arrow-rs/pull/4319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use `page_size` consistently, deprecate `pagesize` in parquet WriterProperties [\#4313](https://github.com/apache/arrow-rs/pull/4313) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add roundtrip tests for Decimal256 and fix issues \(\#4264\) [\#4311](https://github.com/apache/arrow-rs/pull/4311) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Expose page-level arrow reader API \(\#4298\) [\#4307](https://github.com/apache/arrow-rs/pull/4307) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Speed up i256 division and remainder operations [\#4303](https://github.com/apache/arrow-rs/pull/4303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat\(flight\): support int32\_to\_int32\_list\_map in sql infos [\#4300](https://github.com/apache/arrow-rs/pull/4300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- feat\(flight\): add helpers to handle `CommandGetCatalogs`, `CommandGetSchemas`, and `CommandGetTables` requests [\#4296](https://github.com/apache/arrow-rs/pull/4296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Improve docs and tests for `SqlInfoList [\#4293](https://github.com/apache/arrow-rs/pull/4293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- minor: fix arrow\_row docs.rs links [\#4292](https://github.com/apache/arrow-rs/pull/4292) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([roeap](https://github.com/roeap)) +- Update proc-macro2 requirement from =1.0.58 to =1.0.59 [\#4290](https://github.com/apache/arrow-rs/pull/4290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Improve `ArrowWriter` memory usage: Buffer Pages in ArrowWriter instead of RecordBatch \(\#3871\) [\#4280](https://github.com/apache/arrow-rs/pull/4280) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Minor: Add more docstrings in arrow-flight [\#4279](https://github.com/apache/arrow-rs/pull/4279) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Add `Debug` impls for `ArrowWriter` and `SerializedFileWriter` [\#4278](https://github.com/apache/arrow-rs/pull/4278) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Expose `RecordBatchWriter` to `arrow` crate [\#4277](https://github.com/apache/arrow-rs/pull/4277) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Update criterion requirement from 0.4 to 0.5 [\#4275](https://github.com/apache/arrow-rs/pull/4275) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add parquet-concat [\#4274](https://github.com/apache/arrow-rs/pull/4274) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Convert FixedSizeListArray to GenericListArray [\#4273](https://github.com/apache/arrow-rs/pull/4273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: support 'Decimal256' for parquet [\#4272](https://github.com/apache/arrow-rs/pull/4272) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Strip leading whitespace from flight\_sql\_client custom header values [\#4271](https://github.com/apache/arrow-rs/pull/4271) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mkmik](https://github.com/mkmik)) +- Add Append Column API \(\#4155\) [\#4269](https://github.com/apache/arrow-rs/pull/4269) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Derive Default for WriterProperties [\#4268](https://github.com/apache/arrow-rs/pull/4268) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet Reader/writer for fixed-size list arrays [\#4267](https://github.com/apache/arrow-rs/pull/4267) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dexterduck](https://github.com/dexterduck)) +- feat\(flight\): add sql-info helpers [\#4266](https://github.com/apache/arrow-rs/pull/4266) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Convert parquet metadata back to builders [\#4265](https://github.com/apache/arrow-rs/pull/4265) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add constructors for FixedSize array types \(\#3879\) [\#4263](https://github.com/apache/arrow-rs/pull/4263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Extract IPC ArrayReader struct [\#4259](https://github.com/apache/arrow-rs/pull/4259) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update object\_store requirement from 0.5 to 0.6 [\#4258](https://github.com/apache/arrow-rs/pull/4258) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Absolute Timestamps in CSV Schema Inference \(\#4131\) [\#4217](https://github.com/apache/arrow-rs/pull/4217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: cast between `Intervals` [\#4182](https://github.com/apache/arrow-rs/pull/4182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +## [40.0.0](https://github.com/apache/arrow-rs/tree/40.0.0) (2023-05-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/39.0.0...40.0.0) + +**Breaking changes:** + +- Prefetch page index \(\#4090\) [\#4216](https://github.com/apache/arrow-rs/pull/4216) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and P… [\#4206](https://github.com/apache/arrow-rs/pull/4206) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Remove powf\_scalar kernel [\#4187](https://github.com/apache/arrow-rs/pull/4187) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow format specification in cast [\#4169](https://github.com/apache/arrow-rs/pull/4169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([parthchandra](https://github.com/parthchandra)) + +**Implemented enhancements:** + +- ObjectStore with\_url Should Handle Path [\#4199](https://github.com/apache/arrow-rs/issues/4199) +- Support `Interval` +/- `Interval` [\#4178](https://github.com/apache/arrow-rs/issues/4178) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[parquet\] add compression info to `print_column_chunk_metadata()` [\#4172](https://github.com/apache/arrow-rs/issues/4172) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow cast to take in a format specification [\#4168](https://github.com/apache/arrow-rs/issues/4168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support extended pow arithmetic [\#4166](https://github.com/apache/arrow-rs/issues/4166) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Preload page index for async ParquetObjectReader [\#4090](https://github.com/apache/arrow-rs/issues/4090) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Subtracting `Timestamp` from `Timestamp` should produce a `Duration` \(not `Timestamp`\) [\#3964](https://github.com/apache/arrow-rs/issues/3964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Arrow Arithmetic: Subtract timestamps [\#4244](https://github.com/apache/arrow-rs/pull/4244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mr-brobot](https://github.com/mr-brobot)) +- Update proc-macro2 requirement from =1.0.57 to =1.0.58 [\#4236](https://github.com/apache/arrow-rs/pull/4236) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix Nightly Clippy Lints [\#4233](https://github.com/apache/arrow-rs/pull/4233) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: use all primitive types in test\_layouts [\#4229](https://github.com/apache/arrow-rs/pull/4229) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add close method to RecordBatchWriter trait [\#4228](https://github.com/apache/arrow-rs/pull/4228) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Update proc-macro2 requirement from =1.0.56 to =1.0.57 [\#4219](https://github.com/apache/arrow-rs/pull/4219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Feat docs [\#4215](https://github.com/apache/arrow-rs/pull/4215) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Folyd](https://github.com/Folyd)) +- feat: Support bitwise and boolean aggregate functions [\#4210](https://github.com/apache/arrow-rs/pull/4210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Document how to sort a RecordBatch [\#4204](https://github.com/apache/arrow-rs/pull/4204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix incorrect cast Timestamp with Timezone [\#4201](https://github.com/apache/arrow-rs/pull/4201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aprimadi](https://github.com/aprimadi)) +- Add implementation of `RecordBatchReader` for CSV reader [\#4195](https://github.com/apache/arrow-rs/pull/4195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Add Sliced ListArray test \(\#3748\) [\#4186](https://github.com/apache/arrow-rs/pull/4186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- refactor: simplify can\_cast\_types code. [\#4185](https://github.com/apache/arrow-rs/pull/4185) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Minor: support new types in struct\_builder.rs [\#4177](https://github.com/apache/arrow-rs/pull/4177) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- feat: add compression info to print\_column\_chunk\_metadata\(\) [\#4176](https://github.com/apache/arrow-rs/pull/4176) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +## [39.0.0](https://github.com/apache/arrow-rs/tree/39.0.0) (2023-05-05) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/38.0.0...39.0.0) + +**Breaking changes:** + +- Allow creating unbuffered streamreader [\#4165](https://github.com/apache/arrow-rs/pull/4165) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Cleanup ChunkReader \(\#4118\) [\#4156](https://github.com/apache/arrow-rs/pull/4156) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove Type from NativeIndex [\#4146](https://github.com/apache/arrow-rs/pull/4146) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Don't Duplicate Offset Index on RowGroupMetadata [\#4142](https://github.com/apache/arrow-rs/pull/4142) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Return BooleanBuffer from BooleanBufferBuilder [\#4140](https://github.com/apache/arrow-rs/pull/4140) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV schema inference \(\#4129\) \(\#4130\) [\#4133](https://github.com/apache/arrow-rs/pull/4133) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove deprecated parquet ArrowReader [\#4125](https://github.com/apache/arrow-rs/pull/4125) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- refactor: construct `StructArray` w/ `FieldRef` [\#4116](https://github.com/apache/arrow-rs/pull/4116) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Ignore Field Metadata in equals\_datatype for Dictionary, RunEndEncoded, Map and Union [\#4111](https://github.com/apache/arrow-rs/pull/4111) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add StructArray Constructors \(\#3879\) [\#4064](https://github.com/apache/arrow-rs/pull/4064) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Release 39.0.0 of arrow/arrow-flight/parquet/parquet-derive \(next release after 38.0.0\) [\#4170](https://github.com/apache/arrow-rs/issues/4170) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Fixed point decimal multiplication for DictionaryArray [\#4135](https://github.com/apache/arrow-rs/issues/4135) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Seek Requirement from CSV ReaderBuilder [\#4130](https://github.com/apache/arrow-rs/issues/4130) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Inconsistent CSV Inference and Parsing DateTime Handling [\#4129](https://github.com/apache/arrow-rs/issues/4129) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support accessing ipc Reader/Writer inner by reference [\#4121](https://github.com/apache/arrow-rs/issues/4121) +- Add Type Declarations for All Primitive Tensors and Buffer Builders [\#4112](https://github.com/apache/arrow-rs/issues/4112) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Interval + Timestamp` and `Interval + Date` in addition to `Timestamp + Interval` and `Interval + Date` [\#4094](https://github.com/apache/arrow-rs/issues/4094) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable setting FlightDescriptor on FlightDataEncoderBuilder [\#3855](https://github.com/apache/arrow-rs/issues/3855) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Parquet Page Index Reader Assumes Consecutive Offsets [\#4149](https://github.com/apache/arrow-rs/issues/4149) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Equality of nested data types [\#4110](https://github.com/apache/arrow-rs/issues/4110) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Improve Documentation of Parquet ChunkReader [\#4118](https://github.com/apache/arrow-rs/issues/4118) + +**Closed issues:** + +- add specific error log for empty JSON array [\#4105](https://github.com/apache/arrow-rs/issues/4105) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Prep for 39.0.0 [\#4171](https://github.com/apache/arrow-rs/pull/4171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Support Compression in parquet-fromcsv [\#4160](https://github.com/apache/arrow-rs/pull/4160) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([suxiaogang223](https://github.com/suxiaogang223)) +- feat: support bitwise shift left/right with scalars [\#4159](https://github.com/apache/arrow-rs/pull/4159) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Cleanup reading page index \(\#4149\) \(\#4090\) [\#4151](https://github.com/apache/arrow-rs/pull/4151) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: support `bitwise` shift left/right [\#4148](https://github.com/apache/arrow-rs/pull/4148) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Don't hardcode port in FlightSQL tests [\#4145](https://github.com/apache/arrow-rs/pull/4145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Better flight SQL example codes [\#4144](https://github.com/apache/arrow-rs/pull/4144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sundy-li](https://github.com/sundy-li)) +- chore: clean the code by using `as_primitive` [\#4143](https://github.com/apache/arrow-rs/pull/4143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- docs: fix the wrong ln command in CONTRIBUTING.md [\#4139](https://github.com/apache/arrow-rs/pull/4139) ([SteveLauC](https://github.com/SteveLauC)) +- Infer Float64 for JSON Numerics Beyond Bounds of i64 [\#4138](https://github.com/apache/arrow-rs/pull/4138) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([SteveLauC](https://github.com/SteveLauC)) +- Support fixed point multiplication for DictionaryArray of Decimals [\#4136](https://github.com/apache/arrow-rs/pull/4136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make arrow\_json::ReaderBuilder method names consistent [\#4128](https://github.com/apache/arrow-rs/pull/4128) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add get\_{ref, mut} to arrow\_ipc Reader and Writer [\#4122](https://github.com/apache/arrow-rs/pull/4122) ([sticnarf](https://github.com/sticnarf)) +- feat: support `Interval` + `Timestamp` and `Interval` + `Date` [\#4117](https://github.com/apache/arrow-rs/pull/4117) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support NullArray in JSON Reader [\#4114](https://github.com/apache/arrow-rs/pull/4114) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jiangzhx](https://github.com/jiangzhx)) +- Add Type Declarations for All Primitive Tensors and Buffer Builders [\#4113](https://github.com/apache/arrow-rs/pull/4113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Update regex-syntax requirement from 0.6.27 to 0.7.1 [\#4107](https://github.com/apache/arrow-rs/pull/4107) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: set FlightDescriptor on FlightDataEncoderBuilder [\#4101](https://github.com/apache/arrow-rs/pull/4101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Weijun-H](https://github.com/Weijun-H)) +- optimize cast for same decimal type and same scale [\#4088](https://github.com/apache/arrow-rs/pull/4088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) + +## [38.0.0](https://github.com/apache/arrow-rs/tree/38.0.0) (2023-04-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/37.0.0...38.0.0) + +**Breaking changes:** + +- Remove DataType from PrimitiveArray constructors [\#4098](https://github.com/apache/arrow-rs/pull/4098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use Into\\> for PrimitiveArray::with\_timezone [\#4097](https://github.com/apache/arrow-rs/pull/4097) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Store StructArray entries in MapArray [\#4085](https://github.com/apache/arrow-rs/pull/4085) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add DictionaryArray Constructors \(\#3879\) [\#4068](https://github.com/apache/arrow-rs/pull/4068) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Relax JSON schema inference generics [\#4063](https://github.com/apache/arrow-rs/pull/4063) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove ArrayData from Array \(\#3880\) [\#4061](https://github.com/apache/arrow-rs/pull/4061) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add CommandGetXdbcTypeInfo to Flight SQL Server [\#4055](https://github.com/apache/arrow-rs/pull/4055) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([c-thiel](https://github.com/c-thiel)) +- Remove old JSON Reader and Decoder \(\#3610\) [\#4052](https://github.com/apache/arrow-rs/pull/4052) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use BufRead for JSON Schema Inference [\#4041](https://github.com/apache/arrow-rs/pull/4041) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([WenyXu](https://github.com/WenyXu)) + +**Implemented enhancements:** + +- Support dyn\_compare\_scalar for Decimal256 [\#4083](https://github.com/apache/arrow-rs/issues/4083) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Better JSON Reader Error Messages [\#4076](https://github.com/apache/arrow-rs/issues/4076) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Additional data type groups [\#4056](https://github.com/apache/arrow-rs/issues/4056) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Async JSON reader [\#4043](https://github.com/apache/arrow-rs/issues/4043) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Field::contains Should Recurse into DataType [\#4029](https://github.com/apache/arrow-rs/issues/4029) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Prevent UnionArray with Repeated Type IDs [\#3982](https://github.com/apache/arrow-rs/issues/3982) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Timestamp` `+`/`-` `Interval` types [\#3963](https://github.com/apache/arrow-rs/issues/3963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- First-Class Array Abstractions [\#3880](https://github.com/apache/arrow-rs/issues/3880) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Update readme to remove reference to Jira [\#4091](https://github.com/apache/arrow-rs/issues/4091) +- OffsetBuffer::new Rejects 0 Offsets [\#4066](https://github.com/apache/arrow-rs/issues/4066) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet AsyncArrowWriter not shutting down inner async writer. [\#4058](https://github.com/apache/arrow-rs/issues/4058) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Flight SQL Server missing command type.googleapis.com/arrow.flight.protocol.sql.CommandGetXdbcTypeInfo [\#4054](https://github.com/apache/arrow-rs/issues/4054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- RawJsonReader Errors with Empty Schema [\#4053](https://github.com/apache/arrow-rs/issues/4053) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RawJsonReader Integer Truncation [\#4049](https://github.com/apache/arrow-rs/issues/4049) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Sparse UnionArray Equality Incorrect Offset Handling [\#4044](https://github.com/apache/arrow-rs/issues/4044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Write blog about improvements in JSON and CSV processing [\#4062](https://github.com/apache/arrow-rs/issues/4062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Parquet reader of Int96 columns and coercion to timestamps [\#4075](https://github.com/apache/arrow-rs/issues/4075) +- Serializing timestamp from int \(json raw decoder\) [\#4069](https://github.com/apache/arrow-rs/issues/4069) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting to/from Interval and Duration [\#3998](https://github.com/apache/arrow-rs/issues/3998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix Docs Typos [\#4100](https://github.com/apache/arrow-rs/pull/4100) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rnarkk](https://github.com/rnarkk)) +- Update tonic-build requirement from =0.9.1 to =0.9.2 [\#4099](https://github.com/apache/arrow-rs/pull/4099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Increase minimum chrono version to 0.4.24 [\#4093](https://github.com/apache/arrow-rs/pull/4093) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Simplify reference to GitHub issues [\#4092](https://github.com/apache/arrow-rs/pull/4092) ([bkmgit](https://github.com/bkmgit)) +- \[Minor\]: Add `Hash` trait to SortOptions. [\#4089](https://github.com/apache/arrow-rs/pull/4089) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mustafasrepo](https://github.com/mustafasrepo)) +- Include byte offsets in parquet-layout [\#4086](https://github.com/apache/arrow-rs/pull/4086) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: Support dyn\_compare\_scalar for Decimal256 [\#4084](https://github.com/apache/arrow-rs/pull/4084) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add ByteArray constructors \(\#3879\) [\#4081](https://github.com/apache/arrow-rs/pull/4081) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.8 to =0.11.9 [\#4080](https://github.com/apache/arrow-rs/pull/4080) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Improve JSON decoder errors \(\#4076\) [\#4079](https://github.com/apache/arrow-rs/pull/4079) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix Timestamp Numeric Truncation in JSON Reader [\#4074](https://github.com/apache/arrow-rs/pull/4074) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Serialize numeric to tape \(\#4069\) [\#4073](https://github.com/apache/arrow-rs/pull/4073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Prevent UnionArray with Repeated Type IDs [\#4070](https://github.com/apache/arrow-rs/pull/4070) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Add PrimitiveArray::try\_new \(\#3879\) [\#4067](https://github.com/apache/arrow-rs/pull/4067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ListArray Constructors \(\#3879\) [\#4065](https://github.com/apache/arrow-rs/pull/4065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Shutdown parquet async writer [\#4059](https://github.com/apache/arrow-rs/pull/4059) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kindly](https://github.com/kindly)) +- feat: additional data type groups [\#4057](https://github.com/apache/arrow-rs/pull/4057) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Fix precision loss in Raw JSON decoder \(\#4049\) [\#4051](https://github.com/apache/arrow-rs/pull/4051) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use lexical\_core in CSV and JSON parser \(~25% faster\) [\#4050](https://github.com/apache/arrow-rs/pull/4050) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add offsets accessors to variable length arrays \(\#3879\) [\#4048](https://github.com/apache/arrow-rs/pull/4048) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document Async decoder usage \(\#4043\) \(\#78\) [\#4046](https://github.com/apache/arrow-rs/pull/4046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix sparse union array equality \(\#4044\) [\#4045](https://github.com/apache/arrow-rs/pull/4045) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: DataType::contains support nested type [\#4042](https://github.com/apache/arrow-rs/pull/4042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat: Support Timestamp +/- Interval types [\#4038](https://github.com/apache/arrow-rs/pull/4038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Fix object\_store CI [\#4037](https://github.com/apache/arrow-rs/pull/4037) ([tustvold](https://github.com/tustvold)) +- feat: cast from/to interval and duration [\#4020](https://github.com/apache/arrow-rs/pull/4020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) + +## [37.0.0](https://github.com/apache/arrow-rs/tree/37.0.0) (2023-04-07) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/36.0.0...37.0.0) + +**Breaking changes:** + +- Fix timestamp handling in cast kernel \(\#1936\) \(\#4033\) [\#4034](https://github.com/apache/arrow-rs/pull/4034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update tonic 0.9.1 [\#4011](https://github.com/apache/arrow-rs/pull/4011) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Use FieldRef in DataType \(\#3955\) [\#3983](https://github.com/apache/arrow-rs/pull/3983) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Store Timezone as Arc\ [\#3976](https://github.com/apache/arrow-rs/pull/3976) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Panic instead of discarding nulls converting StructArray to RecordBatch - \(\#3951\) [\#3953](https://github.com/apache/arrow-rs/pull/3953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix\(flight\_sql\): PreparedStatement has no token for auth. [\#3948](https://github.com/apache/arrow-rs/pull/3948) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([youngsofun](https://github.com/youngsofun)) +- Add Strongly Typed Array Slice \(\#3929\) [\#3930](https://github.com/apache/arrow-rs/pull/3930) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Zero-Copy Conversion between Vec and MutableBuffer [\#3920](https://github.com/apache/arrow-rs/pull/3920) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Support Decimals cast to Utf8/LargeUtf [\#3991](https://github.com/apache/arrow-rs/issues/3991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Date32/Date64 minus Interval [\#3962](https://github.com/apache/arrow-rs/issues/3962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Reduce Cloning of Field [\#3955](https://github.com/apache/arrow-rs/issues/3955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Deserializing Serde DataTypes to Arrow [\#3949](https://github.com/apache/arrow-rs/issues/3949) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add multiply\_fixed\_point [\#3946](https://github.com/apache/arrow-rs/issues/3946) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Strongly Typed Array Slicing [\#3929](https://github.com/apache/arrow-rs/issues/3929) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make it easier to match FlightSQL messages [\#3874](https://github.com/apache/arrow-rs/issues/3874) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Casting Between Binary / LargeBinary and FixedSizeBinary [\#3826](https://github.com/apache/arrow-rs/issues/3826) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Incorrect Overflow Casting String to Timestamp [\#4033](https://github.com/apache/arrow-rs/issues/4033) +- f16::ZERO and f16::ONE are mixed up [\#4016](https://github.com/apache/arrow-rs/issues/4016) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Handle overflow precision when casting from integer to decimal [\#3995](https://github.com/apache/arrow-rs/issues/3995) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PrimitiveDictionaryBuilder.finish should use actual value type [\#3971](https://github.com/apache/arrow-rs/issues/3971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RecordBatch From StructArray Silently Discards Nulls [\#3952](https://github.com/apache/arrow-rs/issues/3952) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- I256 Checked Subtraction Overflows for i256::MINUS\_ONE [\#3942](https://github.com/apache/arrow-rs/issues/3942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- I256 Checked Multiply Overflows for i256::MIN [\#3941](https://github.com/apache/arrow-rs/issues/3941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Remove non-existent `js` feature from README [\#4000](https://github.com/apache/arrow-rs/issues/4000) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support take on MapArray [\#3875](https://github.com/apache/arrow-rs/issues/3875) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Prep for 37.0.0 [\#4031](https://github.com/apache/arrow-rs/pull/4031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Add RecordBatch::with\_schema [\#4028](https://github.com/apache/arrow-rs/pull/4028) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Only require compatible batch schema in ArrowWriter [\#4027](https://github.com/apache/arrow-rs/pull/4027) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add Fields::contains [\#4026](https://github.com/apache/arrow-rs/pull/4026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add methods "is\_positive" and "signum" to i256 [\#4024](https://github.com/apache/arrow-rs/pull/4024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Deprecate Array::data \(\#3880\) [\#4019](https://github.com/apache/arrow-rs/pull/4019) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add tests for ArrowNativeTypeOp [\#4018](https://github.com/apache/arrow-rs/pull/4018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- fix: f16::ZERO and f16::ONE are mixed up [\#4017](https://github.com/apache/arrow-rs/pull/4017) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Minor: Float16Tensor [\#4013](https://github.com/apache/arrow-rs/pull/4013) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add FlightSQL module docs and links to `arrow-flight` crates [\#4012](https://github.com/apache/arrow-rs/pull/4012) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update proc-macro2 requirement from =1.0.54 to =1.0.56 [\#4008](https://github.com/apache/arrow-rs/pull/4008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Cleanup Primitive take [\#4006](https://github.com/apache/arrow-rs/pull/4006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate combine\_option\_bitmap [\#4005](https://github.com/apache/arrow-rs/pull/4005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add tests for BooleanBuffer [\#4004](https://github.com/apache/arrow-rs/pull/4004) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- feat: support to read/write customized metadata in ipc files [\#4003](https://github.com/apache/arrow-rs/pull/4003) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([framlog](https://github.com/framlog)) +- Cleanup more uses of Array::data \(\#3880\) [\#4002](https://github.com/apache/arrow-rs/pull/4002) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove js feature from README [\#4001](https://github.com/apache/arrow-rs/pull/4001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([akazukin5151](https://github.com/akazukin5151)) +- feat: add the implementation BitXor to BooleanBuffer [\#3997](https://github.com/apache/arrow-rs/pull/3997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Handle precision overflow when casting from integer to decimal [\#3996](https://github.com/apache/arrow-rs/pull/3996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support CAST from Decimal datatype to String [\#3994](https://github.com/apache/arrow-rs/pull/3994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Add Field Constructors for Complex Fields [\#3992](https://github.com/apache/arrow-rs/pull/3992) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: remove unused type parameters. [\#3986](https://github.com/apache/arrow-rs/pull/3986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([youngsofun](https://github.com/youngsofun)) +- Add UnionFields \(\#3955\) [\#3981](https://github.com/apache/arrow-rs/pull/3981) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup Fields Serde [\#3980](https://github.com/apache/arrow-rs/pull/3980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support Rust structures --\> `RecordBatch` by adding `Serde` support to `RawDecoder` \(\#3949\) [\#3979](https://github.com/apache/arrow-rs/pull/3979) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Convert string\_to\_timestamp\_nanos to doctest [\#3978](https://github.com/apache/arrow-rs/pull/3978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix documentation of string\_to\_timestamp\_nanos [\#3977](https://github.com/apache/arrow-rs/pull/3977) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([byteink](https://github.com/byteink)) +- add Date32/Date64 support to subtract\_dyn [\#3974](https://github.com/apache/arrow-rs/pull/3974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([SinanGncgl](https://github.com/SinanGncgl)) +- PrimitiveDictionaryBuilder.finish should use actual value type [\#3972](https://github.com/apache/arrow-rs/pull/3972) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.53 to =1.0.54 [\#3968](https://github.com/apache/arrow-rs/pull/3968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Async writer tweaks [\#3967](https://github.com/apache/arrow-rs/pull/3967) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix reading ipc files with unordered projections [\#3966](https://github.com/apache/arrow-rs/pull/3966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([framlog](https://github.com/framlog)) +- Add Fields abstraction \(\#3955\) [\#3965](https://github.com/apache/arrow-rs/pull/3965) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: cast between `Binary`/`LargeBinary` and `FixedSizeBinary` [\#3961](https://github.com/apache/arrow-rs/pull/3961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat: support async writer \(\#1269\) [\#3957](https://github.com/apache/arrow-rs/pull/3957) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ShiKaiWi](https://github.com/ShiKaiWi)) +- Add ListBuilder::append\_value \(\#3949\) [\#3954](https://github.com/apache/arrow-rs/pull/3954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve array builder documentation \(\#3949\) [\#3951](https://github.com/apache/arrow-rs/pull/3951) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster i256 parsing [\#3950](https://github.com/apache/arrow-rs/pull/3950) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add multiply\_fixed\_point [\#3945](https://github.com/apache/arrow-rs/pull/3945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat: enable metadata import/export through C data interface [\#3944](https://github.com/apache/arrow-rs/pull/3944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix checked i256 arithmetic \(\#3942\) \(\#3941\) [\#3943](https://github.com/apache/arrow-rs/pull/3943) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Avoid memory copies in take\_list [\#3940](https://github.com/apache/arrow-rs/pull/3940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster decimal parsing \(30-60%\) [\#3939](https://github.com/apache/arrow-rs/pull/3939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Fix: FlightSqlClient panic when execute\_update. [\#3938](https://github.com/apache/arrow-rs/pull/3938) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([youngsofun](https://github.com/youngsofun)) +- Cleanup row count handling in JSON writer [\#3934](https://github.com/apache/arrow-rs/pull/3934) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add typed buffers to UnionArray \(\#3880\) [\#3933](https://github.com/apache/arrow-rs/pull/3933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add take for MapArray [\#3925](https://github.com/apache/arrow-rs/pull/3925) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Deprecate Array::data\_ref \(\#3880\) [\#3923](https://github.com/apache/arrow-rs/pull/3923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Zero-copy conversion from Vec to PrimitiveArray [\#3917](https://github.com/apache/arrow-rs/pull/3917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Add Commands enum to decode prost messages to strong type [\#3887](https://github.com/apache/arrow-rs/pull/3887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([stuartcarnie](https://github.com/stuartcarnie)) +## [36.0.0](https://github.com/apache/arrow-rs/tree/36.0.0) (2023-03-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/35.0.0...36.0.0) + +**Breaking changes:** + +- Use dyn Array in sort kernels [\#3931](https://github.com/apache/arrow-rs/pull/3931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Enforce struct nullability in JSON raw reader \(\#3900\) \(\#3904\) [\#3906](https://github.com/apache/arrow-rs/pull/3906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return ScalarBuffer from PrimitiveArray::values \(\#3879\) [\#3896](https://github.com/apache/arrow-rs/pull/3896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use BooleanBuffer in BooleanArray \(\#3879\) [\#3895](https://github.com/apache/arrow-rs/pull/3895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Seal ArrowPrimitiveType [\#3882](https://github.com/apache/arrow-rs/pull/3882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support compression levels [\#3847](https://github.com/apache/arrow-rs/pull/3847) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([spebern](https://github.com/spebern)) + +**Implemented enhancements:** + +- Improve speed of parsing string to Times [\#3919](https://github.com/apache/arrow-rs/issues/3919) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- feat: add comparison/sort support for Float16 [\#3914](https://github.com/apache/arrow-rs/issues/3914) +- Pinned version in arrow-flight's build-dependencies are causing conflicts [\#3876](https://github.com/apache/arrow-rs/issues/3876) +- Add compression options \(levels\) [\#3844](https://github.com/apache/arrow-rs/issues/3844) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use Unsigned Integer for Fixed Size DataType [\#3815](https://github.com/apache/arrow-rs/issues/3815) +- Common trait for RecordBatch and StructArray [\#3764](https://github.com/apache/arrow-rs/issues/3764) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow precision loss on multiplying decimal arrays [\#3689](https://github.com/apache/arrow-rs/issues/3689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Raw JSON Reader Allows Non-Nullable Struct Children to Contain Nulls [\#3904](https://github.com/apache/arrow-rs/issues/3904) +- Nullable field with nested not nullable map in json [\#3900](https://github.com/apache/arrow-rs/issues/3900) +- parquet\_derive doesn't support Vec\ [\#3864](https://github.com/apache/arrow-rs/issues/3864) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[REGRESSION\] Parsing timestamps with lower case time separator [\#3863](https://github.com/apache/arrow-rs/issues/3863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[REGRESSION\] Parsing timestamps with leap seconds [\#3861](https://github.com/apache/arrow-rs/issues/3861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[REGRESSION\] Parsing timestamps with fractional seconds / microseconds / milliseconds / nanoseconds [\#3859](https://github.com/apache/arrow-rs/issues/3859) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CSV Reader Doesn't set Timezone [\#3841](https://github.com/apache/arrow-rs/issues/3841) +- PyArrowConvert Leaks Memory [\#3683](https://github.com/apache/arrow-rs/issues/3683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Derive RunArray Clone [\#3932](https://github.com/apache/arrow-rs/pull/3932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move protoc generation to binary crate, unpin prost/tonic build \(\#3876\) [\#3927](https://github.com/apache/arrow-rs/pull/3927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix JSON Temporal Encoding of Multiple Batches [\#3924](https://github.com/apache/arrow-rs/pull/3924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- Cleanup uses of Array::data\_ref \(\#3880\) [\#3918](https://github.com/apache/arrow-rs/pull/3918) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support microsecond and nanosecond in interval parsing [\#3916](https://github.com/apache/arrow-rs/pull/3916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add comparison/sort support for Float16 [\#3915](https://github.com/apache/arrow-rs/pull/3915) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add AsArray trait for more ergonomic downcasting [\#3912](https://github.com/apache/arrow-rs/pull/3912) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add OffsetBuffer::new [\#3910](https://github.com/apache/arrow-rs/pull/3910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveArray::new \(\#3879\) [\#3909](https://github.com/apache/arrow-rs/pull/3909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support timezones in CSV reader \(\#3841\) [\#3908](https://github.com/apache/arrow-rs/pull/3908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve ScalarBuffer debug output [\#3907](https://github.com/apache/arrow-rs/pull/3907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.52 to =1.0.53 [\#3905](https://github.com/apache/arrow-rs/pull/3905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Re-export parquet compression level structs [\#3903](https://github.com/apache/arrow-rs/pull/3903) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix parsing timestamps of exactly 32 characters [\#3902](https://github.com/apache/arrow-rs/pull/3902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add iterators to BooleanBuffer and NullBuffer [\#3901](https://github.com/apache/arrow-rs/pull/3901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Array equality for &dyn Array \(\#3880\) [\#3899](https://github.com/apache/arrow-rs/pull/3899) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::new \(\#3879\) [\#3898](https://github.com/apache/arrow-rs/pull/3898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Revert structured ArrayData \(\#3877\) [\#3894](https://github.com/apache/arrow-rs/pull/3894) ([tustvold](https://github.com/tustvold)) +- Fix pyarrow memory leak \(\#3683\) [\#3893](https://github.com/apache/arrow-rs/pull/3893) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add examples for `ListBuilder` and `GenericListBuilder` [\#3891](https://github.com/apache/arrow-rs/pull/3891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update syn requirement from 1.0 to 2.0 [\#3890](https://github.com/apache/arrow-rs/pull/3890) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Use of `mul_checked` to avoid silent overflow in interval arithmetic [\#3886](https://github.com/apache/arrow-rs/pull/3886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Flesh out NullBuffer abstraction \(\#3880\) [\#3885](https://github.com/apache/arrow-rs/pull/3885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Bit Operations for i256 [\#3884](https://github.com/apache/arrow-rs/pull/3884) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Flatten arrow\_buffer [\#3883](https://github.com/apache/arrow-rs/pull/3883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Array::to\_data and Array::nulls \(\#3880\) [\#3881](https://github.com/apache/arrow-rs/pull/3881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Added support for byte vectors and slices to parquet\_derive \(\#3864\) [\#3878](https://github.com/apache/arrow-rs/pull/3878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([waymost](https://github.com/waymost)) +- chore: remove LevelDecoder [\#3872](https://github.com/apache/arrow-rs/pull/3872) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Parse timestamps with leap seconds \(\#3861\) [\#3862](https://github.com/apache/arrow-rs/pull/3862) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster time parsing \(~93% faster\) [\#3860](https://github.com/apache/arrow-rs/pull/3860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Parse timestamps with arbitrary seconds fraction [\#3858](https://github.com/apache/arrow-rs/pull/3858) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BitIterator [\#3856](https://github.com/apache/arrow-rs/pull/3856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve decimal parsing performance [\#3854](https://github.com/apache/arrow-rs/pull/3854) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Update proc-macro2 requirement from =1.0.51 to =1.0.52 [\#3853](https://github.com/apache/arrow-rs/pull/3853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update bitflags requirement from 1.2.1 to 2.0.0 [\#3852](https://github.com/apache/arrow-rs/pull/3852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add offset pushdown to parquet [\#3848](https://github.com/apache/arrow-rs/pull/3848) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add timezone support to JSON reader [\#3845](https://github.com/apache/arrow-rs/pull/3845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow precision loss on multiplying decimal arrays [\#3690](https://github.com/apache/arrow-rs/pull/3690) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [35.0.0](https://github.com/apache/arrow-rs/tree/35.0.0) (2023-03-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/34.0.0...35.0.0) + +**Breaking changes:** + +- Add RunEndBuffer \(\#1799\) [\#3817](https://github.com/apache/arrow-rs/pull/3817) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Restrict DictionaryArray to ArrowDictionaryKeyType [\#3813](https://github.com/apache/arrow-rs/pull/3813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- refactor: assorted `FlightSqlServiceClient` improvements [\#3788](https://github.com/apache/arrow-rs/pull/3788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- minor: make Parquet CLI input args consistent [\#3786](https://github.com/apache/arrow-rs/pull/3786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XinyuZeng](https://github.com/XinyuZeng)) +- Return Buffers from ArrayData::buffers instead of slice \(\#1799\) [\#3783](https://github.com/apache/arrow-rs/pull/3783) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use NullBuffer in ArrayData \(\#3775\) [\#3778](https://github.com/apache/arrow-rs/pull/3778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Support timestamp/time and date types in json decoder [\#3834](https://github.com/apache/arrow-rs/issues/3834) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support decoding decimals in new raw json decoder [\#3819](https://github.com/apache/arrow-rs/issues/3819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Timezone Aware Timestamp Parsing [\#3794](https://github.com/apache/arrow-rs/issues/3794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Preallocate buffers for FixedSizeBinary array creation [\#3792](https://github.com/apache/arrow-rs/issues/3792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make Parquet CLI args consistent [\#3785](https://github.com/apache/arrow-rs/issues/3785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Creates PrimitiveDictionaryBuilder from provided keys and values builders [\#3776](https://github.com/apache/arrow-rs/issues/3776) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use NullBuffer in ArrayData [\#3775](https://github.com/apache/arrow-rs/issues/3775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support unary\_dict\_mut in arth [\#3710](https://github.com/apache/arrow-rs/issues/3710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support cast \<\> String to interval [\#3643](https://github.com/apache/arrow-rs/issues/3643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Zero-Copy Conversion from Vec to/from MutableBuffer [\#3516](https://github.com/apache/arrow-rs/issues/3516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Timestamp Unit Casts are Unchecked [\#3833](https://github.com/apache/arrow-rs/issues/3833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- regexp\_match skips first match when returning match [\#3803](https://github.com/apache/arrow-rs/issues/3803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast to timestamp with time zone returns timestamp [\#3800](https://github.com/apache/arrow-rs/issues/3800) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema-level metadata is not encoded in Flight responses [\#3779](https://github.com/apache/arrow-rs/issues/3779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Closed issues:** + +- FlightSQL CLI client: simple test [\#3814](https://github.com/apache/arrow-rs/issues/3814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- refactor: timestamp overflow check [\#3840](https://github.com/apache/arrow-rs/pull/3840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Prep for 35.0.0 [\#3836](https://github.com/apache/arrow-rs/pull/3836) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Support timestamp/time and date json decoding [\#3835](https://github.com/apache/arrow-rs/pull/3835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Make dictionary preservation optional in row encoding [\#3831](https://github.com/apache/arrow-rs/pull/3831) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move prettyprint to arrow-cast [\#3828](https://github.com/apache/arrow-rs/pull/3828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Support decoding decimals in raw decoder [\#3820](https://github.com/apache/arrow-rs/pull/3820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Add ArrayDataLayout, port validation \(\#1799\) [\#3818](https://github.com/apache/arrow-rs/pull/3818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- test: add test for FlightSQL CLI client [\#3816](https://github.com/apache/arrow-rs/pull/3816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Add regexp\_match docs [\#3812](https://github.com/apache/arrow-rs/pull/3812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: Ensure Flight schema includes parent metadata [\#3811](https://github.com/apache/arrow-rs/pull/3811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([stuartcarnie](https://github.com/stuartcarnie)) +- fix: regexp\_match skips first match [\#3807](https://github.com/apache/arrow-rs/pull/3807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- fix: change uft8 to timestamp with timezone [\#3806](https://github.com/apache/arrow-rs/pull/3806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support reading decimal arrays from json [\#3805](https://github.com/apache/arrow-rs/pull/3805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Add unary\_dict\_mut [\#3804](https://github.com/apache/arrow-rs/pull/3804) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Faster timestamp parsing \(~70-90% faster\) [\#3801](https://github.com/apache/arrow-rs/pull/3801) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add concat\_elements\_bytes [\#3798](https://github.com/apache/arrow-rs/pull/3798) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Timezone aware timestamp parsing \(\#3794\) [\#3795](https://github.com/apache/arrow-rs/pull/3795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Preallocate buffers for FixedSizeBinary array creation [\#3793](https://github.com/apache/arrow-rs/pull/3793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- feat: simple flight sql CLI client [\#3789](https://github.com/apache/arrow-rs/pull/3789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Creates PrimitiveDictionaryBuilder from provided keys and values builders [\#3777](https://github.com/apache/arrow-rs/pull/3777) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- ArrayData Enumeration for Remaining Layouts [\#3769](https://github.com/apache/arrow-rs/pull/3769) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.7 to =0.11.8 [\#3767](https://github.com/apache/arrow-rs/pull/3767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Implement concat\_elements\_dyn kernel [\#3763](https://github.com/apache/arrow-rs/pull/3763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support for casting `Utf8` and `LargeUtf8` --\> `Interval` [\#3762](https://github.com/apache/arrow-rs/pull/3762) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- into\_inner\(\) for CSV Writer [\#3759](https://github.com/apache/arrow-rs/pull/3759) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Zero-copy Vec conversion \(\#3516\) \(\#1176\) [\#3756](https://github.com/apache/arrow-rs/pull/3756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ArrayData Enumeration for Primitive, Binary and UTF8 [\#3749](https://github.com/apache/arrow-rs/pull/3749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `into_primitive_dict_builder` to `DictionaryArray` [\#3715](https://github.com/apache/arrow-rs/pull/3715) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [34.0.0](https://github.com/apache/arrow-rs/tree/34.0.0) (2023-02-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/33.0.0...34.0.0) + +**Breaking changes:** + +- Infer 2020-03-19 00:00:00 as timestamp not Date64 in CSV \(\#3744\) [\#3746](https://github.com/apache/arrow-rs/pull/3746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement fallible streams for `FlightClient::do_put` [\#3464](https://github.com/apache/arrow-rs/pull/3464) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) + +**Implemented enhancements:** + +- Support casting string to timestamp with microsecond resolution [\#3751](https://github.com/apache/arrow-rs/issues/3751) +- Add datatime/interval/duration into comparison kernels [\#3729](https://github.com/apache/arrow-rs/issues/3729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ! \(not\) operator overload for SortOptions [\#3726](https://github.com/apache/arrow-rs/issues/3726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: convert Bytes to ByteArray directly [\#3719](https://github.com/apache/arrow-rs/issues/3719) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement simple RecordBatchReader [\#3704](https://github.com/apache/arrow-rs/issues/3704) +- Is possible to implement GenericListArray::from\_iter ? [\#3702](https://github.com/apache/arrow-rs/issues/3702) +- `take_run` improvements [\#3701](https://github.com/apache/arrow-rs/issues/3701) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `as_mut_any` in Array trait [\#3655](https://github.com/apache/arrow-rs/issues/3655) +- `Array` --\> `Display` formatter that supports more options and is configurable [\#3638](https://github.com/apache/arrow-rs/issues/3638) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-csv: support decimal256 [\#3474](https://github.com/apache/arrow-rs/issues/3474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- CSV reader infers Date64 type for fields like "2020-03-19 00:00:00" that it can't parse to Date64 [\#3744](https://github.com/apache/arrow-rs/issues/3744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Update to 34.0.0 and update changelog [\#3757](https://github.com/apache/arrow-rs/pull/3757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Update MIRI for split crates \(\#2594\) [\#3754](https://github.com/apache/arrow-rs/pull/3754) ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.6 to =0.11.7 [\#3753](https://github.com/apache/arrow-rs/pull/3753) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Enable casting of string to timestamp with microsecond resolution [\#3752](https://github.com/apache/arrow-rs/pull/3752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Use Typed Buffers in Arrays \(\#1811\) \(\#1176\) [\#3743](https://github.com/apache/arrow-rs/pull/3743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup arithmetic kernel type constraints [\#3739](https://github.com/apache/arrow-rs/pull/3739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Make dictionary kernels optional for comparison benchmark [\#3738](https://github.com/apache/arrow-rs/pull/3738) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support String Coercion in Raw JSON Reader [\#3736](https://github.com/apache/arrow-rs/pull/3736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rguerreiromsft](https://github.com/rguerreiromsft)) +- replace for loop by try\_for\_each [\#3734](https://github.com/apache/arrow-rs/pull/3734) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([suxiaogang223](https://github.com/suxiaogang223)) +- feat: implement generic record batch reader [\#3733](https://github.com/apache/arrow-rs/pull/3733) ([wjones127](https://github.com/wjones127)) +- \[minor\] fix doc test fail [\#3732](https://github.com/apache/arrow-rs/pull/3732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Add datetime/interval/duration into dyn scalar comparison [\#3730](https://github.com/apache/arrow-rs/pull/3730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Using Borrow\ on infer\_json\_schema\_from\_iterator [\#3728](https://github.com/apache/arrow-rs/pull/3728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rguerreiromsft](https://github.com/rguerreiromsft)) +- Not operator overload for SortOptions [\#3727](https://github.com/apache/arrow-rs/pull/3727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([berkaysynnada](https://github.com/berkaysynnada)) +- fix: encoding batch with no columns [\#3724](https://github.com/apache/arrow-rs/pull/3724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([wangrunji0408](https://github.com/wangrunji0408)) +- feat: impl `Ord`/`PartialOrd` for `SortOptions` [\#3723](https://github.com/apache/arrow-rs/pull/3723) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add From\ for ByteArray [\#3720](https://github.com/apache/arrow-rs/pull/3720) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Deprecate old JSON reader \(\#3610\) [\#3718](https://github.com/apache/arrow-rs/pull/3718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add pretty format with options [\#3717](https://github.com/apache/arrow-rs/pull/3717) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unreachable decimal take [\#3716](https://github.com/apache/arrow-rs/pull/3716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Feat: arrow csv decimal256 [\#3711](https://github.com/apache/arrow-rs/pull/3711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([suxiaogang223](https://github.com/suxiaogang223)) +- perf: `take_run` improvements [\#3705](https://github.com/apache/arrow-rs/pull/3705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add raw MapArrayReader [\#3703](https://github.com/apache/arrow-rs/pull/3703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Sort kernel for `RunArray` [\#3695](https://github.com/apache/arrow-rs/pull/3695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- perf: Remove sorting to yield sorted\_rank [\#3693](https://github.com/apache/arrow-rs/pull/3693) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- fix: Handle sliced array in run array iterator [\#3681](https://github.com/apache/arrow-rs/pull/3681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +## [33.0.0](https://github.com/apache/arrow-rs/tree/33.0.0) (2023-02-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/32.0.0...33.0.0) + +**Breaking changes:** + +- Use ArrayFormatter in Cast Kernel [\#3668](https://github.com/apache/arrow-rs/pull/3668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use dyn Array in cast kernels [\#3667](https://github.com/apache/arrow-rs/pull/3667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return references from FixedSizeListArray and MapArray [\#3652](https://github.com/apache/arrow-rs/pull/3652) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Lazy array display \(\#3638\) [\#3647](https://github.com/apache/arrow-rs/pull/3647) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use array\_value\_to\_string in arrow-csv [\#3514](https://github.com/apache/arrow-rs/pull/3514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) + +**Implemented enhancements:** + +- Support UTF8 cast to Timestamp with timezone [\#3664](https://github.com/apache/arrow-rs/issues/3664) +- Add modulus\_dyn and modulus\_scalar\_dyn [\#3648](https://github.com/apache/arrow-rs/issues/3648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- A trait for append\_value and append\_null on ArrayBuilders [\#3644](https://github.com/apache/arrow-rs/issues/3644) +- Improve error message "batches\[0\] schema is different with argument schema" [\#3628](https://github.com/apache/arrow-rs/issues/3628) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Specified version of helper function to cast binary to string [\#3623](https://github.com/apache/arrow-rs/issues/3623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting generic binary to generic string [\#3606](https://github.com/apache/arrow-rs/issues/3606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `array_value_to_string` in `arrow-csv` [\#3483](https://github.com/apache/arrow-rs/issues/3483) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- ArrowArray::try\_from\_raw Misleading Signature [\#3684](https://github.com/apache/arrow-rs/issues/3684) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PyArrowConvert Leaks Memory [\#3683](https://github.com/apache/arrow-rs/issues/3683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow-csv reader cannot produce RecordBatch even if the bytes are necessary [\#3674](https://github.com/apache/arrow-rs/issues/3674) +- FFI Fails to Account For Offsets [\#3671](https://github.com/apache/arrow-rs/issues/3671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression in CSV reader error handling [\#3656](https://github.com/apache/arrow-rs/issues/3656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- UnionArray Child and Value Fail to Account for non-contiguous Type IDs [\#3653](https://github.com/apache/arrow-rs/issues/3653) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic when accessing RecordBatch from pyarrow [\#3646](https://github.com/apache/arrow-rs/issues/3646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Multiplication for decimals is incorrect [\#3645](https://github.com/apache/arrow-rs/issues/3645) +- Inconsistent output between pretty print and CSV writer for Arrow [\#3513](https://github.com/apache/arrow-rs/issues/3513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Release 33.0.0 of arrow/arrow-flight/parquet/parquet-derive \(next release after 32.0.0\) [\#3682](https://github.com/apache/arrow-rs/issues/3682) +- Release `32.0.0` of `arrow`/`arrow-flight`/`parquet`/`parquet-derive` \(next release after `31.0.0`\) [\#3584](https://github.com/apache/arrow-rs/issues/3584) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- Move FFI to sub-crates [\#3687](https://github.com/apache/arrow-rs/pull/3687) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update to 33.0.0 and update changelog [\#3686](https://github.com/apache/arrow-rs/pull/3686) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Cleanup FFI interface \(\#3684\) \(\#3683\) [\#3685](https://github.com/apache/arrow-rs/pull/3685) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: take\_run benchmark parameter [\#3679](https://github.com/apache/arrow-rs/pull/3679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Minor: Add some examples to Date\*Array and Time\*Array [\#3678](https://github.com/apache/arrow-rs/pull/3678) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add CSV Decoder::capacity \(\#3674\) [\#3677](https://github.com/apache/arrow-rs/pull/3677) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ArrayData::new\_null and DataType::primitive\_width [\#3676](https://github.com/apache/arrow-rs/pull/3676) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix FFI which fails to account for offsets [\#3675](https://github.com/apache/arrow-rs/pull/3675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support UTF8 cast to Timestamp with timezone [\#3673](https://github.com/apache/arrow-rs/pull/3673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Fix Date64Array docs [\#3670](https://github.com/apache/arrow-rs/pull/3670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.50 to =1.0.51 [\#3669](https://github.com/apache/arrow-rs/pull/3669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add timezone accessor for Timestamp\*Array [\#3666](https://github.com/apache/arrow-rs/pull/3666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster timezone cast [\#3665](https://github.com/apache/arrow-rs/pull/3665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat + fix: IPC support for run encoded array. [\#3662](https://github.com/apache/arrow-rs/pull/3662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Implement std::fmt::Write for StringBuilder \(\#3638\) [\#3659](https://github.com/apache/arrow-rs/pull/3659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Include line and field number in CSV UTF-8 error \(\#3656\) [\#3657](https://github.com/apache/arrow-rs/pull/3657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Handle non-contiguous type\_ids in UnionArray \(\#3653\) [\#3654](https://github.com/apache/arrow-rs/pull/3654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add modulus\_dyn and modulus\_scalar\_dyn [\#3649](https://github.com/apache/arrow-rs/pull/3649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve error message with detailed schema [\#3637](https://github.com/apache/arrow-rs/pull/3637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Veeupup](https://github.com/Veeupup)) +- Add limit to ArrowReaderBuilder to push limit down to parquet reader [\#3633](https://github.com/apache/arrow-rs/pull/3633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- chore: delete wrong comment and refactor set\_metadata in `Field` [\#3630](https://github.com/apache/arrow-rs/pull/3630) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chunshao90](https://github.com/chunshao90)) +- Fix typo in comment [\#3627](https://github.com/apache/arrow-rs/pull/3627) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kjschiroo](https://github.com/kjschiroo)) +- Minor: Update doc strings about Page Index / Column Index [\#3625](https://github.com/apache/arrow-rs/pull/3625) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Specified version of helper function to cast binary to string [\#3624](https://github.com/apache/arrow-rs/pull/3624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat: take kernel for RunArray [\#3622](https://github.com/apache/arrow-rs/pull/3622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Remove BitSliceIterator specialization from try\_for\_each\_valid\_idx [\#3621](https://github.com/apache/arrow-rs/pull/3621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Reduce PrimitiveArray::try\_unary codegen [\#3619](https://github.com/apache/arrow-rs/pull/3619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Reduce Dictionary Builder Codegen [\#3616](https://github.com/apache/arrow-rs/pull/3616) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Add test for dictionary encoding of batches [\#3608](https://github.com/apache/arrow-rs/pull/3608) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Casting generic binary to generic string [\#3607](https://github.com/apache/arrow-rs/pull/3607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add ArrayAccessor, Iterator, Extend and benchmarks for RunArray [\#3603](https://github.com/apache/arrow-rs/pull/3603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +## [32.0.0](https://github.com/apache/arrow-rs/tree/32.0.0) (2023-01-27) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/31.0.0...32.0.0) + +**Breaking changes:** + +- Allow `StringArray` construction with `Vec>` [\#3602](https://github.com/apache/arrow-rs/pull/3602) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sinistersnare](https://github.com/sinistersnare)) +- Use native types in PageIndex \(\#3575\) [\#3578](https://github.com/apache/arrow-rs/pull/3578) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add external variant to ParquetError \(\#3285\) [\#3574](https://github.com/apache/arrow-rs/pull/3574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Return reference from ListArray::values [\#3561](https://github.com/apache/arrow-rs/pull/3561) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Add `RunEndEncodedArray` [\#3553](https://github.com/apache/arrow-rs/pull/3553) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +**Implemented enhancements:** + +- There should be a `From>>` impl for `GenericStringArray` [\#3599](https://github.com/apache/arrow-rs/issues/3599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FlightDataEncoder Optionally send Schema even when no record batches [\#3591](https://github.com/apache/arrow-rs/issues/3591) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Use Native Types in PageIndex [\#3575](https://github.com/apache/arrow-rs/issues/3575) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Packing array into dictionary of generic byte array [\#3571](https://github.com/apache/arrow-rs/issues/3571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Error::Source` for ArrowError and FlightError [\#3566](https://github.com/apache/arrow-rs/issues/3566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[FlightSQL\] Allow access to underlying FlightClient [\#3551](https://github.com/apache/arrow-rs/issues/3551) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Arrow CSV writer should not fail when cannot cast the value [\#3547](https://github.com/apache/arrow-rs/issues/3547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Write Deprecated Min Max Statistics When ColumnOrder Signed [\#3526](https://github.com/apache/arrow-rs/issues/3526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve Performance of JSON Reader [\#3441](https://github.com/apache/arrow-rs/issues/3441) +- Support footer kv metadata for IPC file [\#3432](https://github.com/apache/arrow-rs/issues/3432) +- Add `External` variant to ParquetError [\#3285](https://github.com/apache/arrow-rs/issues/3285) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Nullif of NULL Predicate is not NULL [\#3589](https://github.com/apache/arrow-rs/issues/3589) +- BooleanBufferBuilder Fails to Clear Set Bits On Truncate [\#3587](https://github.com/apache/arrow-rs/issues/3587) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `nullif` incorrectly calculates `null_count`, sometimes panics with subtraction overflow error [\#3579](https://github.com/apache/arrow-rs/issues/3579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Meet warning when use pyarrow [\#3543](https://github.com/apache/arrow-rs/issues/3543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect row group total\_byte\_size written to parquet file [\#3530](https://github.com/apache/arrow-rs/issues/3530) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Overflow when casting timestamps prior to the epoch [\#3512](https://github.com/apache/arrow-rs/issues/3512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Panic on Key Overflow in Dictionary Builders [\#3562](https://github.com/apache/arrow-rs/issues/3562) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bumping version gives compilation error \(arrow-array\) [\#3525](https://github.com/apache/arrow-rs/issues/3525) + +**Merged pull requests:** + +- Add Push-Based CSV Decoder [\#3604](https://github.com/apache/arrow-rs/pull/3604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update to flatbuffers 23.1.21 [\#3597](https://github.com/apache/arrow-rs/pull/3597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster BooleanBufferBuilder::append\_n for true values [\#3596](https://github.com/apache/arrow-rs/pull/3596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support sending schemas for empty streams [\#3594](https://github.com/apache/arrow-rs/pull/3594) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Faster ListArray to StringArray conversion [\#3593](https://github.com/apache/arrow-rs/pull/3593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add conversion from StringArray to BinaryArray [\#3592](https://github.com/apache/arrow-rs/pull/3592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix nullif null count \(\#3579\) [\#3590](https://github.com/apache/arrow-rs/pull/3590) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear bits in BooleanBufferBuilder \(\#3587\) [\#3588](https://github.com/apache/arrow-rs/pull/3588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Iterate all dictionary key types in cast test [\#3585](https://github.com/apache/arrow-rs/pull/3585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Propagate EOF Error from AsyncRead [\#3576](https://github.com/apache/arrow-rs/pull/3576) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Sach1nAgarwal](https://github.com/Sach1nAgarwal)) +- Show row\_counts also for \(FixedLen\)ByteArray [\#3573](https://github.com/apache/arrow-rs/pull/3573) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([bmmeijers](https://github.com/bmmeijers)) +- Packing array into dictionary of generic byte array [\#3572](https://github.com/apache/arrow-rs/pull/3572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove unwrap on datetime cast for CSV writer [\#3570](https://github.com/apache/arrow-rs/pull/3570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Implement `std::error::Error::source` for `ArrowError` and `FlightError` [\#3567](https://github.com/apache/arrow-rs/pull/3567) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Improve GenericBytesBuilder offset overflow panic message \(\#139\) [\#3564](https://github.com/apache/arrow-rs/pull/3564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Extend for ArrayBuilder \(\#1841\) [\#3563](https://github.com/apache/arrow-rs/pull/3563) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update pyarrow method call with kwargs [\#3560](https://github.com/apache/arrow-rs/pull/3560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Frankonly](https://github.com/Frankonly)) +- Update pyo3 requirement from 0.17 to 0.18 [\#3557](https://github.com/apache/arrow-rs/pull/3557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Expose Inner FlightServiceClient on FlightSqlServiceClient \(\#3551\) [\#3556](https://github.com/apache/arrow-rs/pull/3556) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix final page row count in parquet-index binary [\#3554](https://github.com/apache/arrow-rs/pull/3554) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet Avoid Reading 8 Byte Footer Twice from AsyncRead [\#3550](https://github.com/apache/arrow-rs/pull/3550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Sach1nAgarwal](https://github.com/Sach1nAgarwal)) +- Improve concat kernel capacity estimation [\#3546](https://github.com/apache/arrow-rs/pull/3546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.49 to =1.0.50 [\#3545](https://github.com/apache/arrow-rs/pull/3545) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update pyarrow method call to avoid warning [\#3544](https://github.com/apache/arrow-rs/pull/3544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Frankonly](https://github.com/Frankonly)) +- Enable casting between Utf8/LargeUtf8 and Binary/LargeBinary [\#3542](https://github.com/apache/arrow-rs/pull/3542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use GHA concurrency groups \(\#3495\) [\#3538](https://github.com/apache/arrow-rs/pull/3538) ([tustvold](https://github.com/tustvold)) +- set sum of uncompressed column size as row group size for parquet files [\#3531](https://github.com/apache/arrow-rs/pull/3531) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sidred](https://github.com/sidred)) +- Minor: Add documentation about memory use for ArrayData [\#3529](https://github.com/apache/arrow-rs/pull/3529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Upgrade to clap 4.1 + fix test [\#3528](https://github.com/apache/arrow-rs/pull/3528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Write backwards compatible row group statistics \(\#3526\) [\#3527](https://github.com/apache/arrow-rs/pull/3527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- No panic on timestamp buffer overflow [\#3519](https://github.com/apache/arrow-rs/pull/3519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Support casting from binary to dictionary of binary [\#3482](https://github.com/apache/arrow-rs/pull/3482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add Raw JSON Reader \(~2.5x faster\) [\#3479](https://github.com/apache/arrow-rs/pull/3479) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [31.0.0](https://github.com/apache/arrow-rs/tree/31.0.0) (2023-01-13) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/30.0.1...31.0.0) + +**Breaking changes:** + +- support RFC3339 style timestamps in `arrow-json` [\#3449](https://github.com/apache/arrow-rs/pull/3449) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) +- Improve arrow flight batch splitting and naming [\#3444](https://github.com/apache/arrow-rs/pull/3444) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Parquet record API: timestamp as signed integer [\#3437](https://github.com/apache/arrow-rs/pull/3437) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ByteBaker](https://github.com/ByteBaker)) +- Support decimal int32/64 for writer [\#3431](https://github.com/apache/arrow-rs/pull/3431) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) + +**Implemented enhancements:** + +- Support casting Date32 to timestamp [\#3504](https://github.com/apache/arrow-rs/issues/3504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting strings like `'2001-01-01'` to timestamp [\#3492](https://github.com/apache/arrow-rs/issues/3492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CLI to "rewrite" parquet files [\#3476](https://github.com/apache/arrow-rs/issues/3476) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add more dictionary value type support to `build_compare` [\#3465](https://github.com/apache/arrow-rs/issues/3465) +- Allow `concat_batches` to take non owned RecordBatch [\#3456](https://github.com/apache/arrow-rs/issues/3456) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `30.0.1` \(maintenance release for `30.0.0`\) [\#3455](https://github.com/apache/arrow-rs/issues/3455) +- Add string comparisons \(starts\_with, ends\_with, and contains\) to kernel [\#3442](https://github.com/apache/arrow-rs/issues/3442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- make\_builder Loses Timezone and Decimal Scale Information [\#3435](https://github.com/apache/arrow-rs/issues/3435) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use RFC3339 style timestamps in arrow-json [\#3416](https://github.com/apache/arrow-rs/issues/3416) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ArrayData`get_slice_memory_size` or similar [\#3407](https://github.com/apache/arrow-rs/issues/3407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Unable to read CSV with null boolean value [\#3521](https://github.com/apache/arrow-rs/issues/3521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make consistent behavior on zeros equality on floating point types [\#3509](https://github.com/apache/arrow-rs/issues/3509) +- Sliced batch w/ bool column doesn't roundtrip through IPC [\#3496](https://github.com/apache/arrow-rs/issues/3496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- take kernel on List array introduces nulls instead of empty lists [\#3471](https://github.com/apache/arrow-rs/issues/3471) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Infinite Loop If Skipping More CSV Lines than Present [\#3469](https://github.com/apache/arrow-rs/issues/3469) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix reading null booleans from CSV [\#3523](https://github.com/apache/arrow-rs/pull/3523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- minor fix: use the unified decimal type builder [\#3522](https://github.com/apache/arrow-rs/pull/3522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Update version to `31.0.0` and add changelog [\#3518](https://github.com/apache/arrow-rs/pull/3518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Additional nullif re-export [\#3515](https://github.com/apache/arrow-rs/pull/3515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Make consistent behavior on zeros equality on floating point types [\#3510](https://github.com/apache/arrow-rs/pull/3510) ([viirya](https://github.com/viirya)) +- Enable cast Date32 to Timestamp [\#3508](https://github.com/apache/arrow-rs/pull/3508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Update prost-build requirement from =0.11.5 to =0.11.6 [\#3507](https://github.com/apache/arrow-rs/pull/3507) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- minor fix for the comments [\#3505](https://github.com/apache/arrow-rs/pull/3505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Fix DataTypeLayout for LargeList [\#3503](https://github.com/apache/arrow-rs/pull/3503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add string comparisons \(starts\_with, ends\_with, and contains\) to kernel [\#3502](https://github.com/apache/arrow-rs/pull/3502) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([snmvaughan](https://github.com/snmvaughan)) +- Add a function to get memory size of array slice [\#3501](https://github.com/apache/arrow-rs/pull/3501) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Fix IPCWriter for Sliced BooleanArray [\#3498](https://github.com/apache/arrow-rs/pull/3498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Fix: Added support to cast string without time [\#3494](https://github.com/apache/arrow-rs/pull/3494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gaelwjl](https://github.com/gaelwjl)) +- Fix negative interval prettyprint [\#3491](https://github.com/apache/arrow-rs/pull/3491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Fixes a broken link in the arrow lib.rs rustdoc [\#3487](https://github.com/apache/arrow-rs/pull/3487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- Refactoring build\_compare for decimal and using downcast\_primitive [\#3484](https://github.com/apache/arrow-rs/pull/3484) ([viirya](https://github.com/viirya)) +- Add tests for record batch size splitting logic in FlightClient [\#3481](https://github.com/apache/arrow-rs/pull/3481) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- change `concat_batches` parameter to non owned reference [\#3480](https://github.com/apache/arrow-rs/pull/3480) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- feat: add `parquet-rewrite` CLI [\#3477](https://github.com/apache/arrow-rs/pull/3477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) +- Preserve empty list array elements in take kernel [\#3473](https://github.com/apache/arrow-rs/pull/3473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonmmease](https://github.com/jonmmease)) +- Add a test for stream writer for writing sliced array [\#3472](https://github.com/apache/arrow-rs/pull/3472) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix CSV infinite loop and improve error messages [\#3470](https://github.com/apache/arrow-rs/pull/3470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more dictionary value type support to `build_compare` [\#3466](https://github.com/apache/arrow-rs/pull/3466) ([viirya](https://github.com/viirya)) +- Add tests for `FlightClient::{list_flights, list_actions, do_action, get_schema}` [\#3463](https://github.com/apache/arrow-rs/pull/3463) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Minor: add ticket links to failing ipc integration tests [\#3461](https://github.com/apache/arrow-rs/pull/3461) ([alamb](https://github.com/alamb)) +- feat: `column_name` based index access for `RecordBatch` and `StructArray` [\#3458](https://github.com/apache/arrow-rs/pull/3458) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Support Decimal256 in FFI [\#3453](https://github.com/apache/arrow-rs/pull/3453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove multiversion dependency [\#3452](https://github.com/apache/arrow-rs/pull/3452) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Re-export nullif kernel [\#3451](https://github.com/apache/arrow-rs/pull/3451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Meaningful error message for map builder with null keys [\#3450](https://github.com/apache/arrow-rs/pull/3450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Parquet writer v2: clear buffer after page flush [\#3447](https://github.com/apache/arrow-rs/pull/3447) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- Verify ArrayData::data\_type compatible in PrimitiveArray::from [\#3440](https://github.com/apache/arrow-rs/pull/3440) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Preserve DataType metadata in make\_builder [\#3438](https://github.com/apache/arrow-rs/pull/3438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate arrow ipc tests and increase coverage [\#3427](https://github.com/apache/arrow-rs/pull/3427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Generic bytes dictionary builder [\#3426](https://github.com/apache/arrow-rs/pull/3426) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Minor: Improve docs for arrow-ipc, remove clippy ignore [\#3421](https://github.com/apache/arrow-rs/pull/3421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- refactor: convert `*like_dyn`, `*like_utf8_scalar_dyn` and `*like_dict` functions to macros [\#3411](https://github.com/apache/arrow-rs/pull/3411) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add parquet-index binary [\#3405](https://github.com/apache/arrow-rs/pull/3405) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Complete mid-level `FlightClient` [\#3402](https://github.com/apache/arrow-rs/pull/3402) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Implement `RecordBatch` \<--\> `FlightData` encode/decode + tests [\#3391](https://github.com/apache/arrow-rs/pull/3391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Provide `into_builder` for bytearray [\#3326](https://github.com/apache/arrow-rs/pull/3326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [30.0.1](https://github.com/apache/arrow-rs/tree/30.0.1) (2023-01-04) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/30.0.0...30.0.1) + +**Implemented enhancements:** + +- Generic bytes dictionary builder [\#3425](https://github.com/apache/arrow-rs/issues/3425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Derive Clone for the builders in object-store. [\#3419](https://github.com/apache/arrow-rs/issues/3419) +- Mid-level `ArrowFlight` Client [\#3371](https://github.com/apache/arrow-rs/issues/3371) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Improve performance of the CSV parser [\#3338](https://github.com/apache/arrow-rs/issues/3338) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- `nullif` kernel no longer exported [\#3454](https://github.com/apache/arrow-rs/issues/3454) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PrimitiveArray from ArrayData Unsound For IntervalArray [\#3439](https://github.com/apache/arrow-rs/issues/3439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- LZ4-compressed PQ files unreadable by Pandas and ClickHouse [\#3433](https://github.com/apache/arrow-rs/issues/3433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet Record API: Cannot convert date before Unix epoch to json [\#3430](https://github.com/apache/arrow-rs/issues/3430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- parquet-fromcsv with writer version v2 does not stop [\#3408](https://github.com/apache/arrow-rs/issues/3408) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +## [30.0.0](https://github.com/apache/arrow-rs/tree/30.0.0) (2022-12-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/29.0.0...30.0.0) + +**Breaking changes:** + +- Infer Parquet JSON Logical and Converted Type as UTF-8 [\#3376](https://github.com/apache/arrow-rs/pull/3376) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use custom Any instead of prost\_types [\#3360](https://github.com/apache/arrow-rs/pull/3360) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Use bytes in arrow-flight [\#3359](https://github.com/apache/arrow-rs/pull/3359) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add derived implementations of Clone and Debug for `ParquetObjectReader` [\#3381](https://github.com/apache/arrow-rs/issues/3381) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up TrackedWrite [\#3366](https://github.com/apache/arrow-rs/issues/3366) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Is it possible for ArrowWriter to write key\_value\_metadata after write all records [\#3356](https://github.com/apache/arrow-rs/issues/3356) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add UnionArray test to arrow-pyarrow integration test [\#3346](https://github.com/apache/arrow-rs/issues/3346) +- Document / Deprecate arrow\_flight::utils::flight\_data\_from\_arrow\_batch [\#3312](https://github.com/apache/arrow-rs/issues/3312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[FlightSQL\] Support HTTPs [\#3309](https://github.com/apache/arrow-rs/issues/3309) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support UnionArray in ffi [\#3304](https://github.com/apache/arrow-rs/issues/3304) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for Azure Data Lake Storage Gen2 \(aka: ADLS Gen2\) in Object Store library [\#3283](https://github.com/apache/arrow-rs/issues/3283) +- Support casting from String to Decimal [\#3280](https://github.com/apache/arrow-rs/issues/3280) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow ArrowCSV writer to control the display of NULL values [\#3268](https://github.com/apache/arrow-rs/issues/3268) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- FlightSQL example is broken [\#3386](https://github.com/apache/arrow-rs/issues/3386) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- CSV Reader Bounds Incorrectly Handles Header [\#3364](https://github.com/apache/arrow-rs/issues/3364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect output string from `try_to_type` [\#3350](https://github.com/apache/arrow-rs/issues/3350) +- Decimal arithmetic computation fails to run because decimal type equality [\#3344](https://github.com/apache/arrow-rs/issues/3344) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pretty print not implemented for Map [\#3322](https://github.com/apache/arrow-rs/issues/3322) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ILIKE Kernels Inconsistent Case Folding [\#3311](https://github.com/apache/arrow-rs/issues/3311) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- minor: Improve arrow-flight docs [\#3372](https://github.com/apache/arrow-rs/pull/3372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Version 30.0.0 release notes and changelog [\#3406](https://github.com/apache/arrow-rs/pull/3406) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Ends ParquetRecordBatchStream when polling on StreamState::Error [\#3404](https://github.com/apache/arrow-rs/pull/3404) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- fix clippy issues [\#3398](https://github.com/apache/arrow-rs/pull/3398) ([Jimexist](https://github.com/Jimexist)) +- Upgrade multiversion to 0.7.1 [\#3396](https://github.com/apache/arrow-rs/pull/3396) ([viirya](https://github.com/viirya)) +- Make FlightSQL Support HTTPs [\#3388](https://github.com/apache/arrow-rs/pull/3388) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Fix broken FlightSQL example [\#3387](https://github.com/apache/arrow-rs/pull/3387) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Update prost-build [\#3385](https://github.com/apache/arrow-rs/pull/3385) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-arith \(\#2594\) [\#3384](https://github.com/apache/arrow-rs/pull/3384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add derive for Clone and Debug for `ParquetObjectReader` [\#3382](https://github.com/apache/arrow-rs/pull/3382) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kszlim](https://github.com/kszlim)) +- Initial Mid-level `FlightClient` [\#3378](https://github.com/apache/arrow-rs/pull/3378) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Document all features on docs.rs [\#3377](https://github.com/apache/arrow-rs/pull/3377) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-row \(\#2594\) [\#3375](https://github.com/apache/arrow-rs/pull/3375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unnecessary flush calls on TrackedWrite [\#3374](https://github.com/apache/arrow-rs/pull/3374) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.47 to =1.0.49 [\#3369](https://github.com/apache/arrow-rs/pull/3369) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add CSV build\_buffered \(\#3338\) [\#3368](https://github.com/apache/arrow-rs/pull/3368) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add append\_key\_value\_metadata [\#3367](https://github.com/apache/arrow-rs/pull/3367) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jiacai2050](https://github.com/jiacai2050)) +- Add csv-core based reader \(\#3338\) [\#3365](https://github.com/apache/arrow-rs/pull/3365) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Put BufWriter into TrackedWrite [\#3361](https://github.com/apache/arrow-rs/pull/3361) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add CSV reader benchmark \(\#3338\) [\#3357](https://github.com/apache/arrow-rs/pull/3357) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use ArrayData::ptr\_eq in DictionaryTracker [\#3354](https://github.com/apache/arrow-rs/pull/3354) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate flight\_data\_from\_arrow\_batch [\#3353](https://github.com/apache/arrow-rs/pull/3353) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Dandandan](https://github.com/Dandandan)) +- Fix incorrect output string from try\_to\_type [\#3351](https://github.com/apache/arrow-rs/pull/3351) ([viirya](https://github.com/viirya)) +- Fix unary\_dyn for decimal scalar arithmetic computation [\#3345](https://github.com/apache/arrow-rs/pull/3345) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add UnionArray test to arrow-pyarrow integration test [\#3343](https://github.com/apache/arrow-rs/pull/3343) ([viirya](https://github.com/viirya)) +- feat: configure null value in arrow csv writer [\#3342](https://github.com/apache/arrow-rs/pull/3342) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Optimize bulk writing of all blocks of bloom filter [\#3340](https://github.com/apache/arrow-rs/pull/3340) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add MapArray to pretty print [\#3339](https://github.com/apache/arrow-rs/pull/3339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Update prost-build 0.11.4 [\#3334](https://github.com/apache/arrow-rs/pull/3334) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Faster Parquet Bloom Writer [\#3333](https://github.com/apache/arrow-rs/pull/3333) ([tustvold](https://github.com/tustvold)) +- Add bloom filter benchmark for parquet writer [\#3323](https://github.com/apache/arrow-rs/pull/3323) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add ASCII fast path for ILIKE scalar \(90% faster\) [\#3306](https://github.com/apache/arrow-rs/pull/3306) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support UnionArray in ffi [\#3305](https://github.com/apache/arrow-rs/pull/3305) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support casting from String to Decimal [\#3281](https://github.com/apache/arrow-rs/pull/3281) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add more integration test for parquet bloom filter round trip tests [\#3210](https://github.com/apache/arrow-rs/pull/3210) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +## [29.0.0](https://github.com/apache/arrow-rs/tree/29.0.0) (2022-12-09) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/28.0.0...29.0.0) + +**Breaking changes:** + +- Minor: Allow `Field::new` and `Field::new_with_dict` to take existing `String` as well as `&str` [\#3288](https://github.com/apache/arrow-rs/pull/3288) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- update `&Option` to `Option<&T>` [\#3249](https://github.com/apache/arrow-rs/pull/3249) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Hide `*_dict_scalar` kernels behind `*_dyn` kernels [\#3202](https://github.com/apache/arrow-rs/pull/3202) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +**Implemented enhancements:** + +- Support writing BloomFilter in arrow\_writer [\#3275](https://github.com/apache/arrow-rs/issues/3275) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting from unsigned numeric to Decimal256 [\#3272](https://github.com/apache/arrow-rs/issues/3272) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting from Decimal256 to float types [\#3266](https://github.com/apache/arrow-rs/issues/3266) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make arithmetic kernels supports DictionaryArray of DecimalType [\#3254](https://github.com/apache/arrow-rs/issues/3254) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting from Decimal256 to unsigned numeric [\#3239](https://github.com/apache/arrow-rs/issues/3239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- precision is not considered when cast value to decimal [\#3223](https://github.com/apache/arrow-rs/issues/3223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use RegexSet in arrow\_csv::infer\_field\_schema [\#3211](https://github.com/apache/arrow-rs/issues/3211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement FlightSQL Client [\#3206](https://github.com/apache/arrow-rs/issues/3206) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add binary\_mut and try\_binary\_mut [\#3143](https://github.com/apache/arrow-rs/issues/3143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add try\_unary\_mut [\#3133](https://github.com/apache/arrow-rs/issues/3133) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Skip null buffer when importing FFI ArrowArray struct if no null buffer in the spec [\#3290](https://github.com/apache/arrow-rs/issues/3290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- using ahash `compile-time-rng` kills reproducible builds [\#3271](https://github.com/apache/arrow-rs/issues/3271) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Decimal128 to Decimal256 Overflows [\#3265](https://github.com/apache/arrow-rs/issues/3265) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `nullif` panics on empty array [\#3261](https://github.com/apache/arrow-rs/issues/3261) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Some more inconsistency between can\_cast\_types and cast\_with\_options [\#3250](https://github.com/apache/arrow-rs/issues/3250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable casting between Dictionary of DecimalArray and DecimalArray [\#3237](https://github.com/apache/arrow-rs/issues/3237) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- new\_null\_array Panics creating StructArray with non-nullable fields [\#3226](https://github.com/apache/arrow-rs/issues/3226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- bool should cast from/to Float16Type as `can_cast_types` returns true [\#3221](https://github.com/apache/arrow-rs/issues/3221) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Utf8 and LargeUtf8 cannot cast from/to Float16 but can\_cast\_types returns true [\#3220](https://github.com/apache/arrow-rs/issues/3220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Re-enable some tests in `arrow-cast` crate [\#3219](https://github.com/apache/arrow-rs/issues/3219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Off-by-one buffer size error triggers Panic when constructing RecordBatch from IPC bytes \(should return an Error\) [\#3215](https://github.com/apache/arrow-rs/issues/3215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow to and from pyarrow conversion results in changes in schema [\#3136](https://github.com/apache/arrow-rs/issues/3136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- better document when we need `LargeUtf8` instead of `Utf8` [\#3228](https://github.com/apache/arrow-rs/issues/3228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Use BufWriter when writing bloom filters and limit tests \(\#3318\) [\#3319](https://github.com/apache/arrow-rs/pull/3319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use take for dictionary like comparisons [\#3313](https://github.com/apache/arrow-rs/pull/3313) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update versions to 29.0.0 and update CHANGELOG [\#3315](https://github.com/apache/arrow-rs/pull/3315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- refactor: Merge similar functions `ilike_scalar` and `nilike_scalar` [\#3303](https://github.com/apache/arrow-rs/pull/3303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Split out arrow-ord \(\#2594\) [\#3299](https://github.com/apache/arrow-rs/pull/3299) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-string \(\#2594\) [\#3295](https://github.com/apache/arrow-rs/pull/3295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Skip null buffer when importing FFI ArrowArray struct if no null buffer in the spec [\#3293](https://github.com/apache/arrow-rs/pull/3293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Don't use dangling NonNull as sentinel [\#3289](https://github.com/apache/arrow-rs/pull/3289) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Set bloom filter on byte array [\#3284](https://github.com/apache/arrow-rs/pull/3284) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Fix ipc schema custom\_metadata serialization [\#3282](https://github.com/apache/arrow-rs/pull/3282) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Disable const-random ahash feature on non-WASM \(\#3271\) [\#3277](https://github.com/apache/arrow-rs/pull/3277) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- fix\(ffi\): handle null data buffers from empty arrays [\#3276](https://github.com/apache/arrow-rs/pull/3276) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Support casting from unsigned numeric to Decimal256 [\#3273](https://github.com/apache/arrow-rs/pull/3273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add parquet-layout binary [\#3269](https://github.com/apache/arrow-rs/pull/3269) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support casting from Decimal256 to float types [\#3267](https://github.com/apache/arrow-rs/pull/3267) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify decimal cast logic [\#3264](https://github.com/apache/arrow-rs/pull/3264) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix panic on nullif empty array \(\#3261\) [\#3263](https://github.com/apache/arrow-rs/pull/3263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::from\_unary and BooleanArray::from\_binary [\#3258](https://github.com/apache/arrow-rs/pull/3258) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Remove parquet build script [\#3257](https://github.com/apache/arrow-rs/pull/3257) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make arithmetic kernels supports DictionaryArray of DecimalType [\#3255](https://github.com/apache/arrow-rs/pull/3255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support List and LargeList in Row format \(\#3159\) [\#3251](https://github.com/apache/arrow-rs/pull/3251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't recurse to children in ArrayData::try\_new [\#3248](https://github.com/apache/arrow-rs/pull/3248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate dictionaries read over IPC [\#3247](https://github.com/apache/arrow-rs/pull/3247) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix MapBuilder example [\#3246](https://github.com/apache/arrow-rs/pull/3246) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Loosen nullability restrictions added in \#3205 \(\#3226\) [\#3244](https://github.com/apache/arrow-rs/pull/3244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Better document implications of offsets \(\#3228\) [\#3243](https://github.com/apache/arrow-rs/pull/3243) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add new API to validate the precision for decimal array [\#3242](https://github.com/apache/arrow-rs/pull/3242) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Move nullif to arrow-select \(\#2594\) [\#3241](https://github.com/apache/arrow-rs/pull/3241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Casting from Decimal256 to unsigned numeric [\#3240](https://github.com/apache/arrow-rs/pull/3240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable casting between Dictionary of DecimalArray and DecimalArray [\#3238](https://github.com/apache/arrow-rs/pull/3238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove unwraps from 'create\_primitive\_array' [\#3232](https://github.com/apache/arrow-rs/pull/3232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aarashy](https://github.com/aarashy)) +- Fix CI build by upgrading tonic-build to 0.8.4 [\#3231](https://github.com/apache/arrow-rs/pull/3231) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Remove negative scale check [\#3230](https://github.com/apache/arrow-rs/pull/3230) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update prost-build requirement from =0.11.2 to =0.11.3 [\#3225](https://github.com/apache/arrow-rs/pull/3225) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Get the round result for decimal to a decimal with smaller scale [\#3224](https://github.com/apache/arrow-rs/pull/3224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Move tests which require chrono-tz feature from `arrow-cast` to `arrow` [\#3222](https://github.com/apache/arrow-rs/pull/3222) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add test cases for extracting week with/without timezone [\#3218](https://github.com/apache/arrow-rs/pull/3218) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Use RegexSet for matching DataType [\#3217](https://github.com/apache/arrow-rs/pull/3217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Update tonic-build to 0.8.3 [\#3214](https://github.com/apache/arrow-rs/pull/3214) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Support StructArray in Row Format \(\#3159\) [\#3212](https://github.com/apache/arrow-rs/pull/3212) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Infer timestamps from CSV files [\#3209](https://github.com/apache/arrow-rs/pull/3209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- fix bug: cast decimal256 to other decimal with no-safe [\#3208](https://github.com/apache/arrow-rs/pull/3208) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- FlightSQL Client & integration test [\#3207](https://github.com/apache/arrow-rs/pull/3207) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Ensure StructArrays check nullability of fields [\#3205](https://github.com/apache/arrow-rs/pull/3205) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Remove special case ArrayData equality for decimals [\#3204](https://github.com/apache/arrow-rs/pull/3204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add a cast test case for decimal negative scale [\#3203](https://github.com/apache/arrow-rs/pull/3203) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Move zip and shift kernels to arrow-select [\#3201](https://github.com/apache/arrow-rs/pull/3201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate limit kernel [\#3200](https://github.com/apache/arrow-rs/pull/3200) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use SlicesIterator for ArrayData Equality [\#3198](https://github.com/apache/arrow-rs/pull/3198) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add \_dyn kernels of like, ilike, nlike, nilike kernels for dictionary support [\#3197](https://github.com/apache/arrow-rs/pull/3197) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Adding scalar nlike\_dyn, ilike\_dyn, nilike\_dyn kernels [\#3195](https://github.com/apache/arrow-rs/pull/3195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Use self capture in DataType [\#3190](https://github.com/apache/arrow-rs/pull/3190) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- To pyarrow with schema [\#3188](https://github.com/apache/arrow-rs/pull/3188) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- Support Duration in array\_value\_to\_string [\#3183](https://github.com/apache/arrow-rs/pull/3183) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Support `FixedSizeBinary` in Row format [\#3182](https://github.com/apache/arrow-rs/pull/3182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add binary\_mut and try\_binary\_mut [\#3144](https://github.com/apache/arrow-rs/pull/3144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add try\_unary\_mut [\#3134](https://github.com/apache/arrow-rs/pull/3134) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +## [28.0.0](https://github.com/apache/arrow-rs/tree/28.0.0) (2022-11-25) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/27.0.0...28.0.0) + +**Breaking changes:** + +- StructArray::columns return slice [\#3186](https://github.com/apache/arrow-rs/pull/3186) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return slice from GenericByteArray::value\_data [\#3171](https://github.com/apache/arrow-rs/pull/3171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support decimal negative scale [\#3152](https://github.com/apache/arrow-rs/pull/3152) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- refactor: convert `Field::metadata` to `HashMap` [\#3148](https://github.com/apache/arrow-rs/pull/3148) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Don't Skip Serializing Empty Metadata \(\#3082\) [\#3126](https://github.com/apache/arrow-rs/pull/3126) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add Decimal128, Decimal256, Float16 to DataType::is\_numeric [\#3121](https://github.com/apache/arrow-rs/pull/3121) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Upgrade to thrift 0.17 and fix issues [\#3104](https://github.com/apache/arrow-rs/pull/3104) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Fix prettyprint for Interval second fractions [\#3093](https://github.com/apache/arrow-rs/pull/3093) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Remove Option from `Field::metadata` [\#3091](https://github.com/apache/arrow-rs/pull/3091) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +**Implemented enhancements:** + +- Add iterator to RowSelection [\#3172](https://github.com/apache/arrow-rs/issues/3172) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- create an integration test set for parquet crate against pyspark for working with bloom filters [\#3167](https://github.com/apache/arrow-rs/issues/3167) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Row Format Size Tracking [\#3160](https://github.com/apache/arrow-rs/issues/3160) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add ArrayBuilder::finish\_cloned\(\) [\#3154](https://github.com/apache/arrow-rs/issues/3154) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Optimize memory usage of json reader [\#3150](https://github.com/apache/arrow-rs/issues/3150) +- Add `Field::size` and `DataType::size` [\#3147](https://github.com/apache/arrow-rs/issues/3147) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add like\_utf8\_scalar\_dyn kernel [\#3145](https://github.com/apache/arrow-rs/issues/3145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- support comparison for decimal128 array with scalar in kernel [\#3140](https://github.com/apache/arrow-rs/issues/3140) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- audit and create a document for bloom filter configurations [\#3138](https://github.com/apache/arrow-rs/issues/3138) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Should be the rounding vs truncation when cast decimal to smaller scale [\#3137](https://github.com/apache/arrow-rs/issues/3137) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Upgrade chrono to 0.4.23 [\#3120](https://github.com/apache/arrow-rs/issues/3120) +- Implements more temporal kernels using time\_fraction\_dyn [\#3108](https://github.com/apache/arrow-rs/issues/3108) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Upgrade to thrift 0.17 [\#3105](https://github.com/apache/arrow-rs/issues/3105) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Be able to parse time formatted strings [\#3100](https://github.com/apache/arrow-rs/issues/3100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve "Fail to merge schema" error messages [\#3095](https://github.com/apache/arrow-rs/issues/3095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Expose `SortingColumn` when reading and writing parquet metadata [\#3090](https://github.com/apache/arrow-rs/issues/3090) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Change Field::metadata to HashMap [\#3086](https://github.com/apache/arrow-rs/issues/3086) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support bloom filter reading and writing for parquet [\#3023](https://github.com/apache/arrow-rs/issues/3023) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- API to take back ownership of an ArrayRef [\#2901](https://github.com/apache/arrow-rs/issues/2901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Specialized Interleave Kernel [\#2864](https://github.com/apache/arrow-rs/issues/2864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- arithmetic overflow leads to segfault in `concat_batches` [\#3123](https://github.com/apache/arrow-rs/issues/3123) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clippy failing on master : error: use of deprecated associated function chrono::NaiveDate::from\_ymd: use from\_ymd\_opt\(\) instead [\#3097](https://github.com/apache/arrow-rs/issues/3097) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pretty print for interval types has wrong formatting [\#3092](https://github.com/apache/arrow-rs/issues/3092) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Field is not serializable with binary formats [\#3082](https://github.com/apache/arrow-rs/issues/3082) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal Casts are Unchecked [\#2986](https://github.com/apache/arrow-rs/issues/2986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Release Arrow `27.0.0` \(next release after `26.0.0`\) [\#3045](https://github.com/apache/arrow-rs/issues/3045) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Perf about ParquetRecordBatchStream vs ParquetRecordBatchReader [\#2916](https://github.com/apache/arrow-rs/issues/2916) + +**Merged pull requests:** + +- Improve regex related kernels by upto 85% [\#3192](https://github.com/apache/arrow-rs/pull/3192) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Derive clone for arrays [\#3184](https://github.com/apache/arrow-rs/pull/3184) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Row decode cleanups [\#3180](https://github.com/apache/arrow-rs/pull/3180) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update zstd requirement from 0.11.1 to 0.12.0 [\#3178](https://github.com/apache/arrow-rs/pull/3178) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Move decimal constants from `arrow-data` to `arrow-schema` crate [\#3177](https://github.com/apache/arrow-rs/pull/3177) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- bloom filter part V: add an integration with pytest against pyspark [\#3176](https://github.com/apache/arrow-rs/pull/3176) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Bloom filter config tweaks \(\#3023\) [\#3175](https://github.com/apache/arrow-rs/pull/3175) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add RowParser [\#3174](https://github.com/apache/arrow-rs/pull/3174) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `RowSelection::iter()`, `Into>` and example [\#3173](https://github.com/apache/arrow-rs/pull/3173) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add read parquet examples [\#3170](https://github.com/apache/arrow-rs/pull/3170) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([xudong963](https://github.com/xudong963)) +- Faster BinaryArray to StringArray conversion \(~67%\) [\#3168](https://github.com/apache/arrow-rs/pull/3168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unnecessary downcasts in builders [\#3166](https://github.com/apache/arrow-rs/pull/3166) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- bloom filter part IV: adjust writer properties, bloom filter properties, and incorporate into column encoder [\#3165](https://github.com/apache/arrow-rs/pull/3165) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Fix parquet decimal precision [\#3164](https://github.com/apache/arrow-rs/pull/3164) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- Add Row size methods \(\#3160\) [\#3163](https://github.com/apache/arrow-rs/pull/3163) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Prevent precision=0 for decimal type [\#3162](https://github.com/apache/arrow-rs/pull/3162) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Remove unnecessary Buffer::from\_slice\_ref reference [\#3161](https://github.com/apache/arrow-rs/pull/3161) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add finish\_cloned to ArrayBuilder [\#3158](https://github.com/apache/arrow-rs/pull/3158) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Check overflow in MutableArrayData extend offsets \(\#3123\) [\#3157](https://github.com/apache/arrow-rs/pull/3157) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Extend Decimal256 as Primitive [\#3156](https://github.com/apache/arrow-rs/pull/3156) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Doc improvements [\#3155](https://github.com/apache/arrow-rs/pull/3155) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add collect.rs example [\#3153](https://github.com/apache/arrow-rs/pull/3153) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Implement Neg for i256 [\#3151](https://github.com/apache/arrow-rs/pull/3151) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: `{Field,DataType}::size` [\#3149](https://github.com/apache/arrow-rs/pull/3149) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add like\_utf8\_scalar\_dyn kernel [\#3146](https://github.com/apache/arrow-rs/pull/3146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- comparison op: decimal128 array with scalar [\#3141](https://github.com/apache/arrow-rs/pull/3141) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Cast: should get the round result for decimal to a decimal with smaller scale [\#3139](https://github.com/apache/arrow-rs/pull/3139) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Fix Panic on Reading Corrupt Parquet Schema \(\#2855\) [\#3130](https://github.com/apache/arrow-rs/pull/3130) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- Clippy parquet fixes [\#3124](https://github.com/apache/arrow-rs/pull/3124) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add GenericByteBuilder \(\#2969\) [\#3122](https://github.com/apache/arrow-rs/pull/3122) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- parquet bloom filter part III: add sbbf writer, remove `bloom` default feature, add reader properties [\#3119](https://github.com/apache/arrow-rs/pull/3119) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Add downcast\_array \(\#2901\) [\#3117](https://github.com/apache/arrow-rs/pull/3117) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add COW conversion for Buffer and PrimitiveArray and unary\_mut [\#3115](https://github.com/apache/arrow-rs/pull/3115) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Include field name in merge error message [\#3113](https://github.com/apache/arrow-rs/pull/3113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Add PrimitiveArray::unary\_opt [\#3110](https://github.com/apache/arrow-rs/pull/3110) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implements more temporal kernels using time\_fraction\_dyn [\#3107](https://github.com/apache/arrow-rs/pull/3107) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- cast: support unsigned numeric type to decimal128 [\#3106](https://github.com/apache/arrow-rs/pull/3106) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Expose `SortingColumn` in parquet files [\#3103](https://github.com/apache/arrow-rs/pull/3103) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- parquet bloom filter part II: read sbbf bitset from row group reader, update API, and add cli demo [\#3102](https://github.com/apache/arrow-rs/pull/3102) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Parse Time32/Time64 from formatted string [\#3101](https://github.com/apache/arrow-rs/pull/3101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Cleanup temporal \_internal functions [\#3099](https://github.com/apache/arrow-rs/pull/3099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve schema mismatch error message [\#3098](https://github.com/apache/arrow-rs/pull/3098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Fix clippy by avoiding deprecated functions in chrono [\#3096](https://github.com/apache/arrow-rs/pull/3096) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Minor: Add diagrams and documentation to row format [\#3094](https://github.com/apache/arrow-rs/pull/3094) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Use ArrowNativeTypeOp instead of total\_cmp directly [\#3087](https://github.com/apache/arrow-rs/pull/3087) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Check overflow while casting between decimal types [\#3076](https://github.com/apache/arrow-rs/pull/3076) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add bloom filter implementation based on split block \(sbbf\) spec [\#3057](https://github.com/apache/arrow-rs/pull/3057) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Add FixedSizeBinaryArray::try\_from\_sparse\_iter\_with\_size [\#3054](https://github.com/apache/arrow-rs/pull/3054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +## [27.0.0](https://github.com/apache/arrow-rs/tree/27.0.0) (2022-11-11) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/26.0.0...27.0.0) + +**Breaking changes:** + +- Recurse into Dictionary value type in DataType::is\_nested [\#3083](https://github.com/apache/arrow-rs/pull/3083) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- early type checks in `RowConverter` [\#3080](https://github.com/apache/arrow-rs/pull/3080) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add Decimal128 and Decimal256 to downcast\_primitive [\#3056](https://github.com/apache/arrow-rs/pull/3056) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Replace remaining \_generic temporal kernels with \_dyn kernels [\#3046](https://github.com/apache/arrow-rs/pull/3046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Replace year\_generic with year\_dyn [\#3041](https://github.com/apache/arrow-rs/pull/3041) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Validate decimal256 with i256 directly [\#3025](https://github.com/apache/arrow-rs/pull/3025) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Hadoop LZ4 Support for LZ4 Codec [\#3013](https://github.com/apache/arrow-rs/pull/3013) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Replace hour\_generic with hour\_dyn [\#3006](https://github.com/apache/arrow-rs/pull/3006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Accept any &dyn Array in nullif kernel [\#2940](https://github.com/apache/arrow-rs/pull/2940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Row Format: Option to detach/own a row [\#3078](https://github.com/apache/arrow-rs/issues/3078) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row Format: API to check if datatypes are supported [\#3077](https://github.com/apache/arrow-rs/issues/3077) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Deprecate Buffer::count\_set\_bits [\#3067](https://github.com/apache/arrow-rs/issues/3067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Decimal128 and Decimal256 to downcast\_primitive [\#3055](https://github.com/apache/arrow-rs/issues/3055) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improved UX of creating `TimestampNanosecondArray` with timezones [\#3042](https://github.com/apache/arrow-rs/issues/3042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast decimal256 to signed integer [\#3039](https://github.com/apache/arrow-rs/issues/3039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting Date64 to Timestamp [\#3037](https://github.com/apache/arrow-rs/issues/3037) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check overflow when casting floating point value to decimal256 [\#3032](https://github.com/apache/arrow-rs/issues/3032) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare i256 in validate\_decimal256\_precision [\#3024](https://github.com/apache/arrow-rs/issues/3024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check overflow when casting floating point value to decimal128 [\#3020](https://github.com/apache/arrow-rs/issues/3020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add macro downcast\_temporal\_array [\#3008](https://github.com/apache/arrow-rs/issues/3008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace hour\_generic with hour\_dyn [\#3005](https://github.com/apache/arrow-rs/issues/3005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace temporal \_generic kernels with dyn [\#3004](https://github.com/apache/arrow-rs/issues/3004) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `RowSelection::intersection` [\#3003](https://github.com/apache/arrow-rs/issues/3003) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- I would like to round rather than truncate when casting f64 to decimal [\#2997](https://github.com/apache/arrow-rs/issues/2997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow::compute::kernels::temporal should support nanoseconds [\#2995](https://github.com/apache/arrow-rs/issues/2995) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `26.0.0` \(next release after `25.0.0`\) [\#2953](https://github.com/apache/arrow-rs/issues/2953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add timezone offset for debug format of Timestamp with Timezone [\#2917](https://github.com/apache/arrow-rs/issues/2917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support merge RowSelectors when creating RowSelection [\#2858](https://github.com/apache/arrow-rs/issues/2858) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Inconsistent Nan Handling Between Scalar and Non-Scalar Comparison Kernels [\#3074](https://github.com/apache/arrow-rs/issues/3074) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Debug format for timestamp ignores timezone [\#3069](https://github.com/apache/arrow-rs/issues/3069) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row format decode loses timezone [\#3063](https://github.com/apache/arrow-rs/issues/3063) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- binary operator produces incorrect result on arrays with resized null buffer [\#3061](https://github.com/apache/arrow-rs/issues/3061) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RLEDecoder Panics on Null Padded Pages [\#3035](https://github.com/apache/arrow-rs/issues/3035) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Nullif with incorrect valid\_count [\#3031](https://github.com/apache/arrow-rs/issues/3031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RLEDecoder::get\_batch\_with\_dict may panic on bit-packed runs longer than 1024 [\#3029](https://github.com/apache/arrow-rs/issues/3029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Converted type is None according to Parquet Tools then utilizing logical types [\#3017](https://github.com/apache/arrow-rs/issues/3017) +- CompressionCodec LZ4 incompatible with C++ implementation [\#2988](https://github.com/apache/arrow-rs/issues/2988) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Mark parquet predicate pushdown as complete [\#2987](https://github.com/apache/arrow-rs/pull/2987) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Merged pull requests:** + +- Improved UX of creating `TimestampNanosecondArray` with timezones [\#3088](https://github.com/apache/arrow-rs/pull/3088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([src255](https://github.com/src255)) +- Remove unused range module [\#3085](https://github.com/apache/arrow-rs/pull/3085) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make intersect\_row\_selections a member function [\#3084](https://github.com/apache/arrow-rs/pull/3084) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update hashbrown requirement from 0.12 to 0.13 [\#3081](https://github.com/apache/arrow-rs/pull/3081) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add `OwnedRow` [\#3079](https://github.com/apache/arrow-rs/pull/3079) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Use ArrowNativeTypeOp on non-scalar comparison kernels [\#3075](https://github.com/apache/arrow-rs/pull/3075) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add missing inline to ArrowNativeTypeOp [\#3073](https://github.com/apache/arrow-rs/pull/3073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix debug information for Timestamp with Timezone [\#3072](https://github.com/apache/arrow-rs/pull/3072) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Deprecate Buffer::count\_set\_bits \(\#3067\) [\#3071](https://github.com/apache/arrow-rs/pull/3071) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add compare to ArrowNativeTypeOp [\#3070](https://github.com/apache/arrow-rs/pull/3070) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve docstrings on WriterPropertiesBuilder [\#3068](https://github.com/apache/arrow-rs/pull/3068) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Faster f64 inequality [\#3065](https://github.com/apache/arrow-rs/pull/3065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix row format decode loses timezone \(\#3063\) [\#3064](https://github.com/apache/arrow-rs/pull/3064) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix null\_count computation in binary [\#3062](https://github.com/apache/arrow-rs/pull/3062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Faster f64 equality [\#3060](https://github.com/apache/arrow-rs/pull/3060) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update arrow-flight subcrates \(\#3044\) [\#3052](https://github.com/apache/arrow-rs/pull/3052) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Minor: Remove cloning ArrayData in with\_precision\_and\_scale [\#3050](https://github.com/apache/arrow-rs/pull/3050) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split out arrow-json \(\#3044\) [\#3049](https://github.com/apache/arrow-rs/pull/3049) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move `intersect_row_selections` from datafusion to arrow-rs. [\#3047](https://github.com/apache/arrow-rs/pull/3047) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Split out arrow-csv \(\#2594\) [\#3044](https://github.com/apache/arrow-rs/pull/3044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move reader\_parser to arrow-cast \(\#3022\) [\#3043](https://github.com/apache/arrow-rs/pull/3043) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cast decimal256 to signed integer [\#3040](https://github.com/apache/arrow-rs/pull/3040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable casting from Date64 to Timestamp [\#3038](https://github.com/apache/arrow-rs/pull/3038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Fix decoding long and/or padded RLE data \(\#3029\) \(\#3035\) [\#3036](https://github.com/apache/arrow-rs/pull/3036) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix nullif when existing array has no nulls [\#3034](https://github.com/apache/arrow-rs/pull/3034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow when casting floating point value to decimal256 [\#3033](https://github.com/apache/arrow-rs/pull/3033) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update parquet to depend on arrow subcrates [\#3028](https://github.com/apache/arrow-rs/pull/3028) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make various i256 methods const [\#3026](https://github.com/apache/arrow-rs/pull/3026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-ipc [\#3022](https://github.com/apache/arrow-rs/pull/3022) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow while casting floating point value to decimal128 [\#3021](https://github.com/apache/arrow-rs/pull/3021) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update arrow-flight [\#3019](https://github.com/apache/arrow-rs/pull/3019) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Move ArrowNativeTypeOp to arrow-array \(\#2594\) [\#3018](https://github.com/apache/arrow-rs/pull/3018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support cast timestamp to time [\#3016](https://github.com/apache/arrow-rs/pull/3016) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([naosense](https://github.com/naosense)) +- Add filter example [\#3014](https://github.com/apache/arrow-rs/pull/3014) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow when casting integer to decimal [\#3009](https://github.com/apache/arrow-rs/pull/3009) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add macro downcast\_temporal\_array [\#3007](https://github.com/apache/arrow-rs/pull/3007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Parquet Writer: Make column descriptor public on the writer [\#3002](https://github.com/apache/arrow-rs/pull/3002) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([pier-oliviert](https://github.com/pier-oliviert)) +- Update chrono-tz requirement from 0.7 to 0.8 [\#3001](https://github.com/apache/arrow-rs/pull/3001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Round instead of Truncate while casting float to decimal [\#3000](https://github.com/apache/arrow-rs/pull/3000) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Support Predicate Pushdown for Parquet Lists \(\#2108\) [\#2999](https://github.com/apache/arrow-rs/pull/2999) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-cast \(\#2594\) [\#2998](https://github.com/apache/arrow-rs/pull/2998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- `arrow::compute::kernels::temporal` should support nanoseconds [\#2996](https://github.com/apache/arrow-rs/pull/2996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Add `RowSelection::from_selectors_and_combine` to merge RowSelectors [\#2994](https://github.com/apache/arrow-rs/pull/2994) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Simplify Single-Column Dictionary Sort [\#2993](https://github.com/apache/arrow-rs/pull/2993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Add entry to changelog for 26.0.0 RC2 fix [\#2992](https://github.com/apache/arrow-rs/pull/2992) ([alamb](https://github.com/alamb)) +- Fix ignored limit on `lexsort_to_indices` [\#2991](https://github.com/apache/arrow-rs/pull/2991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add clone and equal functions for CastOptions [\#2985](https://github.com/apache/arrow-rs/pull/2985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- minor: remove redundant prefix [\#2983](https://github.com/apache/arrow-rs/pull/2983) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([jackwener](https://github.com/jackwener)) +- Compare dictionary decimal arrays [\#2982](https://github.com/apache/arrow-rs/pull/2982) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Compare dictionary and non-dictionary decimal arrays [\#2980](https://github.com/apache/arrow-rs/pull/2980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add decimal comparison kernel support [\#2978](https://github.com/apache/arrow-rs/pull/2978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Move concat kernel to arrow-select \(\#2594\) [\#2976](https://github.com/apache/arrow-rs/pull/2976) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Specialize interleave for byte arrays \(\#2864\) [\#2975](https://github.com/apache/arrow-rs/pull/2975) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use unary function for numeric to decimal cast [\#2973](https://github.com/apache/arrow-rs/pull/2973) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Specialize filter kernel for binary arrays \(\#2969\) [\#2971](https://github.com/apache/arrow-rs/pull/2971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Combine take\_utf8 and take\_binary \(\#2969\) [\#2970](https://github.com/apache/arrow-rs/pull/2970) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster Scalar Dictionary Comparison ~10% [\#2968](https://github.com/apache/arrow-rs/pull/2968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move `byte_size` from datafusion::physical\_expr [\#2965](https://github.com/apache/arrow-rs/pull/2965) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Pass decompressed size to parquet Codec::decompress \(\#2956\) [\#2959](https://github.com/apache/arrow-rs/pull/2959) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Add Decimal Arithmetic [\#2881](https://github.com/apache/arrow-rs/pull/2881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [26.0.0](https://github.com/apache/arrow-rs/tree/26.0.0) (2022-10-28) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/25.0.0...26.0.0) + +**Breaking changes:** + +- Cast Timestamps to RFC3339 strings [\#2934](https://github.com/apache/arrow-rs/issues/2934) +- Remove Unused NativeDecimalType [\#2945](https://github.com/apache/arrow-rs/pull/2945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Format Timestamps as RFC3339 [\#2939](https://github.com/apache/arrow-rs/pull/2939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Update flatbuffers to resolve RUSTSEC-2021-0122 [\#2895](https://github.com/apache/arrow-rs/pull/2895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- replace `from_timestamp` by `from_timestamp_opt` [\#2894](https://github.com/apache/arrow-rs/pull/2894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) + +**Implemented enhancements:** + +- Optimized way to count the numbers of `true` and `false` values in a BooleanArray [\#2963](https://github.com/apache/arrow-rs/issues/2963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add pow to i256 [\#2954](https://github.com/apache/arrow-rs/issues/2954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Write Generic Code over \[Large\]BinaryArray and \[Large\]StringArray [\#2946](https://github.com/apache/arrow-rs/issues/2946) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Page Row Count Limit [\#2941](https://github.com/apache/arrow-rs/issues/2941) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- prettyprint to show timezone offset for timestamp with timezone [\#2937](https://github.com/apache/arrow-rs/issues/2937) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast numeric to decimal256 [\#2922](https://github.com/apache/arrow-rs/issues/2922) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `freeze_with_dictionary` API to `MutableArrayData` [\#2914](https://github.com/apache/arrow-rs/issues/2914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support decimal256 array in sort kernels [\#2911](https://github.com/apache/arrow-rs/issues/2911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- support `[+/-]hhmm` and `[+/-]hh` as fixedoffset timezone format [\#2910](https://github.com/apache/arrow-rs/issues/2910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cleanup decimal sort function [\#2907](https://github.com/apache/arrow-rs/issues/2907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- replace `from_timestamp` by `from_timestamp_opt` [\#2892](https://github.com/apache/arrow-rs/issues/2892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Move Primitive arity kernels to arrow-array [\#2787](https://github.com/apache/arrow-rs/issues/2787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- add overflow-checking for negative arithmetic kernel [\#2662](https://github.com/apache/arrow-rs/issues/2662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Subtle compatibility issue with serve\_arrow [\#2952](https://github.com/apache/arrow-rs/issues/2952) +- error\[E0599\]: no method named `total_cmp` found for struct `f16` in the current scope [\#2926](https://github.com/apache/arrow-rs/issues/2926) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fail at rowSelection `and_then` method [\#2925](https://github.com/apache/arrow-rs/issues/2925) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Ordering not implemented for FixedSizeBinary types [\#2904](https://github.com/apache/arrow-rs/issues/2904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet API: Could not convert timestamp before unix epoch to string/json [\#2897](https://github.com/apache/arrow-rs/issues/2897) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Overly Pessimistic RLE Size Estimation [\#2889](https://github.com/apache/arrow-rs/issues/2889) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Memory alignment error in `RawPtrBox::new` [\#2882](https://github.com/apache/arrow-rs/issues/2882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compilation error under chrono-tz feature [\#2878](https://github.com/apache/arrow-rs/issues/2878) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- AHash Statically Allocates 64 bytes [\#2875](https://github.com/apache/arrow-rs/issues/2875) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `parquet::arrow::arrow_writer::ArrowWriter` ignores page size properties [\#2853](https://github.com/apache/arrow-rs/issues/2853) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Document crate topology \(\#2594\) [\#2913](https://github.com/apache/arrow-rs/pull/2913) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- SerializedFileWriter comments about multiple call on consumed self [\#2935](https://github.com/apache/arrow-rs/issues/2935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Pointer freed error when deallocating ArrayData with shared memory buffer [\#2874](https://github.com/apache/arrow-rs/issues/2874) +- Release Arrow `25.0.0` \(next release after `24.0.0`\) [\#2820](https://github.com/apache/arrow-rs/issues/2820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Replace DecimalArray with PrimitiveArray [\#2637](https://github.com/apache/arrow-rs/issues/2637) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix ignored limit on lexsort\_to\_indices (#2991) [\#2991](https://github.com/apache/arrow-rs/pull/2991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix GenericListArray::try\_new\_from\_array\_data error message \(\#526\) [\#2961](https://github.com/apache/arrow-rs/pull/2961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix take string on sliced indices [\#2960](https://github.com/apache/arrow-rs/pull/2960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::true\_count and BooleanArray::false\_count [\#2957](https://github.com/apache/arrow-rs/pull/2957) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add pow to i256 [\#2955](https://github.com/apache/arrow-rs/pull/2955) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix datatype for timestamptz debug fmt [\#2948](https://github.com/apache/arrow-rs/pull/2948) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Add GenericByteArray \(\#2946\) [\#2947](https://github.com/apache/arrow-rs/pull/2947) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Specialize interleave string ~2-3x faster [\#2944](https://github.com/apache/arrow-rs/pull/2944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Added support for LZ4\_RAW compression. \(\#1604\) [\#2943](https://github.com/apache/arrow-rs/pull/2943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Add optional page row count limit for parquet `WriterProperties` \(\#2941\) [\#2942](https://github.com/apache/arrow-rs/pull/2942) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup orphaned doc comments \(\#2935\) [\#2938](https://github.com/apache/arrow-rs/pull/2938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- support more fixedoffset tz format [\#2936](https://github.com/apache/arrow-rs/pull/2936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Benchmark with prepared row converter [\#2930](https://github.com/apache/arrow-rs/pull/2930) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add lexsort benchmark \(\#2871\) [\#2929](https://github.com/apache/arrow-rs/pull/2929) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve panic messages for RowSelection::and\_then \(\#2925\) [\#2928](https://github.com/apache/arrow-rs/pull/2928) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update required half from 2.0 --\> 2.1 [\#2927](https://github.com/apache/arrow-rs/pull/2927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Cast numeric to decimal256 [\#2923](https://github.com/apache/arrow-rs/pull/2923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup generated proto code [\#2921](https://github.com/apache/arrow-rs/pull/2921) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Deprecate TimestampArray from\_vec and from\_opt\_vec [\#2919](https://github.com/apache/arrow-rs/pull/2919) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support decimal256 array in sort kernels [\#2912](https://github.com/apache/arrow-rs/pull/2912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add timezone abstraction [\#2909](https://github.com/apache/arrow-rs/pull/2909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup decimal sort function [\#2908](https://github.com/apache/arrow-rs/pull/2908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify TimestampArray from\_vec with timezone [\#2906](https://github.com/apache/arrow-rs/pull/2906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement ord for FixedSizeBinary types [\#2905](https://github.com/apache/arrow-rs/pull/2905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- Update chrono-tz requirement from 0.6 to 0.7 [\#2903](https://github.com/apache/arrow-rs/pull/2903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Parquet record api support timestamp before epoch [\#2899](https://github.com/apache/arrow-rs/pull/2899) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AnthonyPoncet](https://github.com/AnthonyPoncet)) +- Specialize interleave integer [\#2898](https://github.com/apache/arrow-rs/pull/2898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support overflow-checking variant of negate kernel [\#2893](https://github.com/apache/arrow-rs/pull/2893) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Respect Page Size Limits in ArrowWriter \(\#2853\) [\#2890](https://github.com/apache/arrow-rs/pull/2890) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Improve row format docs [\#2888](https://github.com/apache/arrow-rs/pull/2888) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add FixedSizeList::from\_iter\_primitive [\#2887](https://github.com/apache/arrow-rs/pull/2887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify ListArray::from\_iter\_primitive [\#2886](https://github.com/apache/arrow-rs/pull/2886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out value selection kernels into arrow-select \(\#2594\) [\#2885](https://github.com/apache/arrow-rs/pull/2885) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Increase default IPC alignment to 64 \(\#2883\) [\#2884](https://github.com/apache/arrow-rs/pull/2884) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Copying inappropriately aligned buffer in ipc reader [\#2883](https://github.com/apache/arrow-rs/pull/2883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Validate decimal IPC read \(\#2387\) [\#2880](https://github.com/apache/arrow-rs/pull/2880) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix compilation error under `chrono-tz` feature [\#2879](https://github.com/apache/arrow-rs/pull/2879) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Don't validate decimal precision in ArrayData \(\#2637\) [\#2873](https://github.com/apache/arrow-rs/pull/2873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add downcast\_integer and downcast\_primitive [\#2872](https://github.com/apache/arrow-rs/pull/2872) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Filter DecimalArray as PrimitiveArray ~5x Faster \(\#2637\) [\#2870](https://github.com/apache/arrow-rs/pull/2870) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Treat DecimalArray as PrimitiveArray in row format [\#2866](https://github.com/apache/arrow-rs/pull/2866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [25.0.0](https://github.com/apache/arrow-rs/tree/25.0.0) (2022-10-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/24.0.0...25.0.0) + +**Breaking changes:** + +- Make DecimalArray as PrimitiveArray [\#2857](https://github.com/apache/arrow-rs/pull/2857) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix timestamp parsing while no explicit timezone given [\#2814](https://github.com/apache/arrow-rs/pull/2814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Support Arbitrary Number of Arrays in downcast\_primitive\_array [\#2809](https://github.com/apache/arrow-rs/pull/2809) ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Restore Integration test JSON schema serialization [\#2876](https://github.com/apache/arrow-rs/issues/2876) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix various invalid\_html\_tags clippy error [\#2861](https://github.com/apache/arrow-rs/issues/2861) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Replace complicated temporal macro with generic functions [\#2851](https://github.com/apache/arrow-rs/issues/2851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add NaN handling in dyn scalar comparison kernels [\#2829](https://github.com/apache/arrow-rs/issues/2829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant of sum kernel [\#2821](https://github.com/apache/arrow-rs/issues/2821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update to Clap 4 [\#2817](https://github.com/apache/arrow-rs/issues/2817) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Safe API to Operate on Dictionary Values [\#2797](https://github.com/apache/arrow-rs/issues/2797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add modulus op into `ArrowNativeTypeOp` [\#2753](https://github.com/apache/arrow-rs/issues/2753) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow creating of TimeUnit instances without direct dependency on parquet-format [\#2708](https://github.com/apache/arrow-rs/issues/2708) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Arrow Row Format [\#2677](https://github.com/apache/arrow-rs/issues/2677) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Don't try to infer nulls in CSV schema inference [\#2859](https://github.com/apache/arrow-rs/issues/2859) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `parquet::arrow::arrow_writer::ArrowWriter` ignores page size properties [\#2853](https://github.com/apache/arrow-rs/issues/2853) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Introducing ArrowNativeTypeOp made it impossible to call kernels from generics [\#2839](https://github.com/apache/arrow-rs/issues/2839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unsound ArrayData to Array Conversions [\#2834](https://github.com/apache/arrow-rs/issues/2834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: `the trait bound for<'de> arrow::datatypes::Schema: serde::de::Deserialize<'de> is not satisfied` [\#2825](https://github.com/apache/arrow-rs/issues/2825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- convert string to timestamp shouldn't apply local timezone offset if there's no explicit timezone info in the string [\#2813](https://github.com/apache/arrow-rs/issues/2813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add pub api for checking column index is sorted [\#2848](https://github.com/apache/arrow-rs/issues/2848) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Take decimal as primitive \(\#2637\) [\#2869](https://github.com/apache/arrow-rs/pull/2869) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-integration-test crate [\#2868](https://github.com/apache/arrow-rs/pull/2868) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Decimal cleanup \(\#2637\) [\#2865](https://github.com/apache/arrow-rs/pull/2865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix various invalid\_html\_tags clippy errors [\#2862](https://github.com/apache/arrow-rs/pull/2862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Don't try to infer nullability in CSV reader [\#2860](https://github.com/apache/arrow-rs/pull/2860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Fix page size on dictionary fallback [\#2854](https://github.com/apache/arrow-rs/pull/2854) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Replace complicated temporal macro with generic functions [\#2850](https://github.com/apache/arrow-rs/pull/2850) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- \[feat\] Add pub api for checking column index is sorted. [\#2849](https://github.com/apache/arrow-rs/pull/2849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- parquet: Add `snap` option to README [\#2847](https://github.com/apache/arrow-rs/pull/2847) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([exyi](https://github.com/exyi)) +- Cleanup cast kernel [\#2846](https://github.com/apache/arrow-rs/pull/2846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify ArrowNativeType [\#2841](https://github.com/apache/arrow-rs/pull/2841) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Expose ArrowNativeTypeOp trait to make it useful for type bound [\#2840](https://github.com/apache/arrow-rs/pull/2840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add `interleave` kernel \(\#1523\) [\#2838](https://github.com/apache/arrow-rs/pull/2838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Handle empty offsets buffer \(\#1824\) [\#2836](https://github.com/apache/arrow-rs/pull/2836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate ArrayData type when converting to Array \(\#2834\) [\#2835](https://github.com/apache/arrow-rs/pull/2835) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Derive ArrowPrimitiveType for Decimal128Type and Decimal256Type \(\#2637\) [\#2833](https://github.com/apache/arrow-rs/pull/2833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add NaN handling in dyn scalar comparison kernels [\#2830](https://github.com/apache/arrow-rs/pull/2830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify OrderPreservingInterner allocation strategy ~97% faster \(\#2677\) [\#2827](https://github.com/apache/arrow-rs/pull/2827) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Convert rows to arrays \(\#2677\) [\#2826](https://github.com/apache/arrow-rs/pull/2826) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add overflow-checking variant of sum kernel [\#2822](https://github.com/apache/arrow-rs/pull/2822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update Clap dependency to version 4 [\#2819](https://github.com/apache/arrow-rs/pull/2819) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jgoday](https://github.com/jgoday)) +- Fix i256 checked multiplication [\#2818](https://github.com/apache/arrow-rs/pull/2818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add string\_dictionary benches for row format \(\#2677\) [\#2816](https://github.com/apache/arrow-rs/pull/2816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add OrderPreservingInterner::lookup \(\#2677\) [\#2815](https://github.com/apache/arrow-rs/pull/2815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify FixedLengthEncoding [\#2812](https://github.com/apache/arrow-rs/pull/2812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement ArrowNumericType for Float16Type [\#2810](https://github.com/apache/arrow-rs/pull/2810) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add DictionaryArray::with\_values to make it easier to operate on dictionary values [\#2798](https://github.com/apache/arrow-rs/pull/2798) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add i256 \(\#2637\) [\#2781](https://github.com/apache/arrow-rs/pull/2781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add modulus ops into `ArrowNativeTypeOp` [\#2756](https://github.com/apache/arrow-rs/pull/2756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- feat: cast List / LargeList to Utf8 / LargeUtf8 [\#2588](https://github.com/apache/arrow-rs/pull/2588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gandronchik](https://github.com/gandronchik)) + +## [24.0.0](https://github.com/apache/arrow-rs/tree/24.0.0) (2022-09-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/23.0.0...24.0.0) + +**Breaking changes:** + +- Cleanup `ArrowNativeType` \(\#1918\) [\#2793](https://github.com/apache/arrow-rs/pull/2793) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove `ArrowNativeType::FromStr` [\#2775](https://github.com/apache/arrow-rs/pull/2775) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out `arrow-array` crate \(\#2594\) [\#2769](https://github.com/apache/arrow-rs/pull/2769) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `dyn_arith_dict` feature flag [\#2760](https://github.com/apache/arrow-rs/pull/2760) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out `arrow-data` into a separate crate [\#2746](https://github.com/apache/arrow-rs/pull/2746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-schema \(\#2594\) [\#2711](https://github.com/apache/arrow-rs/pull/2711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Include field name in Parquet PrimitiveTypeBuilder error messages [\#2804](https://github.com/apache/arrow-rs/issues/2804) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add PrimitiveArray::reinterpret\_cast [\#2785](https://github.com/apache/arrow-rs/issues/2785) +- BinaryBuilder and StringBuilder initialization parameters in struct\_builder may be wrong [\#2783](https://github.com/apache/arrow-rs/issues/2783) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide scalar dyn kernel which produces null for division by zero [\#2767](https://github.com/apache/arrow-rs/issues/2767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide dyn kernel which produces null for division by zero [\#2763](https://github.com/apache/arrow-rs/issues/2763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of checked kernels on non-null data [\#2747](https://github.com/apache/arrow-rs/issues/2747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variants of arithmetic dyn kernels [\#2739](https://github.com/apache/arrow-rs/issues/2739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The `binary` function should not panic on unequaled array length. [\#2721](https://github.com/apache/arrow-rs/issues/2721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- min compute kernel is incorrect with sliced buffers in arrow 23 [\#2779](https://github.com/apache/arrow-rs/issues/2779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `try_unary_dict` should check value type of dictionary array [\#2754](https://github.com/apache/arrow-rs/issues/2754) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add back JSON import/export for schema [\#2762](https://github.com/apache/arrow-rs/issues/2762) +- null casting and coercion for Decimal128 [\#2761](https://github.com/apache/arrow-rs/issues/2761) +- Json decoder behavior changed from versions 21 to 21 and returns non-sensical num\_rows for RecordBatch [\#2722](https://github.com/apache/arrow-rs/issues/2722) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `23.0.0` \(next release after `22.0.0`\) [\#2665](https://github.com/apache/arrow-rs/issues/2665) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- add field name to parquet PrimitiveTypeBuilder error messages [\#2805](https://github.com/apache/arrow-rs/pull/2805) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([andygrove](https://github.com/andygrove)) +- Add struct equality test case \(\#514\) [\#2791](https://github.com/apache/arrow-rs/pull/2791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move unary kernels to arrow-array \(\#2787\) [\#2789](https://github.com/apache/arrow-rs/pull/2789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Disable test harness for string\_dictionary\_builder benchmark [\#2788](https://github.com/apache/arrow-rs/pull/2788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveArray::reinterpret\_cast \(\#2785\) [\#2786](https://github.com/apache/arrow-rs/pull/2786) ([tustvold](https://github.com/tustvold)) +- Fix BinaryBuilder and StringBuilder Capacity Allocation in StructBuilder [\#2784](https://github.com/apache/arrow-rs/pull/2784) ([chunshao90](https://github.com/chunshao90)) +- Fix min/max computation for sliced arrays \(\#2779\) [\#2780](https://github.com/apache/arrow-rs/pull/2780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix Backwards Compatible Parquet List Encodings \(\#1915\) [\#2774](https://github.com/apache/arrow-rs/pull/2774) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- MINOR: Fix clippy for rust 1.64.0 [\#2772](https://github.com/apache/arrow-rs/pull/2772) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- MINOR: Fix clippy for rust 1.64.0 [\#2771](https://github.com/apache/arrow-rs/pull/2771) ([viirya](https://github.com/viirya)) +- Add divide scalar dyn kernel which produces null for division by zero [\#2768](https://github.com/apache/arrow-rs/pull/2768) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add divide dyn kernel which produces null for division by zero [\#2764](https://github.com/apache/arrow-rs/pull/2764) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add value type check in try\_unary\_dict [\#2755](https://github.com/apache/arrow-rs/pull/2755) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix `verify_release_candidate.sh` for new arrow subcrates [\#2752](https://github.com/apache/arrow-rs/pull/2752) ([alamb](https://github.com/alamb)) +- Fix: Issue 2721 : binary function should not panic but return error w… [\#2750](https://github.com/apache/arrow-rs/pull/2750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aksharau](https://github.com/aksharau)) +- Speed up checked kernels for non-null data \(~1.4-5x faster\) [\#2749](https://github.com/apache/arrow-rs/pull/2749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Add overflow-checking variants of arithmetic dyn kernels [\#2740](https://github.com/apache/arrow-rs/pull/2740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Trim parquet row selection [\#2705](https://github.com/apache/arrow-rs/pull/2705) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +## [23.0.0](https://github.com/apache/arrow-rs/tree/24.0.0) (2022-09-16) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/22.0.0...23.0.0) + +**Breaking changes:** + +- Move JSON Test Format To integration-testing [\#2724](https://github.com/apache/arrow-rs/pull/2724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-buffer crate \(\#2594\) [\#2693](https://github.com/apache/arrow-rs/pull/2693) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify DictionaryBuilder constructors \(\#2684\) \(\#2054\) [\#2685](https://github.com/apache/arrow-rs/pull/2685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate RecordBatch::concat replace with concat\_batches \(\#2594\) [\#2683](https://github.com/apache/arrow-rs/pull/2683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add overflow-checking variant for primitive arithmetic kernels and explicitly define overflow behavior [\#2643](https://github.com/apache/arrow-rs/pull/2643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update thrift v0.16 and vendor parquet-format \(\#2502\) [\#2626](https://github.com/apache/arrow-rs/pull/2626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update flight definitions including backwards-incompatible change to GetSchema [\#2586](https://github.com/apache/arrow-rs/pull/2586) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([liukun4515](https://github.com/liukun4515)) + +**Implemented enhancements:** + +- Cleanup like and nlike utf8 kernels [\#2744](https://github.com/apache/arrow-rs/issues/2744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup eq and neq kernels for utf8 arrays [\#2742](https://github.com/apache/arrow-rs/issues/2742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- API for more ergonomic construction of `RecordBatchOptions` [\#2728](https://github.com/apache/arrow-rs/issues/2728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Automate updates to `CHANGELOG-old.md` [\#2726](https://github.com/apache/arrow-rs/issues/2726) +- Don't check the `DivideByZero` error for float modulus [\#2720](https://github.com/apache/arrow-rs/issues/2720) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `try_binary` should not panic on unequaled array length. [\#2715](https://github.com/apache/arrow-rs/issues/2715) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add benchmark for bitwise operation [\#2714](https://github.com/apache/arrow-rs/issues/2714) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variants of arithmetic scalar dyn kernels [\#2712](https://github.com/apache/arrow-rs/issues/2712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide\_opt kernel which produce null values on division by zero error [\#2709](https://github.com/apache/arrow-rs/issues/2709) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `DataType` function to detect nested types [\#2704](https://github.com/apache/arrow-rs/issues/2704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support of sorting dictionary of other primitive types [\#2700](https://github.com/apache/arrow-rs/issues/2700) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Sort indices of dictionary string values [\#2697](https://github.com/apache/arrow-rs/issues/2697) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support empty projection in `RecordBatch::project` [\#2690](https://github.com/apache/arrow-rs/issues/2690) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support sorting dictionary encoded primitive integer arrays [\#2679](https://github.com/apache/arrow-rs/issues/2679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use BitIndexIterator in min\_max\_helper [\#2674](https://github.com/apache/arrow-rs/issues/2674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support building comparator for dictionaries of primitive integer values [\#2672](https://github.com/apache/arrow-rs/issues/2672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change max/min string macro to generic helper function `min_max_helper` [\#2657](https://github.com/apache/arrow-rs/issues/2657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant of arithmetic scalar kernels [\#2651](https://github.com/apache/arrow-rs/issues/2651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with binary array [\#2644](https://github.com/apache/arrow-rs/issues/2644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant for primitive arithmetic kernels [\#2642](https://github.com/apache/arrow-rs/issues/2642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `downcast_primitive_array` in arithmetic kernels [\#2639](https://github.com/apache/arrow-rs/issues/2639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support DictionaryArray in temporal kernels [\#2622](https://github.com/apache/arrow-rs/issues/2622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Inline Generated Thift Code Into Parquet Crate [\#2502](https://github.com/apache/arrow-rs/issues/2502) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Escape contains patterns for utf8 like kernels [\#2745](https://github.com/apache/arrow-rs/issues/2745) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Float Array should not panic on `DivideByZero` in the `Divide` kernel [\#2719](https://github.com/apache/arrow-rs/issues/2719) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DictionaryBuilders can Create Invalid DictionaryArrays [\#2684](https://github.com/apache/arrow-rs/issues/2684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow` crate does not build with `features = ["ffi"]` and `default_features = false`. [\#2670](https://github.com/apache/arrow-rs/issues/2670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Invalid results with `RowSelector` having `row_count` of 0 [\#2669](https://github.com/apache/arrow-rs/issues/2669) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- clippy error: unresolved import `crate::array::layout` [\#2659](https://github.com/apache/arrow-rs/issues/2659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast the numeric without the `CastOptions` [\#2648](https://github.com/apache/arrow-rs/issues/2648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Explicitly define overflow behavior for primitive arithmetic kernels [\#2641](https://github.com/apache/arrow-rs/issues/2641) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- update the `flight.proto` and fix schema to SchemaResult [\#2571](https://github.com/apache/arrow-rs/issues/2571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Panic when first data page is skipped using ColumnChunkData::Sparse [\#2543](https://github.com/apache/arrow-rs/issues/2543) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `SchemaResult` in IPC deviates from other implementations [\#2445](https://github.com/apache/arrow-rs/issues/2445) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Closed issues:** + +- Implement collect for int values [\#2696](https://github.com/apache/arrow-rs/issues/2696) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Speedup string equal/not equal to empty string, cleanup like/ilike kernels, fix escape bug [\#2743](https://github.com/apache/arrow-rs/pull/2743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Partially flatten arrow-buffer [\#2737](https://github.com/apache/arrow-rs/pull/2737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Automate updates to `CHANGELOG-old.md` [\#2732](https://github.com/apache/arrow-rs/pull/2732) ([iajoiner](https://github.com/iajoiner)) +- Update read parquet example in parquet/arrow home [\#2730](https://github.com/apache/arrow-rs/pull/2730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([datapythonista](https://github.com/datapythonista)) +- Better construction of RecordBatchOptions [\#2729](https://github.com/apache/arrow-rs/pull/2729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- benchmark: bitwise operation [\#2718](https://github.com/apache/arrow-rs/pull/2718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Update `try_binary` and `checked_ops`, and remove `math_checked_op` [\#2717](https://github.com/apache/arrow-rs/pull/2717) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support bitwise op in kernel: or,xor,not [\#2716](https://github.com/apache/arrow-rs/pull/2716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Add overflow-checking variants of arithmetic scalar dyn kernels [\#2713](https://github.com/apache/arrow-rs/pull/2713) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add divide\_opt kernel which produce null values on division by zero error [\#2710](https://github.com/apache/arrow-rs/pull/2710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add DataType::is\_nested\(\) [\#2707](https://github.com/apache/arrow-rs/pull/2707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kfastov](https://github.com/kfastov)) +- Update criterion requirement from 0.3 to 0.4 [\#2706](https://github.com/apache/arrow-rs/pull/2706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support bitwise and operation in the kernel [\#2703](https://github.com/apache/arrow-rs/pull/2703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Add support of sorting dictionary of other primitive arrays [\#2701](https://github.com/apache/arrow-rs/pull/2701) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Clarify docs of binary and string builders [\#2699](https://github.com/apache/arrow-rs/pull/2699) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([datapythonista](https://github.com/datapythonista)) +- Sort indices of dictionary string values [\#2698](https://github.com/apache/arrow-rs/pull/2698) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add support for empty projection in RecordBatch::project [\#2691](https://github.com/apache/arrow-rs/pull/2691) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Temporarily disable Golang integration tests re-enable JS [\#2689](https://github.com/apache/arrow-rs/pull/2689) ([tustvold](https://github.com/tustvold)) +- Verify valid UTF-8 when converting byte array \(\#2205\) [\#2686](https://github.com/apache/arrow-rs/pull/2686) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support sorting dictionary encoded primitive integer arrays [\#2680](https://github.com/apache/arrow-rs/pull/2680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Skip RowSelectors with zero rows [\#2678](https://github.com/apache/arrow-rs/pull/2678) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- Faster Null Path Selection in ArrayData Equality [\#2676](https://github.com/apache/arrow-rs/pull/2676) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dhruv9vats](https://github.com/dhruv9vats)) +- Use BitIndexIterator in min\_max\_helper [\#2675](https://github.com/apache/arrow-rs/pull/2675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support building comparator for dictionaries of primitive integer values [\#2673](https://github.com/apache/arrow-rs/pull/2673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- json feature always requires base64 feature [\#2668](https://github.com/apache/arrow-rs/pull/2668) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([eagletmt](https://github.com/eagletmt)) +- Add try\_unary, binary, try\_binary kernels ~90% faster [\#2666](https://github.com/apache/arrow-rs/pull/2666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use downcast\_dictionary\_array in unary\_dyn [\#2663](https://github.com/apache/arrow-rs/pull/2663) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- optimize the `numeric_cast_with_error` [\#2661](https://github.com/apache/arrow-rs/pull/2661) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- ffi feature also requires layout [\#2660](https://github.com/apache/arrow-rs/pull/2660) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Change max/min string macro to generic helper function min\_max\_helper [\#2658](https://github.com/apache/arrow-rs/pull/2658) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix flaky test `test_fuzz_async_reader_selection` [\#2656](https://github.com/apache/arrow-rs/pull/2656) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- MINOR: Ignore flaky test test\_fuzz\_async\_reader\_selection [\#2655](https://github.com/apache/arrow-rs/pull/2655) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- MutableBuffer::typed\_data - shared ref access to the typed slice [\#2652](https://github.com/apache/arrow-rs/pull/2652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([medwards](https://github.com/medwards)) +- Overflow-checking variant of arithmetic scalar kernels [\#2650](https://github.com/apache/arrow-rs/pull/2650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- support `CastOption` for casting numeric [\#2649](https://github.com/apache/arrow-rs/pull/2649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Help LLVM vectorize comparison kernel ~50-80% faster [\#2646](https://github.com/apache/arrow-rs/pull/2646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support comparison between dictionary array and binary array [\#2645](https://github.com/apache/arrow-rs/pull/2645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use `downcast_primitive_array` in arithmetic kernels [\#2640](https://github.com/apache/arrow-rs/pull/2640) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fully qualifying parquet items [\#2638](https://github.com/apache/arrow-rs/pull/2638) ([dingxiangfei2009](https://github.com/dingxiangfei2009)) +- Support DictionaryArray in temporal kernels [\#2623](https://github.com/apache/arrow-rs/pull/2623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Comparable Row Format [\#2593](https://github.com/apache/arrow-rs/pull/2593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix bug in page skipping [\#2552](https://github.com/apache/arrow-rs/pull/2552) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) + +## [22.0.0](https://github.com/apache/arrow-rs/tree/22.0.0) (2022-09-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/21.0.0...22.0.0) + +**Breaking changes:** + +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2614](https://github.com/apache/arrow-rs/pull/2614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Gate dyn comparison of dictionary arrays behind `dyn_cmp_dict` [\#2597](https://github.com/apache/arrow-rs/pull/2597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move JsonSerializable to json module \(\#2300\) [\#2595](https://github.com/apache/arrow-rs/pull/2595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Decimal precision scale datatype change [\#2532](https://github.com/apache/arrow-rs/pull/2532) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor PrimitiveBuilder Constructors [\#2518](https://github.com/apache/arrow-rs/pull/2518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactoring DecimalBuilder constructors [\#2517](https://github.com/apache/arrow-rs/pull/2517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor FixedSizeBinaryBuilder Constructors [\#2516](https://github.com/apache/arrow-rs/pull/2516) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor BooleanBuilder Constructors [\#2515](https://github.com/apache/arrow-rs/pull/2515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor UnionBuilder Constructors [\#2488](https://github.com/apache/arrow-rs/pull/2488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) + +**Implemented enhancements:** + +- Add Macros to assist with static dispatch [\#2635](https://github.com/apache/arrow-rs/issues/2635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support comparison between DictionaryArray and BooleanArray [\#2617](https://github.com/apache/arrow-rs/issues/2617) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2613](https://github.com/apache/arrow-rs/issues/2613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support empty projection in CSV, JSON readers [\#2603](https://github.com/apache/arrow-rs/issues/2603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support SQL-compliant NaN ordering between for DictionaryArray and non-DictionaryArray [\#2599](https://github.com/apache/arrow-rs/issues/2599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `dyn_cmp_dict` feature flag to gate dyn comparison of dictionary arrays [\#2596](https://github.com/apache/arrow-rs/issues/2596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2584](https://github.com/apache/arrow-rs/issues/2584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow FlightSQL implementers to extend `do_get()` [\#2581](https://github.com/apache/arrow-rs/issues/2581) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support SQL-compliant behavior on `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2569](https://github.com/apache/arrow-rs/issues/2569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add sql-compliant feature for enabling sql-compliant kernel behavior [\#2568](https://github.com/apache/arrow-rs/issues/2568) +- Calculate `sum` for dictionary array [\#2565](https://github.com/apache/arrow-rs/issues/2565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add test for float nan comparison [\#2556](https://github.com/apache/arrow-rs/issues/2556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with string array [\#2548](https://github.com/apache/arrow-rs/issues/2548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2538](https://github.com/apache/arrow-rs/issues/2538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2535](https://github.com/apache/arrow-rs/issues/2535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- UnionBuilder Create Children With Capacity [\#2523](https://github.com/apache/arrow-rs/issues/2523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up `like_utf8_scalar` for `%pat%` [\#2519](https://github.com/apache/arrow-rs/issues/2519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace macro with TypedDictionaryArray in comparison kernels [\#2513](https://github.com/apache/arrow-rs/issues/2513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use same codebase for boolean kernels [\#2507](https://github.com/apache/arrow-rs/issues/2507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use u8 for Decimal Precision and Scale [\#2496](https://github.com/apache/arrow-rs/issues/2496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Integrate skip row without pageIndex in SerializedPageReader in Fuzz Test [\#2475](https://github.com/apache/arrow-rs/issues/2475) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Avoid unnecessary copies in Arrow IPC reader [\#2437](https://github.com/apache/arrow-rs/issues/2437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add GenericColumnReader::skip\_records Missing OffsetIndex Fallback [\#2433](https://github.com/apache/arrow-rs/issues/2433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support Reading PageIndex with ParquetRecordBatchStream [\#2430](https://github.com/apache/arrow-rs/issues/2430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Specialize FixedLenByteArrayReader for Parquet [\#2318](https://github.com/apache/arrow-rs/issues/2318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make JSON support Optional via Feature Flag [\#2300](https://github.com/apache/arrow-rs/issues/2300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Casting timestamp array to string should not ignore timezone [\#2607](https://github.com/apache/arrow-rs/issues/2607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Ilike\_ut8\_scalar kernels have incorrect logic [\#2544](https://github.com/apache/arrow-rs/issues/2544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Always validate the array data when creating array in IPC reader [\#2541](https://github.com/apache/arrow-rs/issues/2541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Int96Converter Truncates Timestamps [\#2480](https://github.com/apache/arrow-rs/issues/2480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error Reading Page Index When Not Available [\#2434](https://github.com/apache/arrow-rs/issues/2434) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ParquetFileArrowReader::get_record_reader[_by_column]` `batch_size` overallocates [\#2321](https://github.com/apache/arrow-rs/issues/2321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Document All Arrow Features in docs.rs [\#2633](https://github.com/apache/arrow-rs/issues/2633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add support for CAST from `Interval(DayTime)` to `Timestamp(Nanosecond, None)` [\#2606](https://github.com/apache/arrow-rs/issues/2606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Why do we check for null in TypedDictionaryArray value function [\#2564](https://github.com/apache/arrow-rs/issues/2564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add the `length` field for `Buffer` [\#2524](https://github.com/apache/arrow-rs/issues/2524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Avoid large over allocate buffer in async reader [\#2512](https://github.com/apache/arrow-rs/issues/2512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rewriting Decimal Builders using `const_generic`. [\#2390](https://github.com/apache/arrow-rs/issues/2390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rewrite Decimal Array using `const_generic` [\#2384](https://github.com/apache/arrow-rs/issues/2384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Add downcast macros \(\#2635\) [\#2636](https://github.com/apache/arrow-rs/pull/2636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document all arrow features in docs.rs \(\#2633\) [\#2634](https://github.com/apache/arrow-rs/pull/2634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document dyn\_cmp\_dict [\#2624](https://github.com/apache/arrow-rs/pull/2624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support comparison between DictionaryArray and BooleanArray [\#2618](https://github.com/apache/arrow-rs/pull/2618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cast timestamp array to string array with timezone [\#2608](https://github.com/apache/arrow-rs/pull/2608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support empty projection in CSV and JSON readers [\#2604](https://github.com/apache/arrow-rs/pull/2604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make JSON support optional via a feature flag \(\#2300\) [\#2601](https://github.com/apache/arrow-rs/pull/2601) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray [\#2600](https://github.com/apache/arrow-rs/pull/2600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split out integration test plumbing \(\#2594\) \(\#2300\) [\#2598](https://github.com/apache/arrow-rs/pull/2598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refactor Binary Builder and String Builder Constructors [\#2592](https://github.com/apache/arrow-rs/pull/2592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Dictionary like scalar kernels [\#2591](https://github.com/apache/arrow-rs/pull/2591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Validate dictionary key in TypedDictionaryArray \(\#2578\) [\#2589](https://github.com/apache/arrow-rs/pull/2589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2585](https://github.com/apache/arrow-rs/pull/2585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Code cleanup of array value functions [\#2583](https://github.com/apache/arrow-rs/pull/2583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Allow overriding of do\_get & export useful macro [\#2582](https://github.com/apache/arrow-rs/pull/2582) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- MINOR: Upgrade to pyo3 0.17 [\#2576](https://github.com/apache/arrow-rs/pull/2576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Support SQL-compliant NaN behavior on eq\_dyn, neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#2570](https://github.com/apache/arrow-rs/pull/2570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add sum\_dyn to calculate sum for dictionary array [\#2566](https://github.com/apache/arrow-rs/pull/2566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- struct UnionBuilder will create child buffers with capacity [\#2560](https://github.com/apache/arrow-rs/pull/2560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kastolars](https://github.com/kastolars)) +- Don't panic on RleValueEncoder::flush\_buffer if empty \(\#2558\) [\#2559](https://github.com/apache/arrow-rs/pull/2559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add the `length` field for Buffer and use more `Buffer` in IPC reader to avoid memory copy. [\#2557](https://github.com/apache/arrow-rs/pull/2557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([HaoYang670](https://github.com/HaoYang670)) +- Add test for float nan comparison [\#2555](https://github.com/apache/arrow-rs/pull/2555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Compare dictionary array with string array [\#2549](https://github.com/apache/arrow-rs/pull/2549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Always validate the array data \(except the `Decimal`\) when creating array in IPC reader [\#2547](https://github.com/apache/arrow-rs/pull/2547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- MINOR: Fix test\_row\_type\_validation test [\#2546](https://github.com/apache/arrow-rs/pull/2546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix ilike\_utf8\_scalar kernels [\#2545](https://github.com/apache/arrow-rs/pull/2545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- fix typo [\#2540](https://github.com/apache/arrow-rs/pull/2540) ([00Masato](https://github.com/00Masato)) +- Compare dictionary array and primitive array in lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn kernels [\#2539](https://github.com/apache/arrow-rs/pull/2539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- \[MINOR\]Avoid large over allocate buffer in async reader [\#2537](https://github.com/apache/arrow-rs/pull/2537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2533](https://github.com/apache/arrow-rs/pull/2533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add iterator for FixedSizeBinaryArray [\#2531](https://github.com/apache/arrow-rs/pull/2531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- add bench: decimal with byte array and fixed length byte array [\#2529](https://github.com/apache/arrow-rs/pull/2529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Add FixedLengthByteArrayReader Remove ComplexObjectArrayReader [\#2528](https://github.com/apache/arrow-rs/pull/2528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out byte array decoders \(\#2318\) [\#2527](https://github.com/apache/arrow-rs/pull/2527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use offset index in ParquetRecordBatchStream [\#2526](https://github.com/apache/arrow-rs/pull/2526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Clean the `create_array` in IPC reader. [\#2525](https://github.com/apache/arrow-rs/pull/2525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove DecimalByteArrayConvert \(\#2480\) [\#2522](https://github.com/apache/arrow-rs/pull/2522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Improve performance of `%pat%` \(\>3x speedup\) [\#2521](https://github.com/apache/arrow-rs/pull/2521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- remove len field from MapBuilder [\#2520](https://github.com/apache/arrow-rs/pull/2520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Replace macro with TypedDictionaryArray in comparison kernels [\#2514](https://github.com/apache/arrow-rs/pull/2514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Avoid large over allocate buffer in sync reader [\#2511](https://github.com/apache/arrow-rs/pull/2511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Avoid useless memory copies in IPC reader. [\#2510](https://github.com/apache/arrow-rs/pull/2510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Refactor boolean kernels to use same codebase [\#2508](https://github.com/apache/arrow-rs/pull/2508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove Int96Converter \(\#2480\) [\#2481](https://github.com/apache/arrow-rs/pull/2481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + ## [21.0.0](https://github.com/apache/arrow-rs/tree/21.0.0) (2022-08-18) [Full Changelog](https://github.com/apache/arrow-rs/compare/20.0.0...21.0.0) @@ -430,7 +3231,7 @@ - Incorrect `null_count` of DictionaryArray [\#1962](https://github.com/apache/arrow-rs/issues/1962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Support multi diskRanges for ChunkReader [\#1955](https://github.com/apache/arrow-rs/issues/1955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - Persisting Arrow timestamps with Parquet produces missing `TIMESTAMP` in schema [\#1920](https://github.com/apache/arrow-rs/issues/1920) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Sperate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Separate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Merged pull requests:** @@ -487,7 +3288,7 @@ - `PrimitiveArray::from_iter` should omit validity buffer if all values are valid [\#1856](https://github.com/apache/arrow-rs/issues/1856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `from(v: Vec>)` and `from(v: Vec<&[u8]>)` for `FixedSizedBInaryArray` [\#1852](https://github.com/apache/arrow-rs/issues/1852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `Vec`-inspired APIs to `BufferBuilder` [\#1850](https://github.com/apache/arrow-rs/issues/1850) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- PyArrow intergation test for C Stream Interface [\#1847](https://github.com/apache/arrow-rs/issues/1847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PyArrow integration test for C Stream Interface [\#1847](https://github.com/apache/arrow-rs/issues/1847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `nilike` support in `comparison` [\#1845](https://github.com/apache/arrow-rs/issues/1845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Split up `arrow::array::builder` module [\#1843](https://github.com/apache/arrow-rs/issues/1843) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `quarter` support in `temporal` kernels [\#1835](https://github.com/apache/arrow-rs/issues/1835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] @@ -884,7 +3685,7 @@ **Fixed bugs:** -- Error Infering Schema for LogicalType::UNKNOWN [\#1557](https://github.com/apache/arrow-rs/issues/1557) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error Inferring Schema for LogicalType::UNKNOWN [\#1557](https://github.com/apache/arrow-rs/issues/1557) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - Read dictionary from nested struct in ipc stream reader panics [\#1549](https://github.com/apache/arrow-rs/issues/1549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - `filter` produces invalid sparse `UnionArray`s [\#1547](https://github.com/apache/arrow-rs/issues/1547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Documentation for `GenericListBuilder` is not exposed. [\#1518](https://github.com/apache/arrow-rs/issues/1518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] @@ -1410,7 +4211,7 @@ * [094037d418381584178db1d886cad3b5024b414a](https://github.com/apache/arrow-rs/commit/094037d418381584178db1d886cad3b5024b414a) Update comfy-table to 5.0 ([#957](https://github.com/apache/arrow-rs/pull/957)) ([#964](https://github.com/apache/arrow-rs/pull/964)) * [9f635021eee6786c5377c891218c5f88ebce07c3](https://github.com/apache/arrow-rs/commit/9f635021eee6786c5377c891218c5f88ebce07c3) Fix csv writing of timestamps to show timezone. ([#849](https://github.com/apache/arrow-rs/pull/849)) ([#963](https://github.com/apache/arrow-rs/pull/963)) * [f7deba4c3a050a52608462ee8a827bb8f6364140](https://github.com/apache/arrow-rs/commit/f7deba4c3a050a52608462ee8a827bb8f6364140) Adding ability to parse float from number with leading decimal ([#831](https://github.com/apache/arrow-rs/pull/831)) ([#962](https://github.com/apache/arrow-rs/pull/962)) -* [59f96e842d05b63882f7ba285c66a9739761cf84](https://github.com/apache/arrow-rs/commit/59f96e842d05b63882f7ba285c66a9739761cf84) add ilike comparitor ([#874](https://github.com/apache/arrow-rs/pull/874)) ([#961](https://github.com/apache/arrow-rs/pull/961)) +* [59f96e842d05b63882f7ba285c66a9739761cf84](https://github.com/apache/arrow-rs/commit/59f96e842d05b63882f7ba285c66a9739761cf84) add ilike comparator ([#874](https://github.com/apache/arrow-rs/pull/874)) ([#961](https://github.com/apache/arrow-rs/pull/961)) * [54023c8a5543c9f9fa4955afa01189029f3e96f5](https://github.com/apache/arrow-rs/commit/54023c8a5543c9f9fa4955afa01189029f3e96f5) Remove unpassable cargo publish check from verify-release-candidate.sh ([#882](https://github.com/apache/arrow-rs/pull/882)) ([#949](https://github.com/apache/arrow-rs/pull/949)) @@ -1507,7 +4308,7 @@ **Fixed bugs:** - Converting from string to timestamp uses microseconds instead of milliseconds [\#780](https://github.com/apache/arrow-rs/issues/780) -- Document has no link to `RowColumIter` [\#762](https://github.com/apache/arrow-rs/issues/762) +- Document has no link to `RowColumnIter` [\#762](https://github.com/apache/arrow-rs/issues/762) - length on slices with null doesn't work [\#744](https://github.com/apache/arrow-rs/issues/744) ## [5.4.0](https://github.com/apache/arrow-rs/tree/5.4.0) (2021-09-10) @@ -1565,7 +4366,7 @@ - Remove undefined behavior in `value` method of boolean and primitive arrays [\#645](https://github.com/apache/arrow-rs/issues/645) - Avoid materialization of indices in filter\_record\_batch for single arrays [\#636](https://github.com/apache/arrow-rs/issues/636) - Add a note about arrow crate security / safety [\#627](https://github.com/apache/arrow-rs/issues/627) -- Allow the creation of String arrays from an interator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) +- Allow the creation of String arrays from an iterator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) - Support arrow map datatype [\#395](https://github.com/apache/arrow-rs/issues/395) **Fixed bugs:** @@ -1694,7 +4495,7 @@ - Add C data interface for decimal128 and timestamp [\#453](https://github.com/apache/arrow-rs/pull/453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alippai](https://github.com/alippai)) - Implement the Iterator trait for the json Reader. [\#451](https://github.com/apache/arrow-rs/pull/451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([LaurentMazare](https://github.com/LaurentMazare)) - Update release docs + release email template [\#450](https://github.com/apache/arrow-rs/pull/450) ([alamb](https://github.com/alamb)) -- remove clippy unnecessary wraps suppresions in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) +- remove clippy unnecessary wraps suppression in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) - Use partition for bool sort [\#448](https://github.com/apache/arrow-rs/pull/448) ([Jimexist](https://github.com/Jimexist)) - remove unnecessary wraps in sort [\#445](https://github.com/apache/arrow-rs/pull/445) ([Jimexist](https://github.com/Jimexist)) - Python FFI bridge for Schema, Field and DataType [\#439](https://github.com/apache/arrow-rs/pull/439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszucs](https://github.com/kszucs)) @@ -1767,7 +4568,7 @@ - ARROW-12504: Buffer::from\_slice\_ref set correct capacity [\#18](https://github.com/apache/arrow-rs/pull/18) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) - Add GitHub templates [\#17](https://github.com/apache/arrow-rs/pull/17) ([andygrove](https://github.com/andygrove)) - ARROW-12493: Add support for writing dictionary arrays to CSV and JSON [\#16](https://github.com/apache/arrow-rs/pull/16) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- ARROW-12426: \[Rust\] Fix concatentation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ARROW-12426: \[Rust\] Fix concatenation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) - Update repository and homepage urls [\#14](https://github.com/apache/arrow-rs/pull/14) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Dandandan](https://github.com/Dandandan)) - Added rebase-needed bot [\#13](https://github.com/apache/arrow-rs/pull/13) ([jorgecarleitao](https://github.com/jorgecarleitao)) - Added Integration tests against arrow [\#10](https://github.com/apache/arrow-rs/pull/10) ([jorgecarleitao](https://github.com/jorgecarleitao)) @@ -1911,7 +4712,7 @@ - Support sort [\#215](https://github.com/apache/arrow-rs/issues/215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Support stable Rust [\#214](https://github.com/apache/arrow-rs/issues/214) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Remove Rust and point integration tests to arrow-rs repo [\#211](https://github.com/apache/arrow-rs/issues/211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- ArrayData buffers are inconsistent accross implementations [\#207](https://github.com/apache/arrow-rs/issues/207) +- ArrayData buffers are inconsistent across implementations [\#207](https://github.com/apache/arrow-rs/issues/207) - 3.0.1 patch release [\#204](https://github.com/apache/arrow-rs/issues/204) - Document patch release process [\#202](https://github.com/apache/arrow-rs/issues/202) - Simplify Offset [\#186](https://github.com/apache/arrow-rs/issues/186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f2b8af6cf8..20c13ffebbd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,119 +19,161 @@ # Changelog -## [22.0.0](https://github.com/apache/arrow-rs/tree/22.0.0) (2022-09-02) +## [53.0.0](https://github.com/apache/arrow-rs/tree/53.0.0) (2024-08-31) -[Full Changelog](https://github.com/apache/arrow-rs/compare/21.0.0...22.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/52.2.0...53.0.0) **Breaking changes:** -- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2614](https://github.com/apache/arrow-rs/pull/2614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Gate dyn comparison of dictionary arrays behind `dyn_cmp_dict` [\#2597](https://github.com/apache/arrow-rs/pull/2597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Move JsonSerializable to json module \(\#2300\) [\#2595](https://github.com/apache/arrow-rs/pull/2595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Decimal precision scale datatype change [\#2532](https://github.com/apache/arrow-rs/pull/2532) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor PrimitiveBuilder Constructors [\#2518](https://github.com/apache/arrow-rs/pull/2518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactoring DecimalBuilder constructors [\#2517](https://github.com/apache/arrow-rs/pull/2517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor FixedSizeBinaryBuilder Constructors [\#2516](https://github.com/apache/arrow-rs/pull/2516) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor BooleanBuilder Constructors [\#2515](https://github.com/apache/arrow-rs/pull/2515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor UnionBuilder Constructors [\#2488](https://github.com/apache/arrow-rs/pull/2488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- parquet\_derive: Match fields by name, support reading selected fields rather than all [\#6269](https://github.com/apache/arrow-rs/pull/6269) ([double-free](https://github.com/double-free)) +- Update parquet object\_store dependency to 0.11.0 [\#6264](https://github.com/apache/arrow-rs/pull/6264) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- parquet Statistics - deprecate `has_*` APIs and add `_opt` functions that return `Option` [\#6216](https://github.com/apache/arrow-rs/pull/6216) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Michael-J-Ward](https://github.com/Michael-J-Ward)) +- Expose bulk ingest in flight sql client and server [\#6201](https://github.com/apache/arrow-rs/pull/6201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([djanderson](https://github.com/djanderson)) +- Upgrade protobuf definitions to flightsql 17.0 \(\#6133\) [\#6169](https://github.com/apache/arrow-rs/pull/6169) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Remove automatic buffering in `ipc::reader::FileReader` for for consistent buffering [\#6132](https://github.com/apache/arrow-rs/pull/6132) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([V0ldek](https://github.com/V0ldek)) +- No longer write Parquet column metadata after column chunks \*and\* in the footer [\#6117](https://github.com/apache/arrow-rs/pull/6117) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Remove `impl> From for Buffer` that easily accidentally copies data [\#6043](https://github.com/apache/arrow-rs/pull/6043) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) **Implemented enhancements:** -- Add Macros to assist with static dispatch [\#2635](https://github.com/apache/arrow-rs/issues/2635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support comparison between DictionaryArray and BooleanArray [\#2617](https://github.com/apache/arrow-rs/issues/2617) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2613](https://github.com/apache/arrow-rs/issues/2613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support empty projection in CSV, JSON readers [\#2603](https://github.com/apache/arrow-rs/issues/2603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support SQL-compliant NaN ordering between for DictionaryArray and non-DictionaryArray [\#2599](https://github.com/apache/arrow-rs/issues/2599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add `dyn_cmp_dict` feature flag to gate dyn comparison of dictionary arrays [\#2596](https://github.com/apache/arrow-rs/issues/2596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2584](https://github.com/apache/arrow-rs/issues/2584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Allow FlightSQL implementers to extend `do_get()` [\#2581](https://github.com/apache/arrow-rs/issues/2581) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Support SQL-compliant behavior on `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2569](https://github.com/apache/arrow-rs/issues/2569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add sql-compliant feature for enabling sql-compliant kernel behavior [\#2568](https://github.com/apache/arrow-rs/issues/2568) -- Calculate `sum` for dictionary array [\#2565](https://github.com/apache/arrow-rs/issues/2565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add test for float nan comparison [\#2556](https://github.com/apache/arrow-rs/issues/2556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with string array [\#2548](https://github.com/apache/arrow-rs/issues/2548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with primitive array in `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2538](https://github.com/apache/arrow-rs/issues/2538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2535](https://github.com/apache/arrow-rs/issues/2535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- UnionBuilder Create Children With Capacity [\#2523](https://github.com/apache/arrow-rs/issues/2523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Speed up `like_utf8_scalar` for `%pat%` [\#2519](https://github.com/apache/arrow-rs/issues/2519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Replace macro with TypedDictionaryArray in comparison kernels [\#2513](https://github.com/apache/arrow-rs/issues/2513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use same codebase for boolean kernels [\#2507](https://github.com/apache/arrow-rs/issues/2507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use u8 for Decimal Precision and Scale [\#2496](https://github.com/apache/arrow-rs/issues/2496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Integrate skip row without pageIndex in SerializedPageReader in Fuzz Test [\#2475](https://github.com/apache/arrow-rs/issues/2475) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Avoid unecessary copies in Arrow IPC reader [\#2437](https://github.com/apache/arrow-rs/issues/2437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add GenericColumnReader::skip\_records Missing OffsetIndex Fallback [\#2433](https://github.com/apache/arrow-rs/issues/2433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Support Reading PageIndex with ParquetRecordBatchStream [\#2430](https://github.com/apache/arrow-rs/issues/2430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Specialize FixedLenByteArrayReader for Parquet [\#2318](https://github.com/apache/arrow-rs/issues/2318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Make JSON support Optional via Feature Flag [\#2300](https://github.com/apache/arrow-rs/issues/2300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Derive `PartialEq` and `Eq` for `parquet::arrow::ProjectionMask` [\#6329](https://github.com/apache/arrow-rs/issues/6329) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow converting empty `pyarrow.RecordBatch` to `arrow::RecordBatch` [\#6318](https://github.com/apache/arrow-rs/issues/6318) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet writer should not write any min/max data to ColumnIndex when all values are null [\#6315](https://github.com/apache/arrow-rs/issues/6315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: Add `union` method to `RowSelection` [\#6307](https://github.com/apache/arrow-rs/issues/6307) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support writing `UTC adjusted time` arrow array to parquet [\#6277](https://github.com/apache/arrow-rs/issues/6277) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- A better way to resize the buffer for the snappy encode/decode [\#6276](https://github.com/apache/arrow-rs/issues/6276) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- parquet\_derive: support reading selected columns from parquet file [\#6268](https://github.com/apache/arrow-rs/issues/6268) +- Tests for invalid parquet files [\#6261](https://github.com/apache/arrow-rs/issues/6261) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement `date_part` for `Duration` [\#6245](https://github.com/apache/arrow-rs/issues/6245) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Avoid unnecessary null buffer construction when converting arrays to a different type [\#6243](https://github.com/apache/arrow-rs/issues/6243) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `parquet_opendal` in related projects [\#6235](https://github.com/apache/arrow-rs/issues/6235) +- Look into optimizing reading FixedSizeBinary arrays from parquet [\#6219](https://github.com/apache/arrow-rs/issues/6219) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add benchmarks for `BYTE_STREAM_SPLIT` encoded Parquet `FIXED_LEN_BYTE_ARRAY` data [\#6203](https://github.com/apache/arrow-rs/issues/6203) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make it easy to write parquet to object\_store -- Implement `AsyncFileWriter` for a type that implements `obj_store::MultipartUpload` for `AsyncArrowWriter` [\#6200](https://github.com/apache/arrow-rs/issues/6200) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove test duplication in parquet statistics tets [\#6185](https://github.com/apache/arrow-rs/issues/6185) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support BinaryView Types in C Schema FFI [\#6170](https://github.com/apache/arrow-rs/issues/6170) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- speedup take\_byte\_view kernel [\#6167](https://github.com/apache/arrow-rs/issues/6167) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for `StringView` and `BinaryView` statistics in `StatisticsConverter` [\#6164](https://github.com/apache/arrow-rs/issues/6164) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting `BinaryView` --\> `Utf8` and `LargeUtf8` [\#6162](https://github.com/apache/arrow-rs/issues/6162) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `filter` kernel specially for `FixedSizeByteArray` [\#6153](https://github.com/apache/arrow-rs/issues/6153) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `LevelHistogram` throughout Parquet metadata [\#6134](https://github.com/apache/arrow-rs/issues/6134) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support DoPutStatementIngest from Arrow Flight SQL 17.0 [\#6124](https://github.com/apache/arrow-rs/issues/6124) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- ColumnMetaData should no longer be written inline with data [\#6115](https://github.com/apache/arrow-rs/issues/6115) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement date\_part for `Interval` [\#6113](https://github.com/apache/arrow-rs/issues/6113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Into>` for `ArrayData` [\#6104](https://github.com/apache/arrow-rs/issues/6104) +- Allow flushing or non-buffered writes from `arrow::ipc::writer::StreamWriter` [\#6099](https://github.com/apache/arrow-rs/issues/6099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Default block\_size for `StringViewArray` [\#6094](https://github.com/apache/arrow-rs/issues/6094) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove `Statistics::has_min_max_set` and `ValueStatistics::has_min_max_set` and use `Option` instead [\#6093](https://github.com/apache/arrow-rs/issues/6093) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Upgrade arrow-flight to tonic 0.12 [\#6072](https://github.com/apache/arrow-rs/issues/6072) +- Improve speed of row converter by skipping utf8 checks [\#6058](https://github.com/apache/arrow-rs/issues/6058) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Extend support for BYTE\_STREAM\_SPLIT to FIXED\_LEN\_BYTE\_ARRAY, INT32, and INT64 primitive types [\#6048](https://github.com/apache/arrow-rs/issues/6048) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Release arrow-rs / parquet minor version `52.2.0` \(August 2024\) [\#5998](https://github.com/apache/arrow-rs/issues/5998) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Fixed bugs:** -- Casting timestamp array to string should not ignore timezone [\#2607](https://github.com/apache/arrow-rs/issues/2607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Ilike\_ut8\_scalar kernals have incorrect logic [\#2544](https://github.com/apache/arrow-rs/issues/2544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Always validate the array data when creating array in IPC reader [\#2541](https://github.com/apache/arrow-rs/issues/2541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Int96Converter Truncates Timestamps [\#2480](https://github.com/apache/arrow-rs/issues/2480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Error Reading Page Index When Not Available [\#2434](https://github.com/apache/arrow-rs/issues/2434) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `ParquetFileArrowReader::get_record_reader[_by_colum]` `batch_size` overallocates [\#2321](https://github.com/apache/arrow-rs/issues/2321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Invalid `ColumnIndex` written in parquet [\#6310](https://github.com/apache/arrow-rs/issues/6310) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- comparison\_kernels benchmarks panic [\#6283](https://github.com/apache/arrow-rs/issues/6283) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Printing schema metadata includes possibly incorrect compression level [\#6270](https://github.com/apache/arrow-rs/issues/6270) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Don't panic when creating `Field` from `FFI_ArrowSchema` with no name [\#6251](https://github.com/apache/arrow-rs/issues/6251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- lexsort\_to\_indices should not fallback to non-lexical sort if the datatype is not supported [\#6226](https://github.com/apache/arrow-rs/issues/6226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet Statistics null\_count does not distinguish between `0` and not specified [\#6215](https://github.com/apache/arrow-rs/issues/6215) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Using a take kernel on a dense union can result in reaching "unreachable" code [\#6206](https://github.com/apache/arrow-rs/issues/6206) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Adding sub day seconds to Date64 is ignored. [\#6198](https://github.com/apache/arrow-rs/issues/6198) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- mismatch between parquet type `is_optional` codes and comment [\#6191](https://github.com/apache/arrow-rs/issues/6191) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Documentation updates:** -- Document All Arrow Features in docs.rs [\#2633](https://github.com/apache/arrow-rs/issues/2633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Minor: improve filter documentation [\#6317](https://github.com/apache/arrow-rs/pull/6317) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Improve comments on GenericByteViewArray::bytes\_iter\(\), prefix\_iter\(\) and suffix\_iter\(\) [\#6306](https://github.com/apache/arrow-rs/pull/6306) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: improve `RowFilter` and `ArrowPredicate` docs [\#6301](https://github.com/apache/arrow-rs/pull/6301) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve documentation for `MutableArrayData` [\#6272](https://github.com/apache/arrow-rs/pull/6272) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add examples to `StringViewBuilder` and `BinaryViewBuilder` [\#6240](https://github.com/apache/arrow-rs/pull/6240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- minor: enhance document for ParquetField [\#6239](https://github.com/apache/arrow-rs/pull/6239) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- Minor: Improve Type documentation [\#6224](https://github.com/apache/arrow-rs/pull/6224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Update `DateType::Date64` docs [\#6223](https://github.com/apache/arrow-rs/pull/6223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add \(more\) Parquet Metadata Documentation [\#6184](https://github.com/apache/arrow-rs/pull/6184) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add additional documentation and examples to `ArrayAccessor` [\#6141](https://github.com/apache/arrow-rs/pull/6141) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: improve comments in temporal.rs tests [\#6140](https://github.com/apache/arrow-rs/pull/6140) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Update release schedule in README [\#6125](https://github.com/apache/arrow-rs/pull/6125) ([alamb](https://github.com/alamb)) **Closed issues:** -- Add support for CAST from `Interval(DayTime)` to `Timestamp(Nanosecond, None)` [\#2606](https://github.com/apache/arrow-rs/issues/2606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Why do we check for null in TypedDictionaryArray value function [\#2564](https://github.com/apache/arrow-rs/issues/2564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add the `length` field for `Buffer` [\#2524](https://github.com/apache/arrow-rs/issues/2524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Avoid large over allocate buffer in async reader [\#2512](https://github.com/apache/arrow-rs/issues/2512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Rewriting Decimal Builders using `const_generic`. [\#2390](https://github.com/apache/arrow-rs/issues/2390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Rewrite Decimal Array using `const_generic` [\#2384](https://github.com/apache/arrow-rs/issues/2384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Simplify take octokit workflow [\#6279](https://github.com/apache/arrow-rs/issues/6279) +- Make the bearer token visible in FlightSqlServiceClient [\#6253](https://github.com/apache/arrow-rs/issues/6253) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Port `take` workflow to use `oktokit` [\#6242](https://github.com/apache/arrow-rs/issues/6242) +- Remove `SchemaBuilder` dependency from `StructArray` constructors [\#6138](https://github.com/apache/arrow-rs/issues/6138) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Add downcast macros \(\#2635\) [\#2636](https://github.com/apache/arrow-rs/pull/2636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Document all arrow features in docs.rs \(\#2633\) [\#2634](https://github.com/apache/arrow-rs/pull/2634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Document dyn\_cmp\_dict [\#2624](https://github.com/apache/arrow-rs/pull/2624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Support comparison between DictionaryArray and BooleanArray [\#2618](https://github.com/apache/arrow-rs/pull/2618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Cast timestamp array to string array with timezone [\#2608](https://github.com/apache/arrow-rs/pull/2608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Support empty projection in CSV and JSON readers [\#2604](https://github.com/apache/arrow-rs/pull/2604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Make JSON support optional via a feature flag \(\#2300\) [\#2601](https://github.com/apache/arrow-rs/pull/2601) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray [\#2600](https://github.com/apache/arrow-rs/pull/2600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Split out integration test plumbing \(\#2594\) \(\#2300\) [\#2598](https://github.com/apache/arrow-rs/pull/2598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Refactor Binary Builder and String Builder Constructors [\#2592](https://github.com/apache/arrow-rs/pull/2592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Dictionary like scalar kernels [\#2591](https://github.com/apache/arrow-rs/pull/2591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Validate dictionary key in TypedDictionaryArray \(\#2578\) [\#2589](https://github.com/apache/arrow-rs/pull/2589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2585](https://github.com/apache/arrow-rs/pull/2585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Code cleanup of array value functions [\#2583](https://github.com/apache/arrow-rs/pull/2583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Allow overriding of do\_get & export useful macro [\#2582](https://github.com/apache/arrow-rs/pull/2582) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) -- MINOR: Upgrade to pyo3 0.17 [\#2576](https://github.com/apache/arrow-rs/pull/2576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) -- Support SQL-compliant NaN behavior on eq\_dyn, neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#2570](https://github.com/apache/arrow-rs/pull/2570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Add sum\_dyn to calculate sum for dictionary array [\#2566](https://github.com/apache/arrow-rs/pull/2566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- struct UnionBuilder will create child buffers with capacity [\#2560](https://github.com/apache/arrow-rs/pull/2560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kastolars](https://github.com/kastolars)) -- Don't panic on RleValueEncoder::flush\_buffer if empty \(\#2558\) [\#2559](https://github.com/apache/arrow-rs/pull/2559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Add the `length` field for Buffer and use more `Buffer` in IPC reader to avoid memory copy. [\#2557](https://github.com/apache/arrow-rs/pull/2557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([HaoYang670](https://github.com/HaoYang670)) -- Add test for float nan comparison [\#2555](https://github.com/apache/arrow-rs/pull/2555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Compare dictionary array with string array [\#2549](https://github.com/apache/arrow-rs/pull/2549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Always validate the array data \(except the `Decimal`\) when creating array in IPC reader [\#2547](https://github.com/apache/arrow-rs/pull/2547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- MINOR: Fix test\_row\_type\_validation test [\#2546](https://github.com/apache/arrow-rs/pull/2546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Fix ilike\_utf8\_scalar kernals [\#2545](https://github.com/apache/arrow-rs/pull/2545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- fix typo [\#2540](https://github.com/apache/arrow-rs/pull/2540) ([00Masato](https://github.com/00Masato)) -- Compare dictionary array and primitive array in lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn kernels [\#2539](https://github.com/apache/arrow-rs/pull/2539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- \[MINOR\]Avoid large over allocate buffer in async reader [\#2537](https://github.com/apache/arrow-rs/pull/2537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) -- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2533](https://github.com/apache/arrow-rs/pull/2533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Add iterator for FixedSizeBinaryArray [\#2531](https://github.com/apache/arrow-rs/pull/2531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- add bench: decimal with byte array and fixed length byte array [\#2529](https://github.com/apache/arrow-rs/pull/2529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) -- Add FixedLengthByteArrayReader Remove ComplexObjectArrayReader [\#2528](https://github.com/apache/arrow-rs/pull/2528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Split out byte array decoders \(\#2318\) [\#2527](https://github.com/apache/arrow-rs/pull/2527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Use offset index in ParquetRecordBatchStream [\#2526](https://github.com/apache/arrow-rs/pull/2526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) -- Clean the `create_array` in IPC reader. [\#2525](https://github.com/apache/arrow-rs/pull/2525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Remove DecimalByteArrayConvert \(\#2480\) [\#2522](https://github.com/apache/arrow-rs/pull/2522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Improve performance of `%pat%` \(\>3x speedup\) [\#2521](https://github.com/apache/arrow-rs/pull/2521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- remove len field from MapBuilder [\#2520](https://github.com/apache/arrow-rs/pull/2520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Replace macro with TypedDictionaryArray in comparison kernels [\#2514](https://github.com/apache/arrow-rs/pull/2514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Avoid large over allocate buffer in sync reader [\#2511](https://github.com/apache/arrow-rs/pull/2511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) -- Avoid useless memory copies in IPC reader. [\#2510](https://github.com/apache/arrow-rs/pull/2510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Refactor boolean kernels to use same codebase [\#2508](https://github.com/apache/arrow-rs/pull/2508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Remove Int96Converter \(\#2480\) [\#2481](https://github.com/apache/arrow-rs/pull/2481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Derive PartialEq and Eq for parquet::arrow::ProjectionMask [\#6330](https://github.com/apache/arrow-rs/pull/6330) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Support zero column `RecordBatch`es in pyarrow integration \(use RecordBatchOptions when converting a pyarrow RecordBatch\) [\#6320](https://github.com/apache/arrow-rs/pull/6320) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Michael-J-Ward](https://github.com/Michael-J-Ward)) +- Fix writing of invalid Parquet ColumnIndex when row group contains null pages [\#6319](https://github.com/apache/arrow-rs/pull/6319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adriangb](https://github.com/adriangb)) +- Pass empty vectors as min/max for all null pages when building ColumnIndex [\#6316](https://github.com/apache/arrow-rs/pull/6316) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update tonic-build requirement from =0.12.0 to =0.12.2 [\#6314](https://github.com/apache/arrow-rs/pull/6314) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Parquet: add `union` method to `RowSelection` [\#6308](https://github.com/apache/arrow-rs/pull/6308) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sdd](https://github.com/sdd)) +- Specialize filter for structs and sparse unions [\#6304](https://github.com/apache/arrow-rs/pull/6304) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- Err on `try_from_le_slice` [\#6295](https://github.com/apache/arrow-rs/pull/6295) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([samuelcolvin](https://github.com/samuelcolvin)) +- fix reference in doctest to size\_of which is not imported by default [\#6286](https://github.com/apache/arrow-rs/pull/6286) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rtyler](https://github.com/rtyler)) +- Support writing UTC adjusted time arrays to parquet [\#6278](https://github.com/apache/arrow-rs/pull/6278) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([aykut-bozkurt](https://github.com/aykut-bozkurt)) +- Minor: `pub use ByteView` in arrow and improve documentation [\#6275](https://github.com/apache/arrow-rs/pull/6275) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix accessing name from ffi schema [\#6273](https://github.com/apache/arrow-rs/pull/6273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- Do not print compression level in schema printer [\#6271](https://github.com/apache/arrow-rs/pull/6271) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ttencate](https://github.com/ttencate)) +- ci: use octokit to add assignee [\#6267](https://github.com/apache/arrow-rs/pull/6267) ([dsgibbons](https://github.com/dsgibbons)) +- Add tests for bad parquet files [\#6262](https://github.com/apache/arrow-rs/pull/6262) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add `Statistics::distinct_count_opt` and deprecate `Statistics::distinct_count` [\#6259](https://github.com/apache/arrow-rs/pull/6259) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: move `FallibleRequestStream` and `FallibleTonicResponseStream` to a module [\#6258](https://github.com/apache/arrow-rs/pull/6258) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Make the bearer token visible in FlightSqlServiceClient [\#6254](https://github.com/apache/arrow-rs/pull/6254) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([ccciudatu](https://github.com/ccciudatu)) +- Use `unary()` for array conversion in Parquet array readers, speed up `Decimal128`, `Decimal256` and `Float16` [\#6252](https://github.com/apache/arrow-rs/pull/6252) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) +- Update tower requirement from 0.4.13 to 0.5.0 [\#6250](https://github.com/apache/arrow-rs/pull/6250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Implement date\_part for durations [\#6246](https://github.com/apache/arrow-rs/pull/6246) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nrc](https://github.com/nrc)) +- Remove unnecessary null buffer construction when converting arrays to a different type [\#6244](https://github.com/apache/arrow-rs/pull/6244) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) +- Implement PartialEq for GenericByteViewArray [\#6241](https://github.com/apache/arrow-rs/pull/6241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Remove non standard footer from LICENSE.txt / reference to Apache Aurora [\#6237](https://github.com/apache/arrow-rs/pull/6237) ([alamb](https://github.com/alamb)) +- docs: Add parquet\_opendal in related projects [\#6236](https://github.com/apache/arrow-rs/pull/6236) ([Xuanwo](https://github.com/Xuanwo)) +- Avoid infinite loop in bad parquet by checking the number of rep levels [\#6232](https://github.com/apache/arrow-rs/pull/6232) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jp0317](https://github.com/jp0317)) +- Specialize Prefix/Suffix Match for `Like/ILike` between Array and Scalar for StringViewArray [\#6231](https://github.com/apache/arrow-rs/pull/6231) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xinlifoobar](https://github.com/xinlifoobar)) +- fix: lexsort\_to\_indices should not fallback to non-lexical sort if the datatype is not supported [\#6225](https://github.com/apache/arrow-rs/pull/6225) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Modest improvement to FixedLenByteArray BYTE\_STREAM\_SPLIT arrow decoder [\#6222](https://github.com/apache/arrow-rs/pull/6222) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improve performance of `FixedLengthBinary` decoding [\#6220](https://github.com/apache/arrow-rs/pull/6220) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update documentation for Parquet BYTE\_STREAM\_SPLIT encoding [\#6212](https://github.com/apache/arrow-rs/pull/6212) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improve interval parsing [\#6211](https://github.com/apache/arrow-rs/pull/6211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- minor: Suggest take on interleave docs [\#6210](https://github.com/apache/arrow-rs/pull/6210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- fix: Correctly handle take on dense union of a single selected type [\#6209](https://github.com/apache/arrow-rs/pull/6209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- Add time dictionary coercions [\#6208](https://github.com/apache/arrow-rs/pull/6208) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) +- fix\(arrow\): restrict the range of temporal values produced via `data_gen` [\#6205](https://github.com/apache/arrow-rs/pull/6205) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kyle-mccarthy](https://github.com/kyle-mccarthy)) +- Add benchmarks for `BYTE_STREAM_SPLIT` encoded Parquet `FIXED_LEN_BYTE_ARRAY` data [\#6204](https://github.com/apache/arrow-rs/pull/6204) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Move `ParquetMetadataWriter` to its own module, update documentation [\#6202](https://github.com/apache/arrow-rs/pull/6202) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add `ThriftMetadataWriter` for writing Parquet metadata [\#6197](https://github.com/apache/arrow-rs/pull/6197) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adriangb](https://github.com/adriangb)) +- Update zstd-sys requirement from \>=2.0.0, \<2.0.13 to \>=2.0.0, \<2.0.14 [\#6196](https://github.com/apache/arrow-rs/pull/6196) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix parquet type `is_optional` comments [\#6192](https://github.com/apache/arrow-rs/pull/6192) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jp0317](https://github.com/jp0317)) +- Remove duplicated statistics tests in parquet [\#6190](https://github.com/apache/arrow-rs/pull/6190) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Kev1n8](https://github.com/Kev1n8)) +- Benchmarks for `bool_and` [\#6189](https://github.com/apache/arrow-rs/pull/6189) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- Fix typo in documentation of Float64Array [\#6188](https://github.com/apache/arrow-rs/pull/6188) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mesejo](https://github.com/mesejo)) +- Make it clear that `StatisticsConverter` can not panic [\#6187](https://github.com/apache/arrow-rs/pull/6187) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- add filter benchmark for `FixedSizeBinaryArray` [\#6186](https://github.com/apache/arrow-rs/pull/6186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chloro-pn](https://github.com/chloro-pn)) +- Update sysinfo requirement from 0.30.12 to 0.31.2 [\#6182](https://github.com/apache/arrow-rs/pull/6182) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add support for `StringView` and `BinaryView` statistics in `StatisticsConverter` [\#6181](https://github.com/apache/arrow-rs/pull/6181) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Kev1n8](https://github.com/Kev1n8)) +- Support casting between BinaryView \<--\> Utf8 and LargeUtf8 [\#6180](https://github.com/apache/arrow-rs/pull/6180) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xinlifoobar](https://github.com/xinlifoobar)) +- Implement specialized filter kernel for `FixedSizeByteArray` [\#6178](https://github.com/apache/arrow-rs/pull/6178) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chloro-pn](https://github.com/chloro-pn)) +- Support `StringView` and `BinaryView` in CDataInterface [\#6171](https://github.com/apache/arrow-rs/pull/6171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([a10y](https://github.com/a10y)) +- Optimize `take` kernel for `BinaryViewArray` and `StringViewArray` [\#6168](https://github.com/apache/arrow-rs/pull/6168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([a10y](https://github.com/a10y)) +- Support Parquet `BYTE_STREAM_SPLIT` for INT32, INT64, and FIXED\_LEN\_BYTE\_ARRAY primitive types [\#6159](https://github.com/apache/arrow-rs/pull/6159) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix comparison kernel benchmarks [\#6147](https://github.com/apache/arrow-rs/pull/6147) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- improve `LIKE` regex performance up to 12x [\#6145](https://github.com/apache/arrow-rs/pull/6145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- Optimize `min_boolean` and `bool_and` [\#6144](https://github.com/apache/arrow-rs/pull/6144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([simonvandel](https://github.com/simonvandel)) +- Reduce bounds check in `RowIter`, add `unsafe Rows::row_unchecked` [\#6142](https://github.com/apache/arrow-rs/pull/6142) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Minor: Simplify `StructArray` constructors [\#6139](https://github.com/apache/arrow-rs/pull/6139) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Rafferty97](https://github.com/Rafferty97)) +- Implement exponential block size growing strategy for `StringViewBuilder` [\#6136](https://github.com/apache/arrow-rs/pull/6136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Use `LevelHistogram` in `PageIndex` [\#6135](https://github.com/apache/arrow-rs/pull/6135) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add ArrowError::ArithmeticError [\#6130](https://github.com/apache/arrow-rs/pull/6130) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Improve `LIKE` performance for "contains" style queries [\#6128](https://github.com/apache/arrow-rs/pull/6128) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- Add `BooleanArray::new_from_packed` and `BooleanArray::new_from_u8` [\#6127](https://github.com/apache/arrow-rs/pull/6127) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chloro-pn](https://github.com/chloro-pn)) +- improvements to `(i)starts_with` and `(i)ends_with` performance [\#6118](https://github.com/apache/arrow-rs/pull/6118) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samuelcolvin](https://github.com/samuelcolvin)) +- Fix Clippy for the Rust 1.80 release [\#6116](https://github.com/apache/arrow-rs/pull/6116) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- added a flush method to IPC writers [\#6108](https://github.com/apache/arrow-rs/pull/6108) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([V0ldek](https://github.com/V0ldek)) +- Add support for level histograms added in PARQUET-2261 to `ParquetMetaData` [\#6105](https://github.com/apache/arrow-rs/pull/6105) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Implement date\_part for intervals [\#6071](https://github.com/apache/arrow-rs/pull/6071) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nrc](https://github.com/nrc)) +- feat\(parquet\): Implement AsyncFileWriter for `object_store::buffered::BufWriter` [\#6013](https://github.com/apache/arrow-rs/pull/6013) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Xuanwo](https://github.com/Xuanwo)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 67121f6cd5a3..e0adc18a9a60 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -25,7 +25,41 @@ We welcome and encourage contributions of all kinds, such as: 2. Documentation improvements 3. Code (PR or PR Review) -In addition to submitting new PRs, we have a healthy tradition of community members helping review each other's PRs. Doing so is a great way to help the community as well as get more familiar with Rust and the relevant codebases. +In addition to submitting new PRs, we have a healthy tradition of community +members helping review each other's PRs. Doing so is a great way to help the +community as well as get more familiar with Rust and the relevant codebases. + +## Finding and Creating Issues to Work On + +You can find a curated [good-first-issue] list to help you get started. + +Arrow-rs is an open contribution project, and thus there is no particular +project imposed deadline for completing any issue or any restriction on who can +work on an issue, nor how many people can work on an issue at the same time. + +Contributors drive the project forward based on their own priorities and +interests and thus you are free to work on any issue that interests you. + +If someone is already working on an issue that you want or need but hasn't +been able to finish it yet, you should feel free to work on it as well. In +general it is both polite and will help avoid unnecessary duplication of work if +you leave a note on an issue when you start working on it. + +If you want to work on an issue which is not already assigned to someone else +and there are no comment indicating that someone is already working on that +issue then you can assign the issue to yourself by submitting a single word +comment `take`. This will assign the issue to yourself. However, if you are +unable to make progress you should unassign the issue by using the `unassign me` +link at the top of the issue page (and ask for help if are stuck) so that +someone else can get involved in the work. + +If you plan to work on a new feature that doesn't have an existing ticket, it is +a good idea to open a ticket to discuss the feature. Advanced discussion often +helps avoid wasted effort by determining early if the feature is a good fit for +Arrow-rs before too much time is invested. It also often helps to discuss your +ideas with the community to get feedback on implementation. + +[good-first-issue]: https://github.com/apache/arrow-rs/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22 ## Developer's guide to Arrow Rust @@ -92,19 +126,31 @@ export ARROW_TEST_DATA=$(cd ../testing/data; pwd) From here on, this is a pure Rust project and `cargo` can be used to run tests, benchmarks, docs and examples as usual. -### Running the tests +## Running the tests Run tests using the Rust standard `cargo test` command: ```bash -# run all tests. +# run all unit and integration tests cargo test - -# run only tests for the arrow crate +# run tests for the arrow crate cargo test -p arrow ``` +For some changes, you may want to run additional tests. You can find up-to-date information on the current CI tests in [.github/workflows](https://github.com/apache/arrow-rs/tree/master/.github/workflows). Here are some examples of additional tests you may want to run: + +```bash +# run tests for the parquet crate +cargo test -p parquet + +# run arrow tests with all features enabled +cargo test -p arrow --all-features + +# run the doc tests +cargo test --doc +``` + ## Code Formatting Our CI uses `rustfmt` to check code formatting. Before submitting a @@ -114,14 +160,33 @@ PR be sure to run the following and check for lint issues: cargo +stable fmt --all -- --check ``` +Note that currently the above will not check all source files in the parquet crate. To check all +parquet files run the following from the top-level `arrow-rs` directory: + +```bash +cargo fmt -p parquet -- --check --config skip_children=true `find . -name "*.rs" \! -name format.rs` +``` + +## Breaking Changes + +Our [release schedule] allows breaking API changes only in major releases. +This means that if your PR has a breaking API change, it should be marked as +`api-change` and it will not be merged until development opens for the next +major release. See [this ticket] for details. + +[release schedule]: README.md#release-versioning-and-schedule +[this ticket]: https://github.com/apache/arrow-rs/issues/5907 + ## Clippy Lints -We recommend using `clippy` for checking lints during development. While we do not yet enforce `clippy` checks, we recommend not introducing new `clippy` errors or warnings. +We use `clippy` for checking lints during development, and CI runs `clippy` checks. -Run the following to check for clippy lints. +Run the following to check for `clippy` lints: ```bash -cargo clippy +# run clippy with default settings +cargo clippy --workspace --all-targets --all-features -- -D warnings + ``` If you use Visual Studio Code with the `rust-analyzer` plugin, you can enable `clippy` to run each time you save a file. See https://users.rust-lang.org/t/how-to-use-clippy-in-vs-code-with-rust-analyzer/41881. @@ -134,6 +199,33 @@ Search for `allow(clippy::` in the codebase to identify lints that are ignored/a - If you have several lints on a function or module, you may disable the lint on the function or module. - If a lint is pervasive across multiple modules, you may disable it at the crate level. +## Running Benchmarks + +Running benchmarks are a good way to test the performance of a change. As benchmarks usually take a long time to run, we recommend running targeted tests instead of the full suite. + +```bash +# run all benchmarks +cargo bench + +# run arrow benchmarks +cargo bench -p arrow + +# run benchmark for the parse_time function within the arrow-cast crate +cargo bench -p arrow-cast --bench parse_time +``` + +To set the baseline for your benchmarks, use the --save-baseline flag: + +```bash +git checkout master + +cargo bench --bench parse_time -- --save-baseline master + +git checkout feature + +cargo bench --bench parse_time -- --baseline master +``` + ## Git Pre-Commit Hook We can use [git pre-commit hook](https://git-scm.com/book/en/v2/Customizing-Git-Git-Hooks) to automate various kinds of git pre-commit checking/formatting. @@ -150,7 +242,7 @@ If the file already exists, to avoid mistakenly **overriding**, you MAY have to the link source or file content. Else if not exist, let's safely soft link [pre-commit.sh](pre-commit.sh) as file `.git/hooks/pre-commit`: ```bash -ln -s ../../rust/pre-commit.sh .git/hooks/pre-commit +ln -s ../../pre-commit.sh .git/hooks/pre-commit ``` If sometimes you want to commit without checking, just run `git commit` with `--no-verify`: diff --git a/Cargo.toml b/Cargo.toml index 9bf55c0f2360..3b274d583437 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,15 +16,32 @@ # under the License. [workspace] + members = [ - "arrow", - "parquet", - "parquet_derive", - "parquet_derive_test", - "arrow-flight", - "integration-testing", - "object_store", + "arrow", + "arrow-arith", + "arrow-array", + "arrow-avro", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-flight", + "arrow-flight/gen", + "arrow-integration-test", + "arrow-integration-testing", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "parquet", + "parquet_derive", + "parquet_derive_test", ] + # Enable the version 2 feature resolver, which avoids unifying features for targets that are not being built # # Critically this prevents dev-dependencies from enabling features even when not building a target that @@ -35,7 +52,45 @@ members = [ # resolver = "2" -# this package is excluded because it requires different compilation flags, thereby significantly changing -# how it is compiled within the workspace, causing the whole workspace to be compiled from scratch -# this way, this is a stand-alone package that compiles independently of the others. -exclude = ["arrow-pyarrow-integration-testing"] +exclude = [ + # arrow-pyarrow-integration-testing is excluded because it requires different compilation flags, thereby + # significantly changing how it is compiled within the workspace, causing the whole workspace to be compiled from + # scratch this way, this is a stand-alone package that compiles independently of the others. + "arrow-pyarrow-integration-testing", + # object_store is excluded because it follows a separate release cycle from the other arrow crates + "object_store" +] + +[workspace.package] +version = "53.0.0" +homepage = "https://github.com/apache/arrow-rs" +repository = "https://github.com/apache/arrow-rs" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = ["arrow"] +include = [ + "benches/*.rs", + "src/**/*.rs", + "Cargo.toml", +] +edition = "2021" +rust-version = "1.62" + +[workspace.dependencies] +arrow = { version = "53.0.0", path = "./arrow", default-features = false } +arrow-arith = { version = "53.0.0", path = "./arrow-arith" } +arrow-array = { version = "53.0.0", path = "./arrow-array" } +arrow-buffer = { version = "53.0.0", path = "./arrow-buffer" } +arrow-cast = { version = "53.0.0", path = "./arrow-cast" } +arrow-csv = { version = "53.0.0", path = "./arrow-csv" } +arrow-data = { version = "53.0.0", path = "./arrow-data" } +arrow-ipc = { version = "53.0.0", path = "./arrow-ipc" } +arrow-json = { version = "53.0.0", path = "./arrow-json" } +arrow-ord = { version = "53.0.0", path = "./arrow-ord" } +arrow-row = { version = "53.0.0", path = "./arrow-row" } +arrow-schema = { version = "53.0.0", path = "./arrow-schema" } +arrow-select = { version = "53.0.0", path = "./arrow-select" } +arrow-string = { version = "53.0.0", path = "./arrow-string" } +parquet = { version = "53.0.0", path = "./parquet", default-features = false } + +chrono = { version = "0.4.34", default-features = false, features = ["clock"] } diff --git a/LICENSE.txt b/LICENSE.txt index d74c6b599d2a..d64569567334 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -200,13 +200,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - - -This project includes code from Apache Aurora. - -* dev/release/{release,changelog,release-candidate} are based on the scripts from - Apache Aurora - -Copyright: 2016 The Apache Software Foundation. -Home page: https://aurora.apache.org/ -License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/README.md b/README.md index 55bdad6cb55c..a525490843cd 100644 --- a/README.md +++ b/README.md @@ -17,39 +17,107 @@ under the License. --> -# Native Rust implementation of Apache Arrow and Parquet +# Native Rust implementation of Apache Arrow and Apache Parquet [![Coverage Status](https://codecov.io/gh/apache/arrow-rs/rust/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/arrow-rs?branch=master) -Welcome to the implementation of Arrow, the popular in-memory columnar format, in [Rust][rust]. +Welcome to the [Rust][rust] implementation of [Apache Arrow], the popular in-memory columnar format. This repo contains the following main components: -| Crate | Description | Documentation | -| ------------ | ------------------------------------------------------------------------- | ------------------------------ | -| arrow | Core functionality (memory layout, arrays, low level computations) | [(README)][arrow-readme] | -| parquet | Support for Parquet columnar file format | [(README)][parquet-readme] | -| arrow-flight | Support for Arrow-Flight IPC protocol | [(README)][flight-readme] | -| object-store | Support for object store interactions (aws, azure, gcp, local, in-memory) | [(README)][objectstore-readme] | +| Crate | Description | Latest API Docs | README | +| ------------------ | ---------------------------------------------------------------------------- | ------------------------------------------------ | --------------------------------- | +| [`arrow`] | Core functionality (memory layout, arrays, low level computations) | [docs.rs](https://docs.rs/arrow/latest) | [(README)][arrow-readme] | +| [`arrow-flight`] | Support for Arrow-Flight IPC protocol | [docs.rs](https://docs.rs/arrow-flight/latest) | [(README)][flight-readme] | +| [`object-store`] | Support for object store interactions (aws, azure, gcp, local, in-memory) | [docs.rs](https://docs.rs/object_store/latest) | [(README)][objectstore-readme] | +| [`parquet`] | Support for Parquet columnar file format | [docs.rs](https://docs.rs/parquet/latest) | [(README)][parquet-readme] | +| [`parquet_derive`] | A crate for deriving RecordWriter/RecordReader for arbitrary, simple structs | [docs.rs](https://docs.rs/parquet-derive/latest) | [(README)][parquet-derive-readme] | -There are two related crates in a different repository +The current development version the API documentation in this repo can be found [here](https://arrow.apache.org/rust). -| Crate | Description | Documentation | -| ---------- | --------------------------------------- | ----------------------------- | -| DataFusion | In-memory query engine with SQL support | [(README)][datafusion-readme] | -| Ballista | Distributed query execution | [(README)][ballista-readme] | +[apache arrow]: https://arrow.apache.org/ +[`arrow`]: https://crates.io/crates/arrow +[`parquet`]: https://crates.io/crates/parquet +[`parquet_derive`]: https://crates.io/crates/parquet-derive +[`arrow-flight`]: https://crates.io/crates/arrow-flight +[`object-store`]: https://crates.io/crates/object-store -Collectively, these crates support a vast array of functionality for analytic computations in Rust. +## Release Versioning and Schedule -For example, you can write an SQL query or a `DataFrame` (using the `datafusion` crate), run it against a parquet file (using the `parquet` crate), evaluate it in-memory using Arrow's columnar format (using the `arrow` crate), and send to another process (using the `arrow-flight` crate). +### `arrow` and `parquet` crates -Generally speaking, the `arrow` crate offers functionality for using Arrow arrays, and `datafusion` offers most operations typically found in SQL, including `join`s and window functions. +The Arrow Rust project releases approximately monthly and follows [Semantic +Versioning]. + +Due to available maintainer and testing bandwidth, [`arrow`] crates ([`arrow`], +[`arrow-flight`], etc.) are released on the same schedule with the same versions +as the [`parquet`] and [`parquet-derive`] crates. + +This crate releases every month. We release new major versions (with potentially +breaking API changes) at most once a quarter, and release incremental minor +versions in the intervening months. See [this ticket] for more details. + +To keep our maintenance burden down, we do regularly scheduled releases (major +and minor) from the `master` branch. How we handle PRs with breaking API changes +is described in the [contributing] guide. + +[contributing]: CONTRIBUTING.md#breaking-changes + +Planned Release Schedule + +| Approximate Date | Version | Notes | +| ---------------- | -------- | --------------------------------------- | +| Sep 2024 | `53.0.0` | Major, potentially breaking API changes | +| Oct 2024 | `53.1.0` | Minor, NO breaking API changes | +| Nov 2024 | `53.2.0` | Minor, NO breaking API changes | +| Dec 2024 | `54.0.0` | Major, potentially breaking API changes | + +[this ticket]: https://github.com/apache/arrow-rs/issues/5368 +[semantic versioning]: https://semver.org/ + +### `object_store` crate + +The [`object_store`] crate is released independently of the `arrow` and +`parquet` crates and follows [Semantic Versioning]. We aim to release new +versions approximately every 2 months. + +[`object_store`]: https://crates.io/crates/object_store + +## Related Projects + +There are several related crates in different repositories + +| Crate | Description | Documentation | +| ------------------------ | ------------------------------------------- | --------------------------------------- | +| [`datafusion`] | In-memory query engine with SQL support | [(README)][datafusion-readme] | +| [`ballista`] | Distributed query execution | [(README)][ballista-readme] | +| [`object_store_opendal`] | Use [`opendal`] as [`object_store`] backend | [(README)][object_store_opendal-readme] | +| [`parquet_opendal`] | Use [`opendal`] for [`parquet`] Arrow IO | [(README)][parquet_opendal-readme] | + +[`datafusion`]: https://crates.io/crates/datafusion +[`ballista`]: https://crates.io/crates/ballista +[`object_store_opendal`]: https://crates.io/crates/object_store_opendal +[`opendal`]: https://crates.io/crates/opendal +[object_store_opendal-readme]: https://github.com/apache/opendal/blob/main/integrations/object_store/README.md +[`parquet_opendal`]: https://crates.io/crates/parquet_opendal +[parquet_opendal-readme]: https://github.com/apache/opendal/blob/main/integrations/parquet/README.md + +Collectively, these crates support a wider array of functionality for analytic computations in Rust. + +For example, you can write SQL queries or a `DataFrame` (using the +[`datafusion`] crate) to read a parquet file (using the [`parquet`] crate), +evaluate it in-memory using Arrow's columnar format (using the [`arrow`] crate), +and send to another process (using the [`arrow-flight`] crate). + +Generally speaking, the [`arrow`] crate offers functionality for using Arrow +arrays, and [`datafusion`] offers most operations typically found in SQL, +including `join`s and window functions. You can find more details about each crate in their respective READMEs. ## Arrow Rust Community -The `dev@arrow.apache.org` mailing list serves as the core communication channel for the Arrow community. Instructions for signing up and links to the archives can be found at the [Arrow Community](https://arrow.apache.org/community/) page. All major announcements and communications happen there. +The `dev@arrow.apache.org` mailing list serves as the core communication channel for the Arrow community. Instructions for signing up and links to the archives can be found on the [Arrow Community](https://arrow.apache.org/community/) page. All major announcements and communications happen there. The Rust Arrow community also uses the official [ASF Slack](https://s.apache.org/slack-invite) for informal discussions and coordination. This is a great place to meet other contributors and get guidance on where to contribute. Join us in the `#arrow-rust` channel and feel free to ask for an invite via: @@ -58,8 +126,8 @@ a great place to meet other contributors and get guidance on where to contribute 2. the [GitHub Discussions][discussions] 3. the [Discord channel](https://discord.gg/YAb2TdazKQ) -Unlike other parts of the Arrow ecosystem, the Rust implementation uses [GitHub issues][issues] as the system of record for new features -and bug fixes and this plays a critical role in the release process. +The Rust implementation uses [GitHub issues][issues] as the system of record for new features and bug fixes and +this plays a critical role in the release process. For design discussions we generally collaborate on Google documents and file a GitHub issue linking to the document. @@ -70,8 +138,9 @@ There is more information in the [contributing] guide. [contributing]: CONTRIBUTING.md [parquet-readme]: parquet/README.md [flight-readme]: arrow-flight/README.md -[datafusion-readme]: https://github.com/apache/arrow-datafusion/blob/master/README.md -[ballista-readme]: https://github.com/apache/arrow-ballista/blob/master/README.md -[objectstore-readme]: https://github.com/apache/arrow-rs/blob/master/object_store/README.md +[datafusion-readme]: https://github.com/apache/datafusion/blob/main/README.md +[ballista-readme]: https://github.com/apache/datafusion-ballista/blob/main/README.md +[objectstore-readme]: object_store/README.md +[parquet-derive-readme]: parquet_derive/README.md [issues]: https://github.com/apache/arrow-rs/issues [discussions]: https://github.com/apache/arrow-rs/discussions diff --git a/arrow-arith/Cargo.toml b/arrow-arith/Cargo.toml new file mode 100644 index 000000000000..d2ee0b9e2c72 --- /dev/null +++ b/arrow-arith/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-arith" +version = { workspace = true } +description = "Arrow arithmetic kernels" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_arith" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +half = { version = "2.1", default-features = false } +num = { version = "0.4", default-features = false, features = ["std"] } + +[dev-dependencies] diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs new file mode 100644 index 000000000000..a4915f5893df --- /dev/null +++ b/arrow-arith/src/aggregate.rs @@ -0,0 +1,1704 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines aggregations over Arrow arrays. + +use arrow_array::cast::*; +use arrow_array::iterator::ArrayIter; +use arrow_array::*; +use arrow_buffer::{ArrowNativeType, NullBuffer}; +use arrow_data::bit_iterator::try_for_each_valid_idx; +use arrow_schema::*; +use std::borrow::BorrowMut; +use std::cmp::{self, Ordering}; +use std::ops::{BitAnd, BitOr, BitXor}; +use types::ByteViewType; + +/// An accumulator for primitive numeric values. +trait NumericAccumulator: Copy + Default { + /// Accumulate a non-null value. + fn accumulate(&mut self, value: T); + /// Accumulate a nullable values. + /// If `valid` is false the `value` should not affect the accumulator state. + fn accumulate_nullable(&mut self, value: T, valid: bool); + /// Merge another accumulator into this accumulator + fn merge(&mut self, other: Self); + /// Return the aggregated value. + fn finish(&mut self) -> T; +} + +/// Helper for branchlessly selecting either `a` or `b` based on the boolean `m`. +/// After verifying the generated assembly this can be a simple `if`. +#[inline(always)] +fn select(m: bool, a: T, b: T) -> T { + if m { + a + } else { + b + } +} + +#[derive(Clone, Copy)] +struct SumAccumulator { + sum: T, +} + +impl Default for SumAccumulator { + fn default() -> Self { + Self { sum: T::ZERO } + } +} + +impl NumericAccumulator for SumAccumulator { + fn accumulate(&mut self, value: T) { + self.sum = self.sum.add_wrapping(value); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let sum = self.sum; + self.sum = select(valid, sum.add_wrapping(value), sum) + } + + fn merge(&mut self, other: Self) { + self.sum = self.sum.add_wrapping(other.sum); + } + + fn finish(&mut self) -> T { + self.sum + } +} + +#[derive(Clone, Copy)] +struct MinAccumulator { + min: T, +} + +impl Default for MinAccumulator { + fn default() -> Self { + Self { + min: T::MAX_TOTAL_ORDER, + } + } +} + +impl NumericAccumulator for MinAccumulator { + fn accumulate(&mut self, value: T) { + let min = self.min; + self.min = select(value.is_lt(min), value, min); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let min = self.min; + let is_lt = valid & value.is_lt(min); + self.min = select(is_lt, value, min); + } + + fn merge(&mut self, other: Self) { + self.accumulate(other.min) + } + + fn finish(&mut self) -> T { + self.min + } +} + +#[derive(Clone, Copy)] +struct MaxAccumulator { + max: T, +} + +impl Default for MaxAccumulator { + fn default() -> Self { + Self { + max: T::MIN_TOTAL_ORDER, + } + } +} + +impl NumericAccumulator for MaxAccumulator { + fn accumulate(&mut self, value: T) { + let max = self.max; + self.max = select(value.is_gt(max), value, max); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let max = self.max; + let is_gt = value.is_gt(max) & valid; + self.max = select(is_gt, value, max); + } + + fn merge(&mut self, other: Self) { + self.accumulate(other.max) + } + + fn finish(&mut self) -> T { + self.max + } +} + +fn reduce_accumulators, const LANES: usize>( + mut acc: [A; LANES], +) -> A { + assert!(LANES > 0 && LANES.is_power_of_two()); + let mut len = LANES; + + // attempt at tree reduction, unfortunately llvm does not fully recognize this pattern, + // but the generated code is still a little faster than purely sequential reduction for floats. + while len >= 2 { + let mid = len / 2; + let (h, t) = acc[..len].split_at_mut(mid); + + for i in 0..mid { + h[i].merge(t[i]); + } + len /= 2; + } + acc[0] +} + +#[inline(always)] +fn aggregate_nonnull_chunk, const LANES: usize>( + acc: &mut [A; LANES], + values: &[T; LANES], +) { + for i in 0..LANES { + acc[i].accumulate(values[i]); + } +} + +#[inline(always)] +fn aggregate_nullable_chunk, const LANES: usize>( + acc: &mut [A; LANES], + values: &[T; LANES], + validity: u64, +) { + let mut bit = 1; + for i in 0..LANES { + acc[i].accumulate_nullable(values[i], (validity & bit) != 0); + bit <<= 1; + } +} + +fn aggregate_nonnull_simple>(values: &[T]) -> T { + return values + .iter() + .copied() + .fold(A::default(), |mut a, b| { + a.accumulate(b); + a + }) + .finish(); +} + +#[inline(never)] +fn aggregate_nonnull_lanes, const LANES: usize>( + values: &[T], +) -> T { + // aggregating into multiple independent accumulators allows the compiler to use vector registers + // with a single accumulator the compiler would not be allowed to reorder floating point addition + let mut acc = [A::default(); LANES]; + let mut chunks = values.chunks_exact(LANES); + chunks.borrow_mut().for_each(|chunk| { + aggregate_nonnull_chunk(&mut acc, chunk[..LANES].try_into().unwrap()); + }); + + let remainder = chunks.remainder(); + for i in 0..remainder.len() { + acc[i].accumulate(remainder[i]); + } + + reduce_accumulators(acc).finish() +} + +#[inline(never)] +fn aggregate_nullable_lanes, const LANES: usize>( + values: &[T], + validity: &NullBuffer, +) -> T { + assert!(LANES > 0 && 64 % LANES == 0); + assert_eq!(values.len(), validity.len()); + + // aggregating into multiple independent accumulators allows the compiler to use vector registers + let mut acc = [A::default(); LANES]; + // we process 64 bits of validity at a time + let mut values_chunks = values.chunks_exact(64); + let validity_chunks = validity.inner().bit_chunks(); + let mut validity_chunks_iter = validity_chunks.iter(); + + values_chunks.borrow_mut().for_each(|chunk| { + // Safety: we asserted that values and validity have the same length and trust the iterator impl + let mut validity = unsafe { validity_chunks_iter.next().unwrap_unchecked() }; + // chunk further based on the number of vector lanes + chunk.chunks_exact(LANES).for_each(|chunk| { + aggregate_nullable_chunk(&mut acc, chunk[..LANES].try_into().unwrap(), validity); + validity >>= LANES; + }); + }); + + let remainder = values_chunks.remainder(); + if !remainder.is_empty() { + let mut validity = validity_chunks.remainder_bits(); + + let mut remainder_chunks = remainder.chunks_exact(LANES); + remainder_chunks.borrow_mut().for_each(|chunk| { + aggregate_nullable_chunk(&mut acc, chunk[..LANES].try_into().unwrap(), validity); + validity >>= LANES; + }); + + let remainder = remainder_chunks.remainder(); + if !remainder.is_empty() { + let mut bit = 1; + for i in 0..remainder.len() { + acc[i].accumulate_nullable(remainder[i], (validity & bit) != 0); + bit <<= 1; + } + } + } + + reduce_accumulators(acc).finish() +} + +/// The preferred vector size in bytes for the target platform. +/// Note that the avx512 target feature is still unstable and this also means it is not detected on stable rust. +const PREFERRED_VECTOR_SIZE: usize = + if cfg!(all(target_arch = "x86_64", target_feature = "avx512f")) { + 64 + } else if cfg!(all(target_arch = "x86_64", target_feature = "avx")) { + 32 + } else { + 16 + }; + +/// non-nullable aggregation requires fewer temporary registers so we can use more of them for accumulators +const PREFERRED_VECTOR_SIZE_NON_NULL: usize = PREFERRED_VECTOR_SIZE * 2; + +/// Generic aggregation for any primitive type. +/// Returns None if there are no non-null values in `array`. +fn aggregate, A: NumericAccumulator>( + array: &PrimitiveArray

, +) -> Option { + let null_count = array.null_count(); + if null_count == array.len() { + return None; + } + let values = array.values().as_ref(); + match array.nulls() { + Some(nulls) if null_count > 0 => { + // const generics depending on a generic type parameter are not supported + // so we have to match and call aggregate with the corresponding constant + match PREFERRED_VECTOR_SIZE / std::mem::size_of::() { + 64 => Some(aggregate_nullable_lanes::(values, nulls)), + 32 => Some(aggregate_nullable_lanes::(values, nulls)), + 16 => Some(aggregate_nullable_lanes::(values, nulls)), + 8 => Some(aggregate_nullable_lanes::(values, nulls)), + 4 => Some(aggregate_nullable_lanes::(values, nulls)), + 2 => Some(aggregate_nullable_lanes::(values, nulls)), + _ => Some(aggregate_nullable_lanes::(values, nulls)), + } + } + _ => { + let is_float = matches!( + array.data_type(), + DataType::Float16 | DataType::Float32 | DataType::Float64 + ); + if is_float { + match PREFERRED_VECTOR_SIZE_NON_NULL / std::mem::size_of::() { + 64 => Some(aggregate_nonnull_lanes::(values)), + 32 => Some(aggregate_nonnull_lanes::(values)), + 16 => Some(aggregate_nonnull_lanes::(values)), + 8 => Some(aggregate_nonnull_lanes::(values)), + 4 => Some(aggregate_nonnull_lanes::(values)), + 2 => Some(aggregate_nonnull_lanes::(values)), + _ => Some(aggregate_nonnull_simple::(values)), + } + } else { + // for non-null integers its better to not chunk ourselves and instead + // let llvm fully handle loop unrolling and vectorization + Some(aggregate_nonnull_simple::(values)) + } + } + } +} + +/// Returns the minimum value in the boolean array. +/// +/// ``` +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::aggregate::min_boolean; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(min_boolean(&a), Some(false)) +/// ``` +pub fn min_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + if array.null_count() == array.len() { + return None; + } + + // Note the min bool is false (0), so short circuit as soon as we see it + match array.nulls() { + None => { + let bit_chunks = array.values().bit_chunks(); + if bit_chunks.iter().any(|x| { + // u64::MAX has all bits set, so if the value is not that, then there is a false + x != u64::MAX + }) { + return Some(false); + } + // If the remainder bits are not all set, then there is a false + if bit_chunks.remainder_bits().count_ones() as usize != bit_chunks.remainder_len() { + Some(false) + } else { + Some(true) + } + } + Some(nulls) => { + let validity_chunks = nulls.inner().bit_chunks(); + let value_chunks = array.values().bit_chunks(); + + if value_chunks + .iter() + .zip(validity_chunks.iter()) + .any(|(value, validity)| { + // We are looking for a false value, but because applying the validity mask + // can create a false for a true value (e.g. value: true, validity: false), we instead invert the value, so that we have to look for a true. + (!value & validity) != 0 + }) + { + return Some(false); + } + + // Same trick as above: Instead of looking for a false, we invert the value bits and look for a true + if (!value_chunks.remainder_bits() & validity_chunks.remainder_bits()) != 0 { + Some(false) + } else { + Some(true) + } + } + } +} + +/// Returns the maximum value in the boolean array +/// +/// ``` +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::aggregate::max_boolean; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(max_boolean(&a), Some(true)) +/// ``` +pub fn max_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + if array.null_count() == array.len() { + return None; + } + + // Note the max bool is true (1), so short circuit as soon as we see it + match array.nulls() { + None => array + .values() + .bit_chunks() + .iter_padded() + // We found a true if any bit is set + .map(|x| x != 0) + .find(|b| *b) + .or(Some(false)), + Some(nulls) => { + let validity_chunks = nulls.inner().bit_chunks().iter_padded(); + let value_chunks = array.values().bit_chunks().iter_padded(); + value_chunks + .zip(validity_chunks) + // We found a true if the value bit is 1, AND the validity bit is 1 for any bits in the chunk + .map(|(value_bits, validity_bits)| (value_bits & validity_bits) != 0) + .find(|b| *b) + .or(Some(false)) + } + } +} + +/// Helper to compute min/max of [`ArrayAccessor`]. +fn min_max_helper, F>(array: A, cmp: F) -> Option +where + F: Fn(&T, &T) -> bool, +{ + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + // JUSTIFICATION + // Benefit: ~8% speedup + // Soundness: `i` is always within the array bounds + (0..array.len()) + .map(|i| unsafe { array.value_unchecked(i) }) + .reduce(|acc, item| if cmp(&acc, &item) { item } else { acc }) + } else { + let nulls = array.nulls().unwrap(); + unsafe { + let idx = nulls.valid_indices().reduce(|acc_idx, idx| { + let acc = array.value_unchecked(acc_idx); + let item = array.value_unchecked(idx); + if cmp(&acc, &item) { + idx + } else { + acc_idx + } + }); + idx.map(|idx| array.value_unchecked(idx)) + } + } +} + +/// Helper to compute min/max of [`GenericByteViewArray`]. +/// The specialized min/max leverages the inlined values to compare the byte views. +/// `swap_cond` is the condition to swap current min/max with the new value. +/// For example, `Ordering::Greater` for max and `Ordering::Less` for min. +fn min_max_view_helper( + array: &GenericByteViewArray, + swap_cond: cmp::Ordering, +) -> Option<&T::Native> { + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + let target_idx = (0..array.len()).reduce(|acc, item| { + // SAFETY: array's length is correct so item is within bounds + let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) }; + if cmp == swap_cond { + item + } else { + acc + } + }); + // SAFETY: idx came from valid range `0..array.len()` + unsafe { target_idx.map(|idx| array.value_unchecked(idx)) } + } else { + let nulls = array.nulls().unwrap(); + + let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| { + let cmp = + unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) }; + if cmp == swap_cond { + idx + } else { + acc_idx + } + }); + + // SAFETY: idx came from valid range `0..array.len()` + unsafe { target_idx.map(|idx| array.value_unchecked(idx)) } + } +} + +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b) +} + +/// Returns the maximum value in the binary view array, according to the natural order. +pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> { + min_max_view_helper(array, Ordering::Greater) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b) +} + +/// Returns the minimum value in the binary view array, according to the natural order. +pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> { + min_max_view_helper(array, Ordering::Less) +} + +/// Returns the maximum value in the string array, according to the natural order. +pub fn max_string(array: &GenericStringArray) -> Option<&str> { + min_max_helper::<&str, _, _>(array, |a, b| *a < *b) +} + +/// Returns the maximum value in the string view array, according to the natural order. +pub fn max_string_view(array: &StringViewArray) -> Option<&str> { + min_max_view_helper(array, Ordering::Greater) +} + +/// Returns the minimum value in the string array, according to the natural order. +pub fn min_string(array: &GenericStringArray) -> Option<&str> { + min_max_helper::<&str, _, _>(array, |a, b| *a > *b) +} + +/// Returns the minimum value in the string view array, according to the natural order. +pub fn min_string_view(array: &StringViewArray) -> Option<&str> { + min_max_view_helper(array, Ordering::Less) +} + +/// Returns the sum of values in the array. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `sum_array_checked` instead. +pub fn sum_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_wrapping(value) + } else { + accumulator + } + }); + + Some(sum) + } + _ => sum::(as_primitive_array(&array)), + } +} + +/// Returns the sum of values in the array. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum_array` instead. +pub fn sum_array_checked>( + array: A, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .try_fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_checked(value) + } else { + Ok(accumulator) + } + })?; + + Ok(Some(sum)) + } + _ => sum_checked::(as_primitive_array(&array)), + } +} + +/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary +/// array with value of `ArrowNumericType` type. +pub fn min_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::(array, |a, b| a.is_gt(*b), min) +} + +/// Returns the max of values in the array of `ArrowNumericType` type, or dictionary +/// array with value of `ArrowNumericType` type. +pub fn max_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + min_max_array_helper::(array, |a, b| a.is_lt(*b), max) +} + +fn min_max_array_helper, F, M>( + array: A, + cmp: F, + m: M, +) -> Option +where + T: ArrowNumericType, + F: Fn(&T::Native, &T::Native) -> bool, + M: Fn(&PrimitiveArray) -> Option, +{ + match array.data_type() { + DataType::Dictionary(_, _) => min_max_helper::(array, cmp), + _ => m(as_primitive_array(&array)), + } +} + +macro_rules! bit_operation { + ($NAME:ident, $OP:ident, $NATIVE:ident, $DEFAULT:expr, $DOC:expr) => { + #[doc = $DOC] + /// + /// Returns `None` if the array is empty or only contains null values. + pub fn $NAME(array: &PrimitiveArray) -> Option + where + T: ArrowNumericType, + T::Native: $NATIVE + ArrowNativeTypeOp, + { + let default; + if $DEFAULT == -1 { + default = T::Native::ONE.neg_wrapping(); + } else { + default = T::default_value(); + } + + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let data: &[T::Native] = array.values(); + + match array.nulls() { + None => { + let result = data + .iter() + .fold(default, |accumulator, value| accumulator.$OP(*value)); + + Some(result) + } + Some(nulls) => { + let mut result = default; + let data_chunks = data.chunks_exact(64); + let remainder = data_chunks.remainder(); + + let bit_chunks = nulls.inner().bit_chunks(); + data_chunks + .zip(bit_chunks.iter()) + .for_each(|(chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + chunk.iter().for_each(|value| { + if (mask & index_mask) != 0 { + result = result.$OP(*value); + } + index_mask <<= 1; + }); + }); + + let remainder_bits = bit_chunks.remainder_bits(); + + remainder.iter().enumerate().for_each(|(i, value)| { + if remainder_bits & (1 << i) != 0 { + result = result.$OP(*value); + } + }); + + Some(result) + } + } + } + }; +} + +bit_operation!( + bit_and, + bitand, + BitAnd, + -1, + "Returns the bitwise and of all non-null input values." +); +bit_operation!( + bit_or, + bitor, + BitOr, + 0, + "Returns the bitwise or of all non-null input values." +); +bit_operation!( + bit_xor, + bitxor, + BitXor, + 0, + "Returns the bitwise xor of all non-null input values." +); + +/// Returns true if all non-null input values are true, otherwise false. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn bool_and(array: &BooleanArray) -> Option { + min_boolean(array) +} + +/// Returns true if any non-null input value is true, otherwise false. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn bool_or(array: &BooleanArray) -> Option { + max_boolean(array) +} + +/// Returns the sum of values in the primitive array. +/// +/// Returns `Ok(None)` if the array is empty or only contains null values. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum` instead. +pub fn sum_checked(array: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let data: &[T::Native] = array.values(); + + match array.nulls() { + None => { + let sum = data + .iter() + .try_fold(T::default_value(), |accumulator, value| { + accumulator.add_checked(*value) + })?; + + Ok(Some(sum)) + } + Some(nulls) => { + let mut sum = T::default_value(); + + try_for_each_valid_idx( + nulls.len(), + nulls.offset(), + nulls.null_count(), + Some(nulls.validity()), + |idx| { + unsafe { sum = sum.add_checked(array.value_unchecked(idx))? }; + Ok::<_, ArrowError>(()) + }, + )?; + + Ok(Some(sum)) + } + } +} + +/// Returns the sum of values in the primitive array. +/// +/// Returns `None` if the array is empty or only contains null values. +/// +/// This doesn't detect overflow in release mode by default. Once overflowing, the result will +/// wrap around. For an overflow-checking variant, use `sum_checked` instead. +pub fn sum(array: &PrimitiveArray) -> Option +where + T::Native: ArrowNativeTypeOp, +{ + aggregate::>(array) +} + +/// Returns the minimum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn min(array: &PrimitiveArray) -> Option +where + T::Native: PartialOrd, +{ + aggregate::>(array) +} + +/// Returns the maximum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn max(array: &PrimitiveArray) -> Option +where + T::Native: PartialOrd, +{ + aggregate::>(array) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::*; + use builder::BooleanBuilder; + use std::sync::Arc; + + #[test] + fn test_primitive_array_sum() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_float_sum() { + let a = Float64Array::from(vec![1.1, 2.2, 3.3, 4.4, 5.5]); + assert_eq!(16.5, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_sum_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(10, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_sum_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, sum(&a)); + } + + #[test] + fn test_primitive_array_sum_large_float_64() { + let c = Float64Array::new((1..=100).map(|x| x as f64).collect(), None); + assert_eq!(Some((1..=100).sum::() as f64), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Float64Array::new((1..=100).map(|x| x as f64).collect(), Some(validity)); + + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum::() as f64), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_sum_large_float_32() { + let c = Float32Array::new((1..=100).map(|x| x as f32).collect(), None); + assert_eq!(Some((1..=100).sum::() as f32), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Float32Array::new((1..=100).map(|x| x as f32).collect(), Some(validity)); + + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum::() as f32), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_sum_large_64() { + let c = Int64Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int64Array::new((1..=100).collect(), Some(validity)); + + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_32() { + let c = Int32Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int32Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_16() { + let c = Int16Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int16Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_8() { + let c = UInt8Array::new((1..=100).collect(), None); + assert_eq!( + Some((1..=100).fold(0_u8, |a, x| a.wrapping_add(x))), + sum(&c) + ); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = UInt8Array::new((1..=100).collect(), Some(validity)); + assert_eq!( + Some( + (1..=100) + .filter(|i| i % 3 == 0) + .fold(0_u8, |a, x| a.wrapping_add(x)) + ), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_bit_and() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(0, bit_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_and_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, None]); + assert_eq!(2, bit_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_and_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_and(&a)); + } + + #[test] + fn test_primitive_array_bit_or() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(7, bit_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_or_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(7, bit_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_or_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_or(&a)); + } + + #[test] + fn test_primitive_array_bit_xor() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(1, bit_xor(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_xor_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(4, bit_xor(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_xor_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_xor(&a)); + } + + #[test] + fn test_primitive_array_bool_and() { + let a = BooleanArray::from(vec![true, false, true, false, true]); + assert!(!bool_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_and_with_nulls() { + let a = BooleanArray::from(vec![None, Some(true), Some(true), None, Some(true)]); + assert!(bool_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_and_all_nulls() { + let a = BooleanArray::from(vec![None, None, None]); + assert_eq!(None, bool_and(&a)); + } + + #[test] + fn test_primitive_array_bool_or() { + let a = BooleanArray::from(vec![true, false, true, false, true]); + assert!(bool_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_or_with_nulls() { + let a = BooleanArray::from(vec![None, Some(false), Some(false), None, Some(false)]); + assert!(!bool_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_or_all_nulls() { + let a = BooleanArray::from(vec![None, None, None]); + assert_eq!(None, bool_or(&a)); + } + + #[test] + fn test_primitive_array_min_max() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + assert_eq!(5, min(&a).unwrap()); + assert_eq!(9, max(&a).unwrap()); + } + + #[test] + fn test_primitive_array_min_max_with_nulls() { + let a = Int32Array::from(vec![Some(5), None, None, Some(8), Some(9)]); + assert_eq!(5, min(&a).unwrap()); + assert_eq!(9, max(&a).unwrap()); + } + + #[test] + fn test_primitive_min_max_1() { + let a = Int32Array::from(vec![None, None, Some(5), Some(2)]); + assert_eq!(Some(2), min(&a)); + assert_eq!(Some(5), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_large_nonnull_array() { + let a: Float64Array = (0..256).map(|i| Some((i + 1) as f64)).collect(); + // min/max are on boundaries of chunked data + assert_eq!(Some(1.0), min(&a)); + assert_eq!(Some(256.0), max(&a)); + + // max is last value in remainder after chunking + let a: Float64Array = (0..255).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(255.0), max(&a)); + + // max is first value in remainder after chunking + let a: Float64Array = (0..257).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(257.0), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_large_nullable_array() { + let a: Float64Array = (0..256) + .map(|i| { + if (i + 1) % 3 == 0 { + None + } else { + Some((i + 1) as f64) + } + }) + .collect(); + // min/max are on boundaries of chunked data + assert_eq!(Some(1.0), min(&a)); + assert_eq!(Some(256.0), max(&a)); + + let a: Float64Array = (0..256) + .map(|i| { + if i == 0 || i == 255 { + None + } else { + Some((i + 1) as f64) + } + }) + .collect(); + // boundaries of chunked data are null + assert_eq!(Some(2.0), min(&a)); + assert_eq!(Some(255.0), max(&a)); + + let a: Float64Array = (0..256) + .map(|i| if i != 100 { None } else { Some((i) as f64) }) + .collect(); + // a single non-null value somewhere in the middle + assert_eq!(Some(100.0), min(&a)); + assert_eq!(Some(100.0), max(&a)); + + // max is last value in remainder after chunking + let a: Float64Array = (0..255).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(255.0), max(&a)); + + // max is first value in remainder after chunking + let a: Float64Array = (0..257).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(257.0), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_edge_cases() { + let a: Float64Array = (0..100).map(|_| Some(f64::NEG_INFINITY)).collect(); + assert_eq!(Some(f64::NEG_INFINITY), min(&a)); + assert_eq!(Some(f64::NEG_INFINITY), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::MIN)).collect(); + assert_eq!(Some(f64::MIN), min(&a)); + assert_eq!(Some(f64::MIN), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::MAX)).collect(); + assert_eq!(Some(f64::MAX), min(&a)); + assert_eq!(Some(f64::MAX), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::INFINITY)).collect(); + assert_eq!(Some(f64::INFINITY), min(&a)); + assert_eq!(Some(f64::INFINITY), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_all_nans_non_null() { + let a: Float64Array = (0..100).map(|_| Some(f64::NAN)).collect(); + assert!(max(&a).unwrap().is_nan()); + assert!(min(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_negative_nan() { + let a: Float64Array = + Float64Array::from(vec![f64::NEG_INFINITY, f64::NAN, f64::INFINITY, -f64::NAN]); + let max = max(&a).unwrap(); + let min = min(&a).unwrap(); + assert!(max.is_nan()); + assert!(max.is_sign_positive()); + + assert!(min.is_nan()); + assert!(min.is_sign_negative()); + } + + #[test] + fn test_primitive_min_max_float_first_nan_nonnull() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 0 { + Some(f64::NAN) + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_last_nan_nonnull() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 99 { + Some(f64::NAN) + } else { + Some((i + 1) as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_first_nan_nullable() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 0 { + Some(f64::NAN) + } else if i % 2 == 0 { + None + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_last_nan_nullable() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 99 { + Some(f64::NAN) + } else if i % 2 == 0 { + None + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_inf_and_nans() { + let a: Float64Array = (0..100) + .map(|i| { + let x = match i % 10 { + 0 => f64::NEG_INFINITY, + 1 => f64::MIN, + 2 => f64::MAX, + 4 => f64::INFINITY, + 5 => f64::NAN, + _ => i as f64, + }; + Some(x) + }) + .collect(); + assert_eq!(Some(f64::NEG_INFINITY), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + macro_rules! test_binary { + ($NAME:ident, $ARRAY:expr, $EXPECTED_MIN:expr, $EXPECTED_MAX: expr) => { + #[test] + fn $NAME() { + let binary = BinaryArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_binary(&binary)); + assert_eq!($EXPECTED_MAX, max_binary(&binary)); + + let large_binary = LargeBinaryArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_binary(&large_binary)); + assert_eq!($EXPECTED_MAX, max_binary(&large_binary)); + + let binary_view = BinaryViewArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_binary_view(&binary_view)); + assert_eq!($EXPECTED_MAX, max_binary_view(&binary_view)); + } + }; + } + + test_binary!( + test_binary_min_max_with_nulls, + vec![ + Some("b01234567890123".as_bytes()), // long bytes + None, + None, + Some(b"a"), + Some(b"c"), + Some(b"abcdedfg0123456"), + ], + Some("a".as_bytes()), + Some("c".as_bytes()) + ); + + test_binary!( + test_binary_min_max_no_null, + vec![ + Some("b".as_bytes()), + Some(b"abcdefghijklmnopqrst"), // long bytes + Some(b"c"), + Some(b"b01234567890123"), // long bytes for view types + ], + Some("abcdefghijklmnopqrst".as_bytes()), + Some("c".as_bytes()) + ); + + test_binary!(test_binary_min_max_all_nulls, vec![None, None], None, None); + + test_binary!( + test_binary_min_max_1, + vec![ + None, + Some("b01234567890123435".as_bytes()), // long bytes for view types + None, + Some(b"b0123xxxxxxxxxxx"), + Some(b"a") + ], + Some("a".as_bytes()), + Some("b0123xxxxxxxxxxx".as_bytes()) + ); + + macro_rules! test_string { + ($NAME:ident, $ARRAY:expr, $EXPECTED_MIN:expr, $EXPECTED_MAX: expr) => { + #[test] + fn $NAME() { + let string = StringArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_string(&string)); + assert_eq!($EXPECTED_MAX, max_string(&string)); + + let large_string = LargeStringArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_string(&large_string)); + assert_eq!($EXPECTED_MAX, max_string(&large_string)); + + let string_view = StringViewArray::from($ARRAY); + assert_eq!($EXPECTED_MIN, min_string_view(&string_view)); + assert_eq!($EXPECTED_MAX, max_string_view(&string_view)); + } + }; + } + + test_string!( + test_string_min_max_with_nulls, + vec![ + Some("b012345678901234"), // long bytes for view types + None, + None, + Some("a"), + Some("c"), + Some("b0123xxxxxxxxxxx") + ], + Some("a"), + Some("c") + ); + + test_string!( + test_string_min_max_no_null, + vec![ + Some("b"), + Some("b012345678901234"), // long bytes for view types + Some("a"), + Some("b012xxxxxxxxxxxx") + ], + Some("a"), + Some("b012xxxxxxxxxxxx") + ); + + test_string!( + test_string_min_max_all_nulls, + Vec::>::from_iter([None, None]), + None, + None + ); + + test_string!( + test_string_min_max_1, + vec![ + None, + Some("c12345678901234"), // long bytes for view types + None, + Some("b"), + Some("c1234xxxxxxxxxx") + ], + Some("b"), + Some("c1234xxxxxxxxxx") + ); + + test_string!( + test_string_min_max_empty, + Vec::>::new(), + None, + None + ); + + #[test] + fn test_boolean_min_max_empty() { + let a = BooleanArray::from(vec![] as Vec>); + assert_eq!(None, min_boolean(&a)); + assert_eq!(None, max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_all_null() { + let a = BooleanArray::from(vec![None, None]); + assert_eq!(None, min_boolean(&a)); + assert_eq!(None, max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_no_null() { + let a = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max() { + let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(false), Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(true), None]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(true)]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(false)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_smaller() { + let a = BooleanArray::from(vec![Some(false)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(false)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(true)]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(true)]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_64_true_64_false() { + let mut no_nulls = BooleanBuilder::new(); + no_nulls.append_slice(&[true; 64]); + no_nulls.append_slice(&[false; 64]); + let no_nulls = no_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&no_nulls)); + assert_eq!(Some(true), max_boolean(&no_nulls)); + + let mut with_nulls = BooleanBuilder::new(); + with_nulls.append_slice(&[true; 31]); + with_nulls.append_null(); + with_nulls.append_slice(&[true; 32]); + with_nulls.append_slice(&[false; 1]); + with_nulls.append_nulls(63); + let with_nulls = with_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&with_nulls)); + assert_eq!(Some(true), max_boolean(&with_nulls)); + } + + #[test] + fn test_boolean_min_max_64_false_64_true() { + let mut no_nulls = BooleanBuilder::new(); + no_nulls.append_slice(&[false; 64]); + no_nulls.append_slice(&[true; 64]); + let no_nulls = no_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&no_nulls)); + assert_eq!(Some(true), max_boolean(&no_nulls)); + + let mut with_nulls = BooleanBuilder::new(); + with_nulls.append_slice(&[false; 31]); + with_nulls.append_null(); + with_nulls.append_slice(&[false; 32]); + with_nulls.append_slice(&[true; 1]); + with_nulls.append_nulls(63); + let with_nulls = with_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&with_nulls)); + assert_eq!(Some(true), max_boolean(&with_nulls)); + } + + #[test] + fn test_boolean_min_max_96_true() { + let mut no_nulls = BooleanBuilder::new(); + no_nulls.append_slice(&[true; 96]); + let no_nulls = no_nulls.finish(); + + assert_eq!(Some(true), min_boolean(&no_nulls)); + assert_eq!(Some(true), max_boolean(&no_nulls)); + + let mut with_nulls = BooleanBuilder::new(); + with_nulls.append_slice(&[true; 31]); + with_nulls.append_null(); + with_nulls.append_slice(&[true; 32]); + with_nulls.append_slice(&[true; 31]); + with_nulls.append_null(); + let with_nulls = with_nulls.finish(); + + assert_eq!(Some(true), min_boolean(&with_nulls)); + assert_eq!(Some(true), max_boolean(&with_nulls)); + } + + #[test] + fn test_boolean_min_max_96_false() { + let mut no_nulls = BooleanBuilder::new(); + no_nulls.append_slice(&[false; 96]); + let no_nulls = no_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&no_nulls)); + assert_eq!(Some(false), max_boolean(&no_nulls)); + + let mut with_nulls = BooleanBuilder::new(); + with_nulls.append_slice(&[false; 31]); + with_nulls.append_null(); + with_nulls.append_slice(&[false; 32]); + with_nulls.append_slice(&[false; 31]); + with_nulls.append_null(); + let with_nulls = with_nulls.finish(); + + assert_eq!(Some(false), min_boolean(&with_nulls)); + assert_eq!(Some(false), max_boolean(&with_nulls)); + } + + #[test] + fn test_sum_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let values = Arc::new(values) as ArrayRef; + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(39, sum_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(26, sum_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(sum_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + let values = Arc::new(values) as ArrayRef; + + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(14, max_array::(array).unwrap()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(5, max_array::(&a).unwrap()); + assert_eq!(1, min_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(17, max_array::(array).unwrap()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).is_none()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(min_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn_nan() { + let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]); + let keys = Int8Array::from_iter_values([0_i8, 1, 2]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).unwrap().is_nan()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(2.0_f32, min_array::(array).unwrap()); + } + + #[test] + fn test_min_max_sliced_primitive() { + let expected = Some(4.0); + let input: Float64Array = vec![None, Some(4.0)].into_iter().collect(); + let actual = min(&input); + assert_eq!(actual, expected); + let actual = max(&input); + assert_eq!(actual, expected); + + let sliced_input: Float64Array = vec![None, None, None, None, None, Some(4.0)] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min(&sliced_input); + assert_eq!(actual, expected); + let actual = max(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_boolean() { + let expected = Some(true); + let input: BooleanArray = vec![None, Some(true)].into_iter().collect(); + let actual = min_boolean(&input); + assert_eq!(actual, expected); + let actual = max_boolean(&input); + assert_eq!(actual, expected); + + let sliced_input: BooleanArray = vec![None, None, None, None, None, Some(true)] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(sliced_input, input); + + let actual = min_boolean(&sliced_input); + assert_eq!(actual, expected); + let actual = max_boolean(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_string() { + let expected = Some("foo"); + let input: StringArray = vec![None, Some("foo")].into_iter().collect(); + let actual = min_string(&input); + assert_eq!(actual, expected); + let actual = max_string(&input); + assert_eq!(actual, expected); + + let sliced_input: StringArray = vec![None, None, None, None, None, Some("foo")] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min_string(&sliced_input); + assert_eq!(actual, expected); + let actual = max_string(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_binary() { + let expected: Option<&[u8]> = Some(&[5]); + let input: BinaryArray = vec![None, Some(&[5])].into_iter().collect(); + let actual = min_binary(&input); + assert_eq!(actual, expected); + let actual = max_binary(&input); + assert_eq!(actual, expected); + + let sliced_input: BinaryArray = vec![None, None, None, None, None, Some(&[5])] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min_binary(&sliced_input); + assert_eq!(actual, expected); + let actual = max_binary(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_sum_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + assert_eq!(sum(&a).unwrap(), -2147483648); + assert_eq!(sum_array::(&a).unwrap(), -2147483648); + } + + #[test] + fn test_sum_checked_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + sum_checked(&a).expect_err("overflow should be detected"); + sum_array_checked::(&a).expect_err("overflow should be detected"); + } +} diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs new file mode 100644 index 000000000000..febf5ceabdd9 --- /dev/null +++ b/arrow-arith/src/arithmetic.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines basic arithmetic kernels for `PrimitiveArrays`. +//! +//! These kernels can leverage SIMD if available on your system. Currently no runtime +//! detection is provided, you should enable the specific SIMD intrinsics using +//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation +//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. + +use crate::arity::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::i256; +use arrow_buffer::ArrowNativeType; +use arrow_schema::*; +use std::cmp::min; +use std::sync::Arc; + +/// Returns the precision and scale of the result of a multiplication of two decimal types, +/// and the divisor for fixed point multiplication. +fn get_fixed_point_info( + left: (u8, i8), + right: (u8, i8), + required_scale: i8, +) -> Result<(u8, i8, i256), ArrowError> { + let product_scale = left.1 + right.1; + let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION); + + if required_scale > product_scale { + return Err(ArrowError::ComputeError(format!( + "Required scale {} is greater than product scale {}", + required_scale, product_scale + ))); + } + + let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); + + Ok((precision, product_scale, divisor)) +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply_dyn` or `multiply_dyn_checked` instead. +pub fn multiply_fixed_point_dyn( + left: &dyn Array, + right: &dyn Array, + required_scale: i8, +) -> Result { + match (left.data_type(), right.data_type()) { + (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => { + let left = left.as_any().downcast_ref::().unwrap(); + let right = right.as_any().downcast_ref::().unwrap(); + + multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef) + } + (_, _) => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), + right.data_type() + ))), + } +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply` or `multiply_checked` instead. +pub fn multiply_fixed_point_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (right.precision(), right.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))? + .with_precision_and_scale(precision, required_scale); + } + + try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + let a = i256::from_i128(a); + let b = i256::from_i128(b); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.to_i128().ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b)) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_fixed_point_checked` instead. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply` or `multiply_checked` instead. +pub fn multiply_fixed_point( + left: &PrimitiveArray, + right: &PrimitiveArray, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (right.precision(), right.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return binary(left, right, |a, b| a.mul_wrapping(b))? + .with_precision_and_scale(precision, required_scale); + } + + binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + let a = i256::from_i128(a); + let b = i256::from_i128(b); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.as_i128() + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Divide a decimal native value by given divisor and round the result. +fn divide_and_round(input: I::Native, div: I::Native) -> I::Native +where + I: DecimalType, + I::Native: ArrowNativeTypeOp, +{ + let d = input.div_wrapping(div); + let r = input.mod_wrapping(div); + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + // Round result + match input >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::numeric::mul; + + #[test] + fn test_decimal_multiply_allow_precision_loss() { + // Overflow happening as i128 cannot hold multiplying result. + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let err = mul(&a, &b).unwrap_err(); + assert!(err + .to_string() + .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000")); + + // Allow precision loss. + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [1234567890] + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + + // Rounding case + // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] + let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001] + let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [ + // 0.0000000000000000015555555556, + // 1385459527.2345679012071330528765432099, + // 0.0000000000000000015555555556 + // ] + let expected = Decimal128Array::from(vec![ + 15555555556, + 13854595272345679012071330528765432099, + 15555555556, + ]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + + // Rounded the value "1385459527.234567901207133052876543209876543210". + assert_eq!( + result.value_as_string(1), + "1385459527.2345679012071330528765432099" + ); + assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556"); + assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556"); + + let a = Decimal128Array::from(vec![1230]) + .with_precision_and_scale(4, 2) + .unwrap(); + + let b = Decimal128Array::from(vec![1000]) + .with_precision_and_scale(4, 2) + .unwrap(); + + // Required scale is same as the product of the input scales. Behavior is same as multiply. + let result = multiply_fixed_point_checked(&a, &b, 4).unwrap(); + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), 4); + + let expected = mul(&a, &b).unwrap(); + assert_eq!(expected.as_ref(), &result); + + // Required scale cannot be larger than the product of the input scales. + let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); + assert!(result + .to_string() + .contains("Required scale 5 is greater than product scale 4")); + } + + #[test] + fn test_decimal_multiply_allow_precision_loss_overflow() { + // [99999999999123456789] + let a = Decimal128Array::from(vec![99999999999123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [9999999999910] + let b = Decimal128Array::from(vec![9999999999910000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err(); + assert!(err.to_string().contains( + "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000" + )); + + let result = multiply_fixed_point(&a, &b, 28).unwrap(); + let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + } + + #[test] + fn test_decimal_multiply_fixed_point() { + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // `multiply` overflows on this case. + let err = mul(&a, &b).unwrap_err(); + assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"); + + // Avoid overflow by reducing the scale. + let result = multiply_fixed_point(&a, &b, 28).unwrap(); + // [1234567890] + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + } +} diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs new file mode 100644 index 000000000000..bb983e1225ac --- /dev/null +++ b/arrow-arith/src/arity.rs @@ -0,0 +1,668 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernels for operating on [`PrimitiveArray`]s + +use arrow_array::builder::BufferBuilder; +use arrow_array::types::ArrowDictionaryKeyType; +use arrow_array::*; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ArrowNativeType; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::ArrayData; +use arrow_schema::ArrowError; +use std::sync::Arc; + +/// See [`PrimitiveArray::unary`] +pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> O::Native, +{ + array.unary(op) +} + +/// See [`PrimitiveArray::unary_mut`] +pub fn unary_mut( + array: PrimitiveArray, + op: F, +) -> Result, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> I::Native, +{ + array.unary_mut(op) +} + +/// See [`PrimitiveArray::try_unary`] +pub fn try_unary(array: &PrimitiveArray, op: F) -> Result, ArrowError> +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary(op) +} + +/// See [`PrimitiveArray::try_unary_mut`] +pub fn try_unary_mut( + array: PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary_mut(op) +} + +/// A helper function that applies an infallible unary function to a dictionary array with primitive value type. +fn unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowDictionaryKeyType + ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + let dict_values = array.values().as_any().downcast_ref().unwrap(); + let values = unary::(dict_values, op); + Ok(Arc::new(array.with_values(Arc::new(values)))) +} + +/// A helper function that applies a fallible unary function to a dictionary array with primitive value type. +fn try_unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowDictionaryKeyType + ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + if !PrimitiveArray::::is_compatible(&array.value_type()) { + return Err(ArrowError::CastError(format!( + "Cannot perform the unary operation of type {} on dictionary array of value type {}", + T::DATA_TYPE, + array.value_type() + ))); + } + + let dict_values = array.values().as_any().downcast_ref().unwrap(); + let values = try_unary::(dict_values, op)?; + Ok(Arc::new(array.with_values(Arc::new(values)))) +} + +/// Applies an infallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] +pub fn unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + downcast_dictionary_array! { + array => unary_dict::<_, F, T>(array, op), + t => { + if PrimitiveArray::::is_compatible(t) { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + t + ))) + } + } + } +} + +/// Applies a fallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] +pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + downcast_dictionary_array! { + array => if array.values().data_type() == &T::DATA_TYPE { + try_unary_dict::<_, F, T>(array, op) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation on dictionary array of type {}", + array.data_type() + ))) + }, + t => { + if PrimitiveArray::::is_compatible(t) { + Ok(Arc::new(try_unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + )?)) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + t + ))) + } + } + } +} + +/// Allies a binary infallable function to two [`PrimitiveArray`]s, +/// producing a new [`PrimitiveArray`] +/// +/// # Details +/// +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting +/// the results in a [`PrimitiveArray`]. +/// +/// If any index is null in either `a` or `b`, the +/// corresponding index in the result will also be null +/// +/// Like [`unary`], the `op` is evaluated for every element in the two arrays, +/// including those elements which are NULL. This is beneficial as the cost of +/// the operation is low compared to the cost of branching, and especially when +/// the operation can be vectorised, however, requires `op` to be infallible for +/// all possible values of its inputs +/// +/// # Errors +/// +/// * if the arrays have different lengths. +/// +/// # Example +/// ``` +/// # use arrow_arith::arity::binary; +/// # use arrow_array::{Float32Array, Int32Array}; +/// # use arrow_array::types::Int32Type; +/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8), Some(7.2)]); +/// let b = Int32Array::from(vec![1, 2, 4, 9]); +/// // compute int(a) + b for each element +/// let c = binary(&a, &b, |a, b| a as i32 + b).unwrap(); +/// assert_eq!(c, Int32Array::from(vec![Some(6), None, Some(10), Some(16)])); +/// ``` +pub fn binary( + a: &PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError> +where + A: ArrowPrimitiveType, + B: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(A::Native, B::Native) -> O::Native, +{ + if a.len() != b.len() { + return Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + )); + } + + if a.is_empty() { + return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); + } + + let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()); + + let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size from a PrimitiveArray + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + Ok(PrimitiveArray::new(buffer.into(), nulls)) +} + +/// Applies a binary and infallible function to values in two arrays, replacing +/// the values in the first array in place. +/// +/// # Details +/// +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in +/// `0..len`, modifying the [`PrimitiveArray`] `a` in place, if possible. +/// +/// If any index is null in either `a` or `b`, the corresponding index in the +/// result will also be null. +/// +/// # Buffer Reuse +/// +/// If the underlying buffers in `a` are not shared with other arrays, mutates +/// the underlying buffer in place, without allocating. +/// +/// If the underlying buffer in `a` are shared, returns Err(self) +/// +/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This +/// is beneficial when the cost of the operation is low compared to the cost of branching, and +/// especially when the operation can be vectorised, however, requires `op` to be infallible +/// for all possible values of its inputs +/// +/// # Errors +/// +/// * If the arrays have different lengths +/// * If the array is not mutable (see "Buffer Reuse") +/// +/// # See Also +/// +/// * Documentation on [`PrimitiveArray::unary_mut`] for operating on [`ArrayRef`]. +/// +/// # Example +/// ``` +/// # use arrow_arith::arity::binary_mut; +/// # use arrow_array::{Float32Array, Int32Array}; +/// # use arrow_array::types::Int32Type; +/// // compute a + b for each element +/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]); +/// let b = Int32Array::from(vec![Some(1), None, Some(2)]); +/// // compute a + b, updating the value in a in place if possible +/// let a = binary_mut(a, &b, |a, b| a + b as f32).unwrap().unwrap(); +/// // a is updated in place +/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)])); +/// ``` +/// +/// # Example with shared buffers +/// ``` +/// # use arrow_arith::arity::binary_mut; +/// # use arrow_array::Float32Array; +/// # use arrow_array::types::Int32Type; +/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]); +/// let b = Float32Array::from(vec![Some(1.0f32), None, Some(2.0)]); +/// // a_clone shares the buffer with a +/// let a_cloned = a.clone(); +/// // try to update a in place, but it is shared. Returns Err(a) +/// let a = binary_mut(a, &b, |a, b| a + b).unwrap_err(); +/// assert_eq!(a_cloned, a); +/// // drop shared reference +/// drop(a_cloned); +/// // now a is not shared, so we can update it in place +/// let a = binary_mut(a, &b, |a, b| a + b).unwrap().unwrap(); +/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)])); +/// ``` +pub fn binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + U: ArrowPrimitiveType, + F: Fn(T::Native, U::Native) -> T::Native, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + let mut builder = a.into_builder()?; + + builder + .values_slice_mut() + .iter_mut() + .zip(b.values()) + .for_each(|(l, r)| *l = op(*l, *r)); + + let array = builder.finish(); + + // The builder has the null buffer from `a`, it is not changed. + let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref()); + + let array_builder = array.into_data().into_builder().nulls(nulls); + + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) +} + +/// Applies the provided fallible binary operation across `a` and `b`. +/// +/// This will return any error encountered, or collect the results into +/// a [`PrimitiveArray`]. If any index is null in either `a` +/// or `b`, the corresponding index in the result will also be null +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices +/// +/// # Error +/// +/// Return an error if the arrays have different lengths or +/// the operation is under erroneous +pub fn try_binary( + a: A, + b: B, + op: F, +) -> Result, ArrowError> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + if a.len() != b.len() { + return Err(ArrowError::ComputeError( + "Cannot perform a binary operation on arrays of different length".to_string(), + )); + } + if a.is_empty() { + return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); + } + let len = a.len(); + + if a.null_count() == 0 && b.null_count() == 0 { + try_binary_no_nulls(len, a, b, op) + } else { + let nulls = + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + nulls.try_for_each_valid_idx(|idx| { + unsafe { + *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(values, Some(nulls))) + } +} + +/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable +/// [`PrimitiveArray`] `a` with the results. +/// +/// Returns any error encountered, or collects the results into a [`PrimitiveArray`] as return +/// value. If any index is null in either `a` or `b`, the corresponding index in the result will +/// also be null. +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices. +/// +/// See [`binary_mut`] for errors and buffer reuse information. +pub fn try_binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + let len = a.len(); + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + if a.null_count() == 0 && b.null_count() == 0 { + try_binary_no_nulls_mut(len, a, b, op) + } else { + let nulls = + create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()) + .unwrap(); + + let mut builder = a.into_builder()?; + + let slice = builder.values_slice_mut(); + + let r = nulls.try_for_each_valid_idx(|idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(*slice.get_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + }); + if let Err(err) = r { + return Ok(Err(err)); + } + let array_builder = builder.finish().into_data().into_builder(); + let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) + } +} + +/// Computes the union of the nulls in two optional [`NullBuffer`] which +/// is not shared with the input buffers. +/// +/// The union of the nulls is the same as `NullBuffer::union(lhs, rhs)` but +/// it does not increase the reference count of the null buffer. +fn create_union_null_buffer( + lhs: Option<&NullBuffer>, + rhs: Option<&NullBuffer>, +) -> Option { + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())), + (Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())), + (None, None) => None, + } +} + +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls( + len: usize, + a: A, + b: B, + op: F, +) -> Result, ArrowError> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width()); + for idx in 0..len { + unsafe { + buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); + }; + } + Ok(PrimitiveArray::new(buffer.into(), None)) +} + +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls_mut( + len: usize, + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + let mut builder = a.into_builder()?; + let slice = builder.values_slice_mut(); + + for idx in 0..len { + unsafe { + match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) { + Ok(value) => *slice.get_unchecked_mut(idx) = value, + Err(err) => return Ok(Err(err)), + }; + }; + } + Ok(Ok(builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::builder::*; + use arrow_array::types::*; + + #[test] + #[allow(deprecated)] + fn test_unary_f64_slice() { + let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); + let input_slice = input.slice(1, 4); + let result = unary(&input_slice, |n| n.round()); + assert_eq!( + result, + Float64Array::from(vec![None, Some(7.0), None, Some(7.0)]) + ); + + let result = unary_dyn::<_, Float64Type>(&input_slice, |n| n + 1.0).unwrap(); + + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)]) + ); + } + + #[test] + #[allow(deprecated)] + fn test_unary_dict_and_unary_dyn() { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null(); + builder.append(9).unwrap(); + let dictionary_array = builder.finish(); + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + builder.append_null(); + builder.append(10).unwrap(); + let expected = builder.finish(); + + let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); + + let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); + } + + #[test] + fn test_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_binary_mut_null_buffer() { + let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]); + + let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]); + + let r1 = binary_mut(a, &b, |a, b| a + b).unwrap(); + + let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]); + let b = Int32Array::new( + vec![10, 11, 12, 13, 14].into(), + Some(vec![true, true, true, true, true].into()), + ); + + // unwrap here means that no copying occured + let r2 = binary_mut(a, &b, |a, b| a + b).unwrap(); + assert_eq!(r1.unwrap(), r2.unwrap()); + } + + #[test] + fn test_try_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![1, 2, 3, 4, 5]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + let expected = Int32Array::from(vec![16, 16, 12, 12, 6]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let _ = try_binary_mut(a, &b, |l, r| { + if l == 1 { + Err(ArrowError::InvalidArgumentError( + "got error".parse().unwrap(), + )) + } else { + Ok(l + r) + } + }) + .unwrap() + .expect_err("should got error"); + } + + #[test] + fn test_try_binary_mut_null_buffer() { + let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]); + + let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]); + + let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap(); + + let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]); + let b = Int32Array::new( + vec![10, 11, 12, 13, 14].into(), + Some(vec![true, true, true, true, true].into()), + ); + + // unwrap here means that no copying occured + let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap(); + assert_eq!(r1.unwrap(), r2.unwrap()); + } + + #[test] + fn test_unary_dict_mut() { + let values = Int32Array::from(vec![Some(10), Some(20), None]); + let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + + let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap(); + let typed = updated.downcast_dict::().unwrap(); + assert_eq!(typed.value(0), 11); + assert_eq!(typed.value(1), 11); + assert_eq!(typed.value(2), 21); + + let values = updated.values(); + assert!(values.is_null(2)); + } +} diff --git a/arrow-arith/src/bitwise.rs b/arrow-arith/src/bitwise.rs new file mode 100644 index 000000000000..a3c18136c5eb --- /dev/null +++ b/arrow-arith/src/bitwise.rs @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module contains bitwise operations on arrays + +use crate::arity::{binary, unary}; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::ArrowError; +use num::traits::{WrappingShl, WrappingShr}; +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +/// The helper function for bitwise operation with two array +fn bitwise_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result, ArrowError> +where + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> T::Native, +{ + binary(left, right, op) +} + +/// Perform `left & right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_and( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitAnd, +{ + bitwise_op(left, right, |a, b| a & b) +} + +/// Perform `left | right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_or( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitOr, +{ + bitwise_op(left, right, |a, b| a | b) +} + +/// Perform `left ^ right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_xor( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitXor, +{ + bitwise_op(left, right, |a, b| a ^ b) +} + +/// Perform bitwise `left << right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_shift_left( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShl, +{ + bitwise_op(left, right, |a, b| { + let b = b.as_usize(); + a.wrapping_shl(b as u32) + }) +} + +/// Perform bitwise `left >> right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_shift_right( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShr, +{ + bitwise_op(left, right, |a, b| { + let b = b.as_usize(); + a.wrapping_shr(b as u32) + }) +} + +/// Perform `!array` operation on array. If array value is null +/// then the result is also null. +pub fn bitwise_not(array: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: Not, +{ + Ok(unary(array, |value| !value)) +} + +/// Perform `left & !right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_and_not( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitAnd, + T::Native: Not, +{ + bitwise_op(left, right, |a, b| a & !b) +} + +/// Perform bitwise `and` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_and_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitAnd, +{ + Ok(unary(array, |value| value & scalar)) +} + +/// Perform bitwise `or` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_or_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitOr, +{ + Ok(unary(array, |value| value | scalar)) +} + +/// Perform bitwise `xor` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_xor_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitXor, +{ + Ok(unary(array, |value| value ^ scalar)) +} + +/// Perform bitwise `left << right` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_shift_left_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShl, +{ + Ok(unary(array, |value| { + let scalar = scalar.as_usize(); + value.wrapping_shl(scalar as u32) + })) +} + +/// Perform bitwise `left >> right` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_shift_right_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShr, +{ + Ok(unary(array, |value| { + let scalar = scalar.as_usize(); + value.wrapping_shr(scalar as u32) + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bitwise_and_array() -> Result<(), ArrowError> { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12)]); + let expected = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let result = bitwise_and(&left, &right)?; + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(5), Some(-10), Some(8), Some(12)]); + let expected = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let result = bitwise_and(&left, &right)?; + assert_eq!(expected, result); + Ok(()) + } + + #[test] + fn test_bitwise_shift_left() { + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(u64::MAX)]); + let expected = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(0)]); + let result = bitwise_shift_left(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_left_scalar() { + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); + let scalar = 2; + let expected = UInt64Array::from(vec![Some(4), Some(8), None, Some(16), Some(32)]); + let result = bitwise_shift_left_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_right() { + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(65)]); + let expected = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(1)]); + let result = bitwise_shift_right(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_right_scalar() { + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let scalar = 2; + let expected = UInt64Array::from(vec![Some(8), Some(512), None, Some(4096), Some(0)]); + let result = bitwise_shift_right_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_and_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(7), Some(2), None, Some(4)]); + let result = bitwise_and_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = -20; + let expected = Int32Array::from(vec![Some(0), Some(0), None, Some(4)]); + let result = bitwise_and_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_or_array() { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(7), Some(5), Some(8), Some(13)]); + let expected = UInt64Array::from(vec![Some(7), Some(7), None, Some(13)]); + let result = bitwise_or(&left, &right).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(-7), Some(-5), Some(8), Some(13)]); + let expected = Int32Array::from(vec![Some(-7), Some(-5), None, Some(13)]); + let result = bitwise_or(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_not_array() { + // unsigned value + let array = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let expected = UInt64Array::from(vec![ + Some(18446744073709551614), + Some(18446744073709551613), + None, + Some(18446744073709551611), + ]); + let result = bitwise_not(&array).unwrap(); + assert_eq!(expected, result); + // signed value + let array = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let expected = Int32Array::from(vec![Some(-2), Some(-3), None, Some(-5)]); + let result = bitwise_not(&array).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_and_not_array() { + // unsigned value + let left = UInt64Array::from(vec![Some(8), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(7), Some(5), Some(8), Some(13)]); + let expected = UInt64Array::from(vec![Some(8), Some(2), None, Some(0)]); + let result = bitwise_and_not(&left, &right).unwrap(); + assert_eq!(expected, result); + assert_eq!( + bitwise_and(&left, &bitwise_not(&right).unwrap()).unwrap(), + result + ); + + // signed value + let left = Int32Array::from(vec![Some(2), Some(1), None, Some(3)]); + let right = Int32Array::from(vec![Some(-7), Some(-5), Some(8), Some(13)]); + let expected = Int32Array::from(vec![Some(2), Some(0), None, Some(2)]); + let result = bitwise_and_not(&left, &right).unwrap(); + assert_eq!(expected, result); + assert_eq!( + bitwise_and(&left, &bitwise_not(&right).unwrap()).unwrap(), + result + ); + } + + #[test] + fn test_bitwise_or_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(15), Some(7), None, Some(7)]); + let result = bitwise_or_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = 20; + let expected = Int32Array::from(vec![Some(21), Some(22), None, Some(20)]); + let result = bitwise_or_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_xor_array() { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(7), Some(5), Some(8), Some(13)]); + let expected = UInt64Array::from(vec![Some(6), Some(7), None, Some(9)]); + let result = bitwise_xor(&left, &right).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(-7), Some(5), Some(8), Some(-13)]); + let expected = Int32Array::from(vec![Some(-8), Some(7), None, Some(-9)]); + let result = bitwise_xor(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_xor_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(8), Some(5), None, Some(3)]); + let result = bitwise_xor_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = -20; + let expected = Int32Array::from(vec![Some(-19), Some(-18), None, Some(-24)]); + let result = bitwise_xor_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } +} diff --git a/arrow/src/compute/kernels/boolean.rs b/arrow-arith/src/boolean.rs similarity index 56% rename from arrow/src/compute/kernels/boolean.rs rename to arrow-arith/src/boolean.rs index c51953a7540c..ea8e12abbe2c 100644 --- a/arrow/src/compute/kernels/boolean.rs +++ b/arrow-arith/src/boolean.rs @@ -22,36 +22,52 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use std::ops::Not; - -use crate::array::{Array, ArrayData, BooleanArray, PrimitiveArray}; -use crate::buffer::{ - bitwise_bin_op_helper, bitwise_quaternary_op_helper, buffer_bin_and, buffer_bin_or, - buffer_unary_not, Buffer, MutableBuffer, -}; -use crate::compute::util::combine_option_bitmap; -use crate::datatypes::{ArrowNumericType, DataType}; -use crate::error::{ArrowError, Result}; -use crate::util::bit_util::ceil; - -/// Updates null buffer based on data buffer and null buffer of the operand at other side -/// in boolean AND kernel with Kleene logic. In short, because for AND kernel, null AND false -/// results false. So we cannot simply AND two null buffers. This function updates null buffer -/// of one side if other side is a false value. -pub(crate) fn build_null_buffer_for_and_kleene( - left_data: &ArrayData, - left_offset: usize, - right_data: &ArrayData, - right_offset: usize, - len_in_bits: usize, -) -> Option { - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - - let left_null_buffer = left_data.null_buffer(); - let right_null_buffer = right_data.null_buffer(); - - match (left_null_buffer, right_null_buffer) { +use arrow_array::*; +use arrow_buffer::buffer::{bitwise_bin_op_helper, bitwise_quaternary_op_helper}; +use arrow_buffer::{buffer_bin_and_not, BooleanBuffer, NullBuffer}; +use arrow_schema::ArrowError; + +/// Logical 'and' boolean values with Kleene logic +/// +/// # Behavior +/// +/// This function behaves as follows with nulls: +/// +/// * `true` and `null` = `null` +/// * `null` and `true` = `null` +/// * `false` and `null` = `false` +/// * `null` and `false` = `false` +/// * `null` and `null` = `null` +/// +/// In other words, in this context a null value really means \"unknown\", +/// and an unknown value 'and' false is always false. +/// For a different null behavior, see function \"and\". +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::and_kleene; +/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); +/// let b = BooleanArray::from(vec![None, None, None]); +/// let and_ab = and_kleene(&a, &b).unwrap(); +/// assert_eq!(and_ab, BooleanArray::from(vec![None, Some(false), None])); +/// ``` +/// +/// # Fails +/// +/// If the operands have different lengths +pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform bitwise operation on arrays of different length".to_string(), + )); + } + + let left_values = left.values(); + let right_values = right.values(); + + let buffer = match (left.nulls(), right.nulls()) { (None, None) => None, (Some(left_null_buffer), None) => { // The right side has no null values. @@ -59,22 +75,22 @@ pub(crate) fn build_null_buffer_for_and_kleene( // 1. left null bit is set, or // 2. right data bit is false (because null AND false = false). Some(bitwise_bin_op_helper( - left_null_buffer, - left_offset, - right_buffer, - right_offset, - len_in_bits, + left_null_buffer.buffer(), + left_null_buffer.offset(), + right_values.inner(), + right_values.offset(), + left.len(), |a, b| a | !b, )) } (None, Some(right_null_buffer)) => { // Same as above Some(bitwise_bin_op_helper( - right_null_buffer, - right_offset, - left_buffer, - left_offset, - len_in_bits, + right_null_buffer.buffer(), + right_null_buffer.offset(), + left_values.inner(), + left_values.offset(), + left.len(), |a, b| a | !b, )) } @@ -85,109 +101,131 @@ pub(crate) fn build_null_buffer_for_and_kleene( // The final null bits are: // (a | (c & !d)) & (c | (a & !b)) Some(bitwise_quaternary_op_helper( - left_null_buffer, - left_offset, - left_buffer, - left_offset, - right_null_buffer, - right_offset, - right_buffer, - right_offset, - len_in_bits, + [ + left_null_buffer.buffer(), + left_values.inner(), + right_null_buffer.buffer(), + right_values.inner(), + ], + [ + left_null_buffer.offset(), + left_values.offset(), + right_null_buffer.offset(), + right_values.offset(), + ], + left.len(), |a, b, c, d| (a | (c & !d)) & (c | (a & !b)), )) } - } + }; + let nulls = buffer.map(|b| NullBuffer::new(BooleanBuffer::new(b, 0, left.len()))); + Ok(BooleanArray::new(left_values & right_values, nulls)) } -/// For AND/OR kernels, the result of null buffer is simply a bitwise `and` operation. -pub(crate) fn build_null_buffer_for_and_or( - left_data: &ArrayData, - _left_offset: usize, - right_data: &ArrayData, - _right_offset: usize, - len_in_bits: usize, -) -> Option { - // `arrays` are not empty, so safely do `unwrap` directly. - combine_option_bitmap(&[left_data, right_data], len_in_bits).unwrap() -} +/// Logical 'or' boolean values with Kleene logic +/// +/// # Behavior +/// +/// This function behaves as follows with nulls: +/// +/// * `true` or `null` = `true` +/// * `null` or `true` = `true` +/// * `false` or `null` = `null` +/// * `null` or `false` = `null` +/// * `null` or `null` = `null` +/// +/// In other words, in this context a null value really means \"unknown\", +/// and an unknown value 'or' true is always true. +/// For a different null behavior, see function \"or\". +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::or_kleene; +/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); +/// let b = BooleanArray::from(vec![None, None, None]); +/// let or_ab = or_kleene(&a, &b).unwrap(); +/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), None, None])); +/// ``` +/// +/// # Fails +/// +/// If the operands have different lengths +pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform bitwise operation on arrays of different length".to_string(), + )); + } + + let left_values = left.values(); + let right_values = right.values(); -/// Updates null buffer based on data buffer and null buffer of the operand at other side -/// in boolean OR kernel with Kleene logic. In short, because for OR kernel, null OR true -/// results true. So we cannot simply AND two null buffers. This function updates null -/// buffer of one side if other side is a true value. -pub(crate) fn build_null_buffer_for_or_kleene( - left_data: &ArrayData, - left_offset: usize, - right_data: &ArrayData, - right_offset: usize, - len_in_bits: usize, -) -> Option { - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - - let left_null_buffer = left_data.null_buffer(); - let right_null_buffer = right_data.null_buffer(); - - match (left_null_buffer, right_null_buffer) { + let buffer = match (left.nulls(), right.nulls()) { (None, None) => None, - (Some(left_null_buffer), None) => { + (Some(left_nulls), None) => { // The right side has no null values. // The final null bit is set only if: // 1. left null bit is set, or // 2. right data bit is true (because null OR true = true). Some(bitwise_bin_op_helper( - left_null_buffer, - left_offset, - right_buffer, - right_offset, - len_in_bits, + left_nulls.buffer(), + left_nulls.offset(), + right_values.inner(), + right_values.offset(), + left.len(), |a, b| a | b, )) } - (None, Some(right_null_buffer)) => { + (None, Some(right_nulls)) => { // Same as above Some(bitwise_bin_op_helper( - right_null_buffer, - right_offset, - left_buffer, - left_offset, - len_in_bits, + right_nulls.buffer(), + right_nulls.offset(), + left_values.inner(), + left_values.offset(), + left.len(), |a, b| a | b, )) } - (Some(left_null_buffer), Some(right_null_buffer)) => { + (Some(left_nulls), Some(right_nulls)) => { // Follow the same logic above. Both sides have null values. // Assume a is left null bits, b is left data bits, c is right null bits, // d is right data bits. // The final null bits are: // (a | (c & d)) & (c | (a & b)) Some(bitwise_quaternary_op_helper( - left_null_buffer, - left_offset, - left_buffer, - left_offset, - right_null_buffer, - right_offset, - right_buffer, - right_offset, - len_in_bits, + [ + left_nulls.buffer(), + left_values.inner(), + right_nulls.buffer(), + right_values.inner(), + ], + [ + left_nulls.offset(), + left_values.offset(), + right_nulls.offset(), + right_values.offset(), + ], + left.len(), |a, b, c, d| (a | (c & d)) & (c | (a & b)), )) } - } + }; + + let nulls = buffer.map(|b| NullBuffer::new(BooleanBuffer::new(b, 0, left.len()))); + Ok(BooleanArray::new(left_values | right_values, nulls)) } /// Helper function to implement binary kernels -pub(crate) fn binary_boolean_kernel( +pub(crate) fn binary_boolean_kernel( left: &BooleanArray, right: &BooleanArray, op: F, - null_op: U, -) -> Result +) -> Result where - F: Fn(&Buffer, usize, &Buffer, usize, usize) -> Buffer, - U: Fn(&ArrayData, usize, &ArrayData, usize, usize) -> Option, + F: Fn(&BooleanBuffer, &BooleanBuffer) -> BooleanBuffer, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -195,32 +233,9 @@ where )); } - let len = left.len(); - - let left_data = left.data_ref(); - let right_data = right.data_ref(); - - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - let left_offset = left.offset(); - let right_offset = right.offset(); - - let null_bit_buffer = null_op(left_data, left_offset, right_data, right_offset, len); - - let values = op(left_buffer, left_offset, right_buffer, right_offset, len); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - null_bit_buffer, - 0, - vec![values], - vec![], - ) - }; - Ok(BooleanArray::from(data)) + let nulls = NullBuffer::union(left.nulls(), right.nulls()); + let values = op(left.values(), right.values()); + Ok(BooleanArray::new(values, nulls)) } /// Performs `AND` operation on two arrays. If either left or right value is null then the @@ -229,62 +244,15 @@ where /// This function errors when the arrays have different lengths. /// # Example /// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::and; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::and; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); /// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); -/// let and_ab = and(&a, &b)?; +/// let and_ab = and(&a, &b).unwrap(); /// assert_eq!(and_ab, BooleanArray::from(vec![Some(false), Some(true), None])); -/// # Ok(()) -/// # } -/// ``` -pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_and, build_null_buffer_for_and_or) -} - -/// Logical 'and' boolean values with Kleene logic -/// -/// # Behavior -/// -/// This function behaves as follows with nulls: -/// -/// * `true` and `null` = `null` -/// * `null` and `true` = `null` -/// * `false` and `null` = `false` -/// * `null` and `false` = `false` -/// * `null` and `null` = `null` -/// -/// In other words, in this context a null value really means \"unknown\", -/// and an unknown value 'and' false is always false. -/// For a different null behavior, see function \"and\". -/// -/// # Example -/// -/// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::and_kleene; -/// # fn main() -> Result<()> { -/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); -/// let b = BooleanArray::from(vec![None, None, None]); -/// let and_ab = and_kleene(&a, &b)?; -/// assert_eq!(and_ab, BooleanArray::from(vec![None, Some(false), None])); -/// # Ok(()) -/// # } /// ``` -/// -/// # Fails -/// -/// If the operands have different lengths -pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel( - left, - right, - buffer_bin_and, - build_null_buffer_for_and_kleene, - ) +pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { + binary_boolean_kernel(left, right, |a, b| a & b) } /// Performs `OR` operation on two arrays. If either left or right value is null then the @@ -293,57 +261,36 @@ pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::or; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); /// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); -/// let or_ab = or(&a, &b)?; +/// let or_ab = or(&a, &b).unwrap(); /// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); -/// # Ok(()) -/// # } /// ``` -pub fn or(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_and_or) +pub fn or(left: &BooleanArray, right: &BooleanArray) -> Result { + binary_boolean_kernel(left, right, |a, b| a | b) } -/// Logical 'or' boolean values with Kleene logic -/// -/// # Behavior -/// -/// This function behaves as follows with nulls: -/// -/// * `true` or `null` = `true` -/// * `null` or `true` = `true` -/// * `false` or `null` = `null` -/// * `null` or `false` = `null` -/// * `null` or `null` = `null` -/// -/// In other words, in this context a null value really means \"unknown\", -/// and an unknown value 'or' true is always true. -/// For a different null behavior, see function \"or\". -/// +/// Performs `AND_NOT` operation on two arrays. If either left or right value is null then the +/// result is also null. +/// # Error +/// This function errors when the arrays have different lengths. /// # Example -/// /// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::or_kleene; -/// # fn main() -> Result<()> { -/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); -/// let b = BooleanArray::from(vec![None, None, None]); -/// let or_ab = or_kleene(&a, &b)?; -/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), None, None])); -/// # Ok(()) -/// # } -/// ``` -/// -/// # Fails -/// -/// If the operands have different lengths -pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_or_kleene) +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::{and, not, and_not}; +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); +/// let andn_ab = and_not(&a, &b).unwrap(); +/// assert_eq!(andn_ab, BooleanArray::from(vec![Some(false), Some(false), None])); +/// // It's equal to and(left, not(right)) +/// assert_eq!(andn_ab, and(&a, ¬(&b).unwrap()).unwrap()); +pub fn and_not(left: &BooleanArray, right: &BooleanArray) -> Result { + binary_boolean_kernel(left, right, |a, b| { + let buffer = buffer_bin_and_not(a.inner(), b.offset(), b.inner(), a.offset(), a.len()); + BooleanBuffer::new(buffer, left.offset(), left.len()) + }) } /// Performs unary `NOT` operation on an arrays. If value is null then the result is also @@ -352,40 +299,16 @@ pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::not; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let not_a = not(&a)?; +/// let not_a = not(&a).unwrap(); /// assert_eq!(not_a, BooleanArray::from(vec![Some(true), Some(false), None])); -/// # Ok(()) -/// # } /// ``` -pub fn not(left: &BooleanArray) -> Result { - let left_offset = left.offset(); - let len = left.len(); - - let data = left.data_ref(); - let null_bit_buffer = data - .null_bitmap() - .as_ref() - .map(|b| b.bits.bit_slice(left_offset, len)); - - let values = buffer_unary_not(&data.buffers()[0], left_offset, len); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - null_bit_buffer, - 0, - vec![values], - vec![], - ) - }; - Ok(BooleanArray::from(data)) +pub fn not(left: &BooleanArray) -> Result { + let nulls = left.nulls().cloned(); + let values = !left.values(); + Ok(BooleanArray::new(values, nulls)) } /// Returns a non-null [BooleanArray] with whether each value of the array is null. @@ -393,40 +316,19 @@ pub fn not(left: &BooleanArray) -> Result { /// This function never errors. /// # Example /// ```rust -/// # use arrow::error::Result; -/// use arrow::array::BooleanArray; -/// use arrow::compute::kernels::boolean::is_null; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::is_null; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let a_is_null = is_null(&a)?; +/// let a_is_null = is_null(&a).unwrap(); /// assert_eq!(a_is_null, BooleanArray::from(vec![false, false, true])); -/// # Ok(()) -/// # } /// ``` -pub fn is_null(input: &dyn Array) -> Result { - let len = input.len(); - - let output = match input.data_ref().null_buffer() { - None => { - let len_bytes = ceil(len, 8); - MutableBuffer::from_len_zeroed(len_bytes).into() - } - Some(buffer) => buffer_unary_not(buffer, input.offset(), len), +pub fn is_null(input: &dyn Array) -> Result { + let values = match input.logical_nulls() { + None => BooleanBuffer::new_unset(input.len()), + Some(nulls) => !nulls.inner(), }; - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - None, - 0, - vec![output], - vec![], - ) - }; - - Ok(BooleanArray::from(data)) + Ok(BooleanArray::new(values, None)) } /// Returns a non-null [BooleanArray] with whether each value of the array is not null. @@ -434,141 +336,23 @@ pub fn is_null(input: &dyn Array) -> Result { /// This function never errors. /// # Example /// ```rust -/// # use arrow::error::Result; -/// use arrow::array::BooleanArray; -/// use arrow::compute::kernels::boolean::is_not_null; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::is_not_null; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let a_is_not_null = is_not_null(&a)?; +/// let a_is_not_null = is_not_null(&a).unwrap(); /// assert_eq!(a_is_not_null, BooleanArray::from(vec![true, true, false])); -/// # Ok(()) -/// # } /// ``` -pub fn is_not_null(input: &dyn Array) -> Result { - let len = input.len(); - - let output = match input.data_ref().null_buffer() { - None => { - let len_bytes = ceil(len, 8); - MutableBuffer::new(len_bytes) - .with_bitset(len_bytes, true) - .into() - } - Some(buffer) => buffer.bit_slice(input.offset(), len), - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - None, - 0, - vec![output], - vec![], - ) +pub fn is_not_null(input: &dyn Array) -> Result { + let values = match input.logical_nulls() { + None => BooleanBuffer::new_set(input.len()), + Some(n) => n.inner().clone(), }; - - Ok(BooleanArray::from(data)) -} - -/// Copies original array, setting null bit to true if a secondary comparison boolean array is set to true. -/// Typically used to implement NULLIF. -// NOTE: For now this only supports Primitive Arrays. Although the code could be made generic, the issue -// is that currently the bitmap operations result in a final bitmap which is aligned to bit 0, and thus -// the left array's data needs to be sliced to a new offset, and for non-primitive arrays shifting the -// data might be too complicated. In the future, to avoid shifting left array's data, we could instead -// shift the final bitbuffer to the right, prepending with 0's instead. -pub fn nullif( - left: &PrimitiveArray, - right: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - let left_data = left.data(); - let right_data = right.data(); - - // If left has no bitmap, create a new one with all values set for nullity op later - // left=0 (null) right=null output bitmap=null - // left=0 right=1 output bitmap=null - // left=1 (set) right=null output bitmap=set (passthrough) - // left=1 right=1 & comp=true output bitmap=null - // left=1 right=1 & comp=false output bitmap=set - // - // Thus: result = left null bitmap & (!right_values | !right_bitmap) - // OR left null bitmap & !(right_values & right_bitmap) - // - // Do the right expression !(right_values & right_bitmap) first since there are two steps - // TRICK: convert BooleanArray buffer as a bitmap for faster operation - let right_combo_buffer = match right.data().null_bitmap() { - Some(right_bitmap) => { - // NOTE: right values and bitmaps are combined and stay at bit offset right.offset() - (right.values() & &right_bitmap.bits).ok().map(|b| b.not()) - } - None => Some(!right.values()), - }; - - // AND of original left null bitmap with right expression - // Here we take care of the possible offsets of the left and right arrays all at once. - let modified_null_buffer = match left_data.null_bitmap() { - Some(left_null_bitmap) => match right_combo_buffer { - Some(rcb) => Some(buffer_bin_and( - &left_null_bitmap.bits, - left_data.offset(), - &rcb, - right_data.offset(), - left_data.len(), - )), - None => Some( - left_null_bitmap - .bits - .bit_slice(left_data.offset(), left.len()), - ), - }, - None => right_combo_buffer - .map(|rcb| rcb.bit_slice(right_data.offset(), right_data.len())), - }; - - // Align/shift left data on offset as needed, since new bitmaps are shifted and aligned to 0 already - // NOTE: this probably only works for primitive arrays. - let data_buffers = if left.offset() == 0 { - left_data.buffers().to_vec() - } else { - // Shift each data buffer by type's bit_width * offset. - left_data - .buffers() - .iter() - .map(|buf| buf.slice(left.offset() * T::get_byte_width())) - .collect::>() - }; - - // Construct new array with same values but modified null bitmap - // TODO: shift data buffer as needed - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - left.len(), - None, // force new to compute the number of null bits - modified_null_buffer, - 0, // No need for offset since left data has been shifted - data_buffers, - left_data.child_data().to_vec(), - ) - }; - Ok(PrimitiveArray::::from(data)) + Ok(BooleanArray::new(values, None)) } #[cfg(test)] mod tests { use super::*; - use crate::array::{ArrayRef, Int32Array}; use std::sync::Arc; #[test] @@ -593,6 +377,18 @@ mod tests { assert_eq!(c, expected); } + #[test] + fn test_bool_array_and_not() { + let a = BooleanArray::from(vec![false, false, true, true]); + let b = BooleanArray::from(vec![false, true, false, true]); + let c = and_not(&a, &b).unwrap(); + + let expected = BooleanArray::from(vec![false, false, true, false]); + + assert_eq!(c, expected); + assert_eq!(c, and(&a, ¬(&b).unwrap()).unwrap()); + } + #[test] fn test_bool_array_or_nulls() { let a = BooleanArray::from(vec![ @@ -731,7 +527,7 @@ mod tests { let a = BooleanArray::from(vec![false, false, false, true, true, true]); // ensure null bitmap of a is absent - assert!(a.data_ref().null_bitmap().is_none()); + assert!(a.nulls().is_none()); let b = BooleanArray::from(vec![ Some(true), @@ -743,7 +539,7 @@ mod tests { ]); // ensure null bitmap of b is present - assert!(b.data_ref().null_bitmap().is_some()); + assert!(b.nulls().is_some()); let c = or_kleene(&a, &b).unwrap(); @@ -771,12 +567,12 @@ mod tests { ]); // ensure null bitmap of b is absent - assert!(a.data_ref().null_bitmap().is_some()); + assert!(a.nulls().is_some()); let b = BooleanArray::from(vec![false, false, false, true, true, true]); // ensure null bitmap of a is present - assert!(b.data_ref().null_bitmap().is_none()); + assert!(b.nulls().is_none()); let c = or_kleene(&a, &b).unwrap(); @@ -809,8 +605,7 @@ mod tests { let a = a.as_any().downcast_ref::().unwrap(); let c = not(a).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); + let expected = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); assert_eq!(c, expected); } @@ -859,12 +654,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(8, 4); @@ -882,12 +675,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset_mod8() { let a = BooleanArray::from(vec![ - false, false, true, true, false, false, false, false, false, false, false, - false, + false, false, true, true, false, false, false, false, false, false, false, false, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(0, 4); @@ -905,8 +696,7 @@ mod tests { #[test] fn test_bool_array_and_sliced_offset1() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![false, true, false, true]); @@ -924,8 +714,7 @@ mod tests { fn test_bool_array_and_sliced_offset2() { let a = BooleanArray::from(vec![false, false, true, true]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let b = b.slice(8, 4); @@ -958,8 +747,7 @@ mod tests { let c = and(a, b).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); assert_eq!(expected, c); } @@ -973,7 +761,7 @@ mod tests { let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -981,12 +769,12 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); let a = a.slice(8, 4); - let res = is_null(a.as_ref()).unwrap(); + let res = is_null(&a).unwrap(); let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -998,7 +786,7 @@ mod tests { let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1006,12 +794,12 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); let a = a.slice(8, 4); - let res = is_not_null(a.as_ref()).unwrap(); + let res = is_not_null(&a).unwrap(); let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1023,7 +811,7 @@ mod tests { let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1049,12 +837,12 @@ mod tests { ]); let a = a.slice(8, 4); - let res = is_null(a.as_ref()).unwrap(); + let res = is_null(&a).unwrap(); let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1066,7 +854,7 @@ mod tests { let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1092,57 +880,35 @@ mod tests { ]); let a = a.slice(8, 4); - let res = is_not_null(a.as_ref()).unwrap(); + let res = is_not_null(&a).unwrap(); let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] - fn test_nullif_int_array() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9)]); - let comp = - BooleanArray::from(vec![Some(false), None, Some(true), Some(false), None]); - let res = nullif(&a, &comp).unwrap(); + fn test_null_array_is_null() { + let a = NullArray::new(3); - let expected = Int32Array::from(vec![ - Some(15), - None, - None, // comp true, slot 2 turned into null - Some(1), - // Even though comp array / right is null, should still pass through original value - // comp true, slot 2 turned into null - Some(9), - ]); + let res = is_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![true, true, true]); assert_eq!(expected, res); + assert!(res.nulls().is_none()); } #[test] - fn test_nullif_int_array_offset() { - let a = Int32Array::from(vec![None, Some(15), Some(8), Some(1), Some(9)]); - let a = a.slice(1, 3); // Some(15), Some(8), Some(1) - let a = a.as_any().downcast_ref::().unwrap(); - let comp = BooleanArray::from(vec![ - Some(false), - Some(false), - Some(false), - None, - Some(true), - Some(false), - None, - ]); - let comp = comp.slice(2, 3); // Some(false), None, Some(true) - let comp = comp.as_any().downcast_ref::().unwrap(); - let res = nullif(a, comp).unwrap(); - - let expected = Int32Array::from(vec![ - Some(15), // False => keep it - Some(8), // None => keep it - None, // true => None - ]); - assert_eq!(&expected, &res) + fn test_null_array_is_not_null() { + let a = NullArray::new(3); + + let res = is_not_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![false, false, false]); + + assert_eq!(expected, res); + assert!(res.nulls().is_none()); } } diff --git a/arrow/src/ipc/compression/mod.rs b/arrow-arith/src/lib.rs similarity index 75% rename from arrow/src/ipc/compression/mod.rs rename to arrow-arith/src/lib.rs index 666fa6d86a27..c8b6412e5efc 100644 --- a/arrow/src/ipc/compression/mod.rs +++ b/arrow-arith/src/lib.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -#[cfg(feature = "ipc_compression")] -mod codec; -#[cfg(feature = "ipc_compression")] -pub(crate) use codec::CompressionCodec; +//! Arrow arithmetic and aggregation kernels -#[cfg(not(feature = "ipc_compression"))] -mod stub; -#[cfg(not(feature = "ipc_compression"))] -pub(crate) use stub::CompressionCodec; +#![warn(missing_docs)] +pub mod aggregate; +#[doc(hidden)] // Kernels to be removed in a future release +pub mod arithmetic; +pub mod arity; +pub mod bitwise; +pub mod boolean; +pub mod numeric; +pub mod temporal; diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs new file mode 100644 index 000000000000..b6af40f7d7c2 --- /dev/null +++ b/arrow-arith/src/numeric.rs @@ -0,0 +1,1523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`] + +use std::cmp::Ordering; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::{ArrowNativeType, IntervalDayTime, IntervalMonthDayNano}; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::arity::{binary, try_binary}; + +/// Perform `lhs + rhs`, returning an error on overflow +pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Add, lhs, rhs) +} + +/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::AddWrapping, lhs, rhs) +} + +/// Perform `lhs - rhs`, returning an error on overflow +pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Sub, lhs, rhs) +} + +/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::SubWrapping, lhs, rhs) +} + +/// Perform `lhs * rhs`, returning an error on overflow +pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Mul, lhs, rhs) +} + +/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::MulWrapping, lhs, rhs) +} + +/// Perform `lhs / rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Div, lhs, rhs) +} + +/// Perform `lhs % rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Rem, lhs, rhs) +} + +macro_rules! neg_checked { + ($t:ty, $a:ident) => {{ + let array = $a + .as_primitive::<$t>() + .try_unary::<_, $t, _>(|x| x.neg_checked())?; + Ok(Arc::new(array)) + }}; +} + +macro_rules! neg_wrapping { + ($t:ty, $a:ident) => {{ + let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping()); + Ok(Arc::new(array)) + }}; +} + +/// Negates each element of `array`, returning an error on overflow +/// +/// Note: negation of unsigned arrays is not supported and will return in an error, +/// for wrapping unsigned negation consider using [`neg_wrapping`][neg_wrapping()] +pub fn neg(array: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + match array.data_type() { + Int8 => neg_checked!(Int8Type, array), + Int16 => neg_checked!(Int16Type, array), + Int32 => neg_checked!(Int32Type, array), + Int64 => neg_checked!(Int64Type, array), + Float16 => neg_wrapping!(Float16Type, array), + Float32 => neg_wrapping!(Float32Type, array), + Float64 => neg_wrapping!(Float64Type, array), + Decimal128(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Decimal256(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Duration(Second) => neg_checked!(DurationSecondType, array), + Duration(Millisecond) => neg_checked!(DurationMillisecondType, array), + Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array), + Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array), + Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array), + Interval(DayTime) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| { + let (days, ms) = IntervalDayTimeType::to_parts(x); + Ok(IntervalDayTimeType::make_value( + days.neg_checked()?, + ms.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + Interval(MonthDayNano) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); + Ok(IntervalMonthDayNanoType::make_value( + months.neg_checked()?, + days.neg_checked()?, + nanos.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + t => Err(ArrowError::InvalidArgumentError(format!( + "Invalid arithmetic operation: !{t}" + ))), + } +} + +/// Negates each element of `array`, wrapping on overflow for [`DataType::is_integer`] +pub fn neg_wrapping(array: &dyn Array) -> Result { + downcast_integer! { + array.data_type() => (neg_wrapping, array), + _ => neg(array), + } +} + +/// An enumeration of arithmetic operations +/// +/// This allows sharing the type dispatch logic across the various kernels +#[derive(Debug, Copy, Clone)] +enum Op { + AddWrapping, + Add, + SubWrapping, + Sub, + MulWrapping, + Mul, + Div, + Rem, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Op::AddWrapping | Op::Add => write!(f, "+"), + Op::SubWrapping | Op::Sub => write!(f, "-"), + Op::MulWrapping | Op::Mul => write!(f, "*"), + Op::Div => write!(f, "/"), + Op::Rem => write!(f, "%"), + } + } +} + +impl Op { + fn commutative(&self) -> bool { + matches!(self, Self::Add | Self::AddWrapping) + } +} + +/// Dispatch the given `op` to the appropriate specialized kernel +fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + macro_rules! integer_helper { + ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => { + integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar) + }; + } + + let (l, l_scalar) = lhs.get(); + let (r, r_scalar) = rhs.get(); + downcast_integer! { + l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar), + (Float16, Float16) => float_op::(op, l, l_scalar, r, r_scalar), + (Float32, Float32) => float_op::(op, l, l_scalar, r, r_scalar), + (Float64, Float64) => float_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Second, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Millisecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Microsecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Nanosecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Duration(Second), Duration(Second)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Millisecond), Duration(Millisecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Microsecond), Duration(Microsecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Interval(YearMonth), Interval(YearMonth)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (Decimal256(_, _), Decimal256(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (l_t, r_t) => match (l_t, r_t) { + (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => { + arithmetic_op(op, rhs, lhs) + } + _ => Err(ArrowError::InvalidArgumentError( + format!("Invalid arithmetic operation: {l_t} {op} {r_t}") + )) + } + } +} + +/// Perform an infallible binary operation on potentially scalar inputs +macro_rules! op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.unary(|$r| $op), + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.unary(|$l| $op), + }, + } + }; +} + +/// Same as `op` but with a type hint for the returned array +macro_rules! op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform a fallible binary operation on potentially scalar inputs +macro_rules! try_op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.try_unary(|$r| $op)?, + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.try_unary(|$l| $op)?, + }, + } + }; +} + +/// Same as `try_op` but with a type hint for the returned array +macro_rules! try_op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform an arithmetic operation on integers +fn integer_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)), + Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)), + Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Perform an arithmetic operation on floats +fn float_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)), + Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)), + }; + Ok(Arc::new(array)) +} + +/// Arithmetic trait for timestamp arrays +trait TimestampOp: ArrowTimestampType { + type Duration: ArrowPrimitiveType; + + fn add_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn add_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option; + fn add_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option; + + fn sub_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn sub_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option; + fn sub_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option; +} + +macro_rules! timestamp { + ($t:ty, $d:ty) => { + impl TimestampOp for $t { + type Duration = $d; + + fn add_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::add_year_months(left, right, tz) + } + + fn add_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option { + Self::add_day_time(left, right, tz) + } + + fn add_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option { + Self::add_month_day_nano(left, right, tz) + } + + fn sub_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::subtract_year_months(left, right, tz) + } + + fn sub_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option { + Self::subtract_day_time(left, right, tz) + } + + fn sub_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option { + Self::subtract_month_day_nano(left, right, tz) + } + } + }; +} +timestamp!(TimestampSecondType, DurationSecondType); +timestamp!(TimestampMillisecondType, DurationMillisecondType); +timestamp!(TimestampMicrosecondType, DurationMicrosecondType); +timestamp!(TimestampNanosecondType, DurationNanosecondType); + +/// Perform arithmetic operation on a timestamp array +fn timestamp_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + let l = l.as_primitive::(); + let l_tz: Tz = l.timezone().unwrap_or("+00:00").parse()?; + + let array: PrimitiveArray = match (op, r.data_type()) { + (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r))); + } + + (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.add_checked(r)) + } + (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.sub_checked(r)) + } + + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid timestamp arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } + }; + Ok(Arc::new(array.with_timezone_opt(l.timezone()))) +} + +/// Arithmetic trait for date arrays +/// +/// Note: these should be fallible (#4456) +trait DateOp: ArrowTemporalType { + fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn add_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native; + fn add_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn sub_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native; + fn sub_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native; +} + +macro_rules! date { + ($t:ty) => { + impl DateOp for $t { + fn add_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::add_year_months(left, right) + } + + fn add_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +date!(Date32Type); +date!(Date64Type); + +/// Arithmetic trait for interval arrays +trait IntervalOp: ArrowPrimitiveType { + fn add(left: Self::Native, right: Self::Native) -> Result; + fn sub(left: Self::Native, right: Self::Native) -> Result; +} + +impl IntervalOp for IntervalYearMonthType { + fn add(left: Self::Native, right: Self::Native) -> Result { + left.add_checked(right) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + left.sub_checked(right) + } +} + +impl IntervalOp for IntervalDayTimeType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.add_checked(r_days)?; + let ms = l_ms.add_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.sub_checked(r_days)?; + let ms = l_ms.sub_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } +} + +impl IntervalOp for IntervalMonthDayNanoType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.add_checked(r_months)?; + let days = l_days.add_checked(r_days)?; + let nanos = l_nanos.add_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.sub_checked(r_months)?; + let days = l_days.sub_checked(r_days)?; + let nanos = l_nanos.sub_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } +} + +/// Perform arithmetic operation on an interval array +fn interval_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +fn duration_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid duration arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on a date array +fn date_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + const NUM_SECONDS_IN_DAY: i64 = 60 * 60 * 24; + + let r_t = r.data_type(); + match (T::DATA_TYPE, op, r_t) { + (Date32, Op::Sub | Op::SubWrapping, Date32) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + return Ok(op_ref!( + DurationSecondType, + l, + l_s, + r, + r_s, + ((l as i64) - (r as i64)) * NUM_SECONDS_IN_DAY + )); + } + (Date64, Op::Sub | Op::SubWrapping, Date64) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let result = try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r)); + return Ok(result); + } + _ => {} + } + + let l = l.as_primitive::(); + match (op, r_t) { + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid date arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on decimal arrays +fn decimal_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + + let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) { + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2), + (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2), + _ => unreachable!(), + }; + + // Follow the Hive decimal arithmetic rules + // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + let array: PrimitiveArray = match op { + Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => { + // max(s1, s2) + let result_scale = *s1.max(s2); + + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8) + .saturating_add(1) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_checked((result_scale - s1) as _)?; + let r_mul = T::Native::usize_as(10).pow_checked((result_scale - s2) as _)?; + + match op { + Op::Add | Op::AddWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?) + ) + } + Op::Sub | Op::SubWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?) + ) + } + _ => unreachable!(), + } + .with_precision_and_scale(result_precision, result_scale)? + } + Op::Mul | Op::MulWrapping => { + let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION); + let result_scale = s1.saturating_add(*s2); + if result_scale > T::MAX_SCALE { + // SQL standard says that if the resulting scale of a multiply operation goes + // beyond the maximum, rounding is not acceptable and thus an error occurs + return Err(ArrowError::InvalidArgumentError(format!( + "Output scale of {} {op} {} would exceed max scale of {}", + l.data_type(), + r.data_type(), + T::MAX_SCALE + ))); + } + + try_op!(l, l_s, r, r_s, l.mul_checked(r)) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Div => { + // Follow postgres and MySQL adding a fixed scale increment of 4 + // s1 + 4 + let result_scale = s1.saturating_add(4).min(T::MAX_SCALE); + let mul_pow = result_scale - s1 + s2; + + // p1 - s1 + s2 + result_scale + let result_precision = (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); + + let (l_mul, r_mul) = match mul_pow.cmp(&0) { + Ordering::Greater => ( + T::Native::usize_as(10).pow_checked(mul_pow as _)?, + T::Native::ONE, + ), + Ordering::Equal => (T::Native::ONE, T::Native::ONE), + Ordering::Less => ( + T::Native::ONE, + T::Native::usize_as(10).pow_checked(mul_pow.neg_wrapping() as _)?, + ), + }; + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Rem => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2 -s2) + max( s1,s2 ) + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _); + let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _); + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + }; + + Ok(Arc::new(array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::{as_date, as_datetime}; + use arrow_buffer::{i256, ScalarBuffer}; + use chrono::{DateTime, NaiveDate}; + + fn test_neg_primitive( + input: &[T::Native], + out: Result<&[T::Native], &str>, + ) { + let a = PrimitiveArray::::new(ScalarBuffer::from(input.to_vec()), None); + match out { + Ok(expected) => { + let result = neg(&a).unwrap(); + assert_eq!(result.as_primitive::().values(), expected); + } + Err(e) => { + let err = neg(&a).unwrap_err().to_string(); + assert_eq!(e, err); + } + } + } + + #[test] + fn test_neg() { + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + + let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5]; + let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5]; + test_neg_primitive::(input, Ok(output)); + + test_neg_primitive::( + &[i32::MIN], + Err("Arithmetic overflow: Overflow happened on: - -2147483648"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"), + ); + + let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i32::MIN); + + let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i64::MIN); + + let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN])) + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: - -9223372036854775808" + ); + + let a = Decimal128Array::from(vec![1, 3, -44, 2, 4]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[-1, -3, 44, -2, -4] + ); + + let a = Decimal256Array::from(vec![ + i256::from_i128(342), + i256::from_i128(-4949), + i256::from_i128(3), + ]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[ + i256::from_i128(-342), + i256::from_i128(4949), + i256::from_i128(-3), + ] + ); + + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(2, 4), + IntervalYearMonthType::make_value(2, -4), + IntervalYearMonthType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalYearMonthType::make_value(-2, -4), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(3, 5), + ] + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(2, 4), + IntervalDayTimeType::make_value(2, -4), + IntervalDayTimeType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalDayTimeType::make_value(-2, -4), + IntervalDayTimeType::make_value(-2, 4), + IntervalDayTimeType::make_value(3, 5), + ] + ); + + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(2, 4, 5953394), + IntervalMonthDayNanoType::make_value(2, -4, -45839), + IntervalMonthDayNanoType::make_value(-3, -5, 6944), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalMonthDayNanoType::make_value(-2, -4, -5953394), + IntervalMonthDayNanoType::make_value(-2, 4, 45839), + IntervalMonthDayNanoType::make_value(3, 5, -6944), + ] + ); + } + + #[test] + fn test_integer() { + let a = Int32Array::from(vec![4, 3, 5, -6, 100]); + let b = Int32Array::from(vec![6, 2, 5, -7, 3]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Int32Array::from(vec![10, 5, 10, -13, 103]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![-2, 1, 0, 1, 97])); + let result = div(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![0, 1, 1, 0, 33])); + let result = mul(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![24, 6, 25, 42, 300])); + let result = rem(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![4, 1, 0, -6, 1])); + + let a = Int8Array::from(vec![Some(2), None, Some(45)]); + let b = Int8Array::from(vec![Some(5), Some(3), None]); + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int8Array::from(vec![Some(7), None, None])); + + let a = UInt8Array::from(vec![56, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Arithmetic overflow: Overflow happened on: 56 + 200"); + let result = add_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![0, 7, 8])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 - 200"); + let result = sub_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![90, 3, 254])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 * 200"); + let result = mul_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![144, 10, 15])); + + let a = Int16Array::from(vec![i16::MIN]); + let b = Int16Array::from(vec![-1]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: -32768 / -1" + ); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + #[test] + fn test_float() { + let a = Float32Array::from(vec![1., f32::MAX, 6., -4., -1., 0.]); + let b = Float32Array::from(vec![1., f32::MAX, f32::MAX, -3., 45., 0.]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![2., f32::INFINITY, f32::MAX, -7., 44.0, 0.]) + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![0., 0., f32::MIN, -1., -46., 0.]) + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![1., f32::INFINITY, f32::INFINITY, 12., -45., 0.]) + ); + + let result = div(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(r.value(0), 1.); + assert_eq!(r.value(1), 1.); + assert!(r.value(2) < f32::EPSILON); + assert_eq!(r.value(3), -4. / -3.); + assert!(r.value(5).is_nan()); + + let result = rem(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(&r.values()[..5], &[0., 0., 6., -1., -1.]); + assert!(r.value(5).is_nan()); + } + + #[test] + fn test_decimal() { + // 0.015 7.842 -0.577 0.334 -0.078 0.003 + let a = Decimal128Array::from(vec![15, 0, -577, 334, -78, 3]) + .with_precision_and_scale(12, 3) + .unwrap(); + + // 5.4 0 -35.6 0.3 0.6 7.45 + let b = Decimal128Array::from(vec![54, 34, -356, 3, 6, 745]) + .with_precision_and_scale(12, 1) + .unwrap(); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[5415, 3400, -36177, 634, 522, 74503] + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[-5385, -3400, 35023, 34, -678, -74497] + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(25, 4)); + assert_eq!( + result.as_primitive::().values(), + &[810, 0, 205412, 1002, -468, 2235] + ); + + let result = div(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(17, 7)); + assert_eq!( + result.as_primitive::().values(), + &[27777, 0, 162078, 11133333, -1300000, 402] + ); + + let result = rem(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(12, 3)); + assert_eq!( + result.as_primitive::().values(), + &[15, 0, -577, 34, -78, 3] + ); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, 3) + .unwrap(); + let b = Decimal128Array::from(vec![1]) + .with_precision_and_scale(37, 37) + .unwrap(); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38"); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, -2) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Arithmetic overflow: Overflow happened on: 10 ^ 39"); + + let a = Decimal128Array::from(vec![10]) + .with_precision_and_scale(3, -1) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: 10 * 100000000000000000000000000000000000000" + ); + + let b = Decimal128Array::from(vec![0]) + .with_precision_and_scale(1, 1) + .unwrap(); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + fn test_timestamp_impl() { + let a = PrimitiveArray::::new(vec![2000000, 434030324, 53943340].into(), None); + let b = PrimitiveArray::::new(vec![329593, 59349, 694994].into(), None); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[1670407, 433970975, 53248346] + ); + + let r2 = add(&b, &result.as_ref()).unwrap(); + assert_eq!(r2.as_ref(), &a); + + let r3 = add(&result.as_ref(), &b).unwrap(); + assert_eq!(r3.as_ref(), &a); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| as_datetime::(*x).unwrap().to_string()) + .collect() + }; + + let values = vec![ + "1970-01-01T00:00:00Z", + "2010-04-01T04:00:20Z", + "1960-01-30T04:23:20Z", + ] + .into_iter() + .map(|x| T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap()) + .collect(); + + let a = PrimitiveArray::::new(values, None); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + ]); + let r4 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r4.as_ref()), + &[ + "1977-11-01 00:00:00".to_string(), + "2008-08-01 04:00:20".to_string(), + "1966-09-30 04:23:20".to_string() + ] + ); + + let r5 = sub(&r4, &b).unwrap(); + assert_eq!(r5.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + ]); + let r6 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r6.as_ref()), + &[ + "1970-01-06 00:07:34".to_string(), + "2010-02-26 04:00:20".to_string(), + "1960-02-06 04:23:16".to_string() + ] + ); + + let r7 = sub(&r6, &b).unwrap(); + assert_eq!(r7.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + ]); + let r8 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r8.as_ref()), + &[ + "1998-10-04 23:59:17".to_string(), + "1960-09-29 04:00:33".to_string(), + "1960-07-02 04:31:33".to_string() + ] + ); + + let r9 = sub(&r8, &b).unwrap(); + // Note: subtraction is not the inverse of addition for intervals + assert_eq!( + &format_array(r9.as_ref()), + &[ + "1970-01-02 00:00:00".to_string(), + "2010-04-02 04:00:20".to_string(), + "1960-01-31 04:23:20".to_string() + ] + ); + } + + #[test] + fn test_timestamp() { + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + } + + #[test] + fn test_interval() { + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(32, 4), + IntervalYearMonthType::make_value(32, 4), + ]); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(-4, 6), + IntervalYearMonthType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(28, 10), + IntervalYearMonthType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(36, -2), + IntervalYearMonthType::make_value(35, -19) + ]) + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(32, 4), + IntervalDayTimeType::make_value(32, 4), + ]); + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(-4, 6), + IntervalDayTimeType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(28, 10), + IntervalDayTimeType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(36, -2), + IntervalDayTimeType::make_value(35, -19) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(32, 4, 4000000000000), + IntervalMonthDayNanoType::make_value(32, 4, 45463000000000000), + ]); + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(-4, 6, 46000000000000), + IntervalMonthDayNanoType::make_value(-3, 23, 3564000000000000), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(28, 10, 50000000000000), + IntervalMonthDayNanoType::make_value(29, 27, 49027000000000000) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(36, -2, -42000000000000), + IntervalMonthDayNanoType::make_value(35, -19, 41899000000000000) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::MAX]); + let b = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::ONE]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: 2147483647 + 1" + ); + } + + fn test_duration_impl>() { + let a = PrimitiveArray::::new(vec![1000, 4394, -3944].into(), None); + let b = PrimitiveArray::::new(vec![4, -5, -243].into(), None); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[1004, 4389, -4187]); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[996, 4399, -3701]); + + let err = mul(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = div(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = rem(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let a = PrimitiveArray::::new(vec![i64::MAX].into(), None); + let b = PrimitiveArray::::new(vec![1].into(), None); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: 9223372036854775807 + 1" + ); + } + + #[test] + fn test_duration() { + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + } + + fn test_date_impl(f: F) + where + F: Fn(NaiveDate) -> T::Native, + T::Native: TryInto, + { + let a = PrimitiveArray::::new( + vec![ + f(NaiveDate::from_ymd_opt(1979, 1, 30).unwrap()), + f(NaiveDate::from_ymd_opt(2010, 4, 3).unwrap()), + f(NaiveDate::from_ymd_opt(2008, 2, 29).unwrap()), + ] + .into(), + None, + ); + + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(34, 2), + IntervalYearMonthType::make_value(3, -3), + IntervalYearMonthType::make_value(-12, 4), + ]); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| { + as_date::((*x).try_into().ok().unwrap()) + .unwrap() + .to_string() + }) + .collect() + }; + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "2013-03-30".to_string(), + "2013-01-03".to_string(), + "1996-06-29".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(34, 2), + IntervalDayTimeType::make_value(3, -3), + IntervalDayTimeType::make_value(-12, 4), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-03-05".to_string(), + "2010-04-06".to_string(), + "2008-02-17".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(34, 2, -34353534), + IntervalMonthDayNanoType::make_value(3, -3, 2443), + IntervalMonthDayNanoType::make_value(-12, 4, 2323242423232), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1981-12-02".to_string(), + "2010-06-30".to_string(), + "2007-03-04".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-01-31".to_string(), + "2010-04-02".to_string(), + "2008-02-29".to_string(), + ] + ); + } + + #[test] + fn test_date() { + test_date_impl::(Date32Type::from_naive_date); + test_date_impl::(Date64Type::from_naive_date); + + let a = Date32Array::from(vec![i32::MIN, i32::MAX, 23, 7684]); + let b = Date32Array::from(vec![i32::MIN, i32::MIN, -2, 45]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[0, 371085174288000, 2160000, 660009600] + ); + + let a = Date64Array::from(vec![4343, 76676, 3434]); + let b = Date64Array::from(vec![3, -5, 5]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[4340, 76681, 3429] + ); + + let a = Date64Array::from(vec![i64::MAX]); + let b = Date64Array::from(vec![-1]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Arithmetic overflow: Overflow happened on: 9223372036854775807 - -1" + ); + } +} diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs new file mode 100644 index 000000000000..09d690d3237c --- /dev/null +++ b/arrow-arith/src/temporal.rs @@ -0,0 +1,2088 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines temporal kernels for time and date related functions. + +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use cast::as_primitive_array; +use chrono::{Datelike, NaiveDateTime, Offset, TimeZone, Timelike, Utc}; + +use arrow_array::temporal_conversions::{ + date32_to_datetime, date64_to_datetime, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_s_to_datetime, timestamp_us_to_datetime, MICROSECONDS, MICROSECONDS_IN_DAY, + MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS, NANOSECONDS_IN_DAY, SECONDS_IN_DAY, +}; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +/// Valid parts to extract from date/time/timestamp arrays. +/// +/// See [`date_part`]. +/// +/// Marked as non-exhaustive as may expand to support more types of +/// date parts in the future. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum DatePart { + /// Quarter of the year, in range `1..=4` + Quarter, + /// Calendar year + Year, + /// Month in the year, in range `1..=12` + Month, + /// ISO week of the year, in range `1..=53` + Week, + /// Day of the month, in range `1..=31` + Day, + /// Day of the week, in range `0..=6`, where Sunday is `0` + DayOfWeekSunday0, + /// Day of the week, in range `0..=6`, where Monday is `0` + DayOfWeekMonday0, + /// Day of year, in range `1..=366` + DayOfYear, + /// Hour of the day, in range `0..=23` + Hour, + /// Minute of the hour, in range `0..=59` + Minute, + /// Second of the minute, in range `0..=59` + Second, + /// Millisecond of the second + Millisecond, + /// Microsecond of the second + Microsecond, + /// Nanosecond of the second + Nanosecond, +} + +impl std::fmt::Display for DatePart { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +/// Returns function to extract relevant [`DatePart`] from types like a +/// [`NaiveDateTime`] or [`DateTime`]. +/// +/// [`DateTime`]: chrono::DateTime +fn get_date_time_part_extract_fn(part: DatePart) -> fn(T) -> i32 +where + T: ChronoDateExt + Datelike + Timelike, +{ + match part { + DatePart::Quarter => |d| d.quarter() as i32, + DatePart::Year => |d| d.year(), + DatePart::Month => |d| d.month() as i32, + DatePart::Week => |d| d.iso_week().week() as i32, + DatePart::Day => |d| d.day() as i32, + DatePart::DayOfWeekSunday0 => |d| d.num_days_from_sunday(), + DatePart::DayOfWeekMonday0 => |d| d.num_days_from_monday(), + DatePart::DayOfYear => |d| d.ordinal() as i32, + DatePart::Hour => |d| d.hour() as i32, + DatePart::Minute => |d| d.minute() as i32, + DatePart::Second => |d| d.second() as i32, + DatePart::Millisecond => |d| (d.nanosecond() / 1_000_000) as i32, + DatePart::Microsecond => |d| (d.nanosecond() / 1_000) as i32, + DatePart::Nanosecond => |d| (d.nanosecond()) as i32, + } +} + +/// Given an array, return a new array with the extracted [`DatePart`] as signed 32-bit +/// integer values. +/// +/// Currently only supports temporal types: +/// - Date32/Date64 +/// - Time32/Time64 +/// - Timestamp +/// - Interval +/// - Duration +/// +/// Returns an [`Int32Array`] unless input was a dictionary type, in which case returns +/// the dictionary but with this function applied onto its values. +/// +/// If array passed in is not of the above listed types (or is a dictionary array where the +/// values array isn't of the above listed types), then this function will return an error. +/// +/// # Examples +/// +/// ``` +/// # use arrow_array::{Int32Array, TimestampMicrosecondArray}; +/// # use arrow_arith::temporal::{DatePart, date_part}; +/// let input: TimestampMicrosecondArray = +/// vec![Some(1612025847000000), None, Some(1722015847000000)].into(); +/// +/// let actual = date_part(&input, DatePart::Week).unwrap(); +/// let expected: Int32Array = vec![Some(4), None, Some(30)].into(); +/// assert_eq!(actual.as_ref(), &expected); +/// ``` +pub fn date_part(array: &dyn Array, part: DatePart) -> Result { + downcast_temporal_array!( + array => { + let array = array.date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Interval(IntervalUnit::DayTime) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Duration(TimeUnit::Second) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Duration(TimeUnit::Millisecond) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Duration(TimeUnit::Microsecond) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Duration(TimeUnit::Nanosecond) => { + let array = as_primitive_array::(array).date_part(part)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } + DataType::Dictionary(_, _) => { + let array = array.as_any_dictionary(); + let values = date_part(array.values(), part)?; + let values = Arc::new(values) as ArrayRef; + let new_array = array.with_values(values); + Ok(new_array) + } + t => return_compute_error_with!(format!("{part} does not support"), t), + ) +} + +/// Used to integrate new [`date_part()`] method with deprecated shims such as +/// [`hour()`] and [`week()`]. +fn date_part_primitive( + array: &PrimitiveArray, + part: DatePart, +) -> Result { + let array = date_part(array, part)?; + Ok(array.as_primitive::().to_owned()) +} + +/// Extract optional [`Tz`] from timestamp data types, returning error +/// if called with a non-timestamp type. +fn get_tz(dt: &DataType) -> Result, ArrowError> { + match dt { + DataType::Timestamp(_, Some(tz)) => Ok(Some(tz.parse::()?)), + DataType::Timestamp(_, None) => Ok(None), + _ => Err(ArrowError::CastError(format!("Not a timestamp type: {dt}"))), + } +} + +/// Implement the specialized functions for extracting date part from temporal arrays. +trait ExtractDatePartExt { + fn date_part(&self, part: DatePart) -> Result; +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + #[inline] + fn range_check(s: i32) -> bool { + (0..SECONDS_IN_DAY as i32).contains(&s) + } + match part { + DatePart::Hour => Ok(self.unary_opt(|s| range_check(s).then_some(s / 3_600))), + DatePart::Minute => Ok(self.unary_opt(|s| range_check(s).then_some((s / 60) % 60))), + DatePart::Second => Ok(self.unary_opt(|s| range_check(s).then_some(s % 60))), + // Time32Second only encodes number of seconds, so these will always be 0 (if in valid range) + DatePart::Millisecond | DatePart::Microsecond | DatePart::Nanosecond => { + Ok(self.unary_opt(|s| range_check(s).then_some(0))) + } + _ => return_compute_error_with!(format!("{part} does not support"), self.data_type()), + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + #[inline] + fn range_check(ms: i32) -> bool { + (0..MILLISECONDS_IN_DAY as i32).contains(&ms) + } + let milliseconds = MILLISECONDS as i32; + match part { + DatePart::Hour => { + Ok(self.unary_opt(|ms| range_check(ms).then_some(ms / 3_600 / milliseconds))) + } + DatePart::Minute => { + Ok(self.unary_opt(|ms| range_check(ms).then_some((ms / 60 / milliseconds) % 60))) + } + DatePart::Second => { + Ok(self.unary_opt(|ms| range_check(ms).then_some((ms / milliseconds) % 60))) + } + DatePart::Millisecond => { + Ok(self.unary_opt(|ms| range_check(ms).then_some(ms % milliseconds))) + } + DatePart::Microsecond => { + Ok(self.unary_opt(|ms| range_check(ms).then_some((ms % milliseconds) * 1_000))) + } + DatePart::Nanosecond => { + Ok(self.unary_opt(|ms| range_check(ms).then_some((ms % milliseconds) * 1_000_000))) + } + _ => return_compute_error_with!(format!("{part} does not support"), self.data_type()), + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + #[inline] + fn range_check(us: i64) -> bool { + (0..MICROSECONDS_IN_DAY).contains(&us) + } + match part { + DatePart::Hour => { + Ok(self + .unary_opt(|us| range_check(us).then_some((us / 3_600 / MICROSECONDS) as i32))) + } + DatePart::Minute => Ok(self + .unary_opt(|us| range_check(us).then_some(((us / 60 / MICROSECONDS) % 60) as i32))), + DatePart::Second => { + Ok(self + .unary_opt(|us| range_check(us).then_some(((us / MICROSECONDS) % 60) as i32))) + } + DatePart::Millisecond => Ok(self + .unary_opt(|us| range_check(us).then_some(((us % MICROSECONDS) / 1_000) as i32))), + DatePart::Microsecond => { + Ok(self.unary_opt(|us| range_check(us).then_some((us % MICROSECONDS) as i32))) + } + DatePart::Nanosecond => Ok(self + .unary_opt(|us| range_check(us).then_some(((us % MICROSECONDS) * 1_000) as i32))), + _ => return_compute_error_with!(format!("{part} does not support"), self.data_type()), + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + #[inline] + fn range_check(ns: i64) -> bool { + (0..NANOSECONDS_IN_DAY).contains(&ns) + } + match part { + DatePart::Hour => { + Ok(self + .unary_opt(|ns| range_check(ns).then_some((ns / 3_600 / NANOSECONDS) as i32))) + } + DatePart::Minute => Ok(self + .unary_opt(|ns| range_check(ns).then_some(((ns / 60 / NANOSECONDS) % 60) as i32))), + DatePart::Second => Ok( + self.unary_opt(|ns| range_check(ns).then_some(((ns / NANOSECONDS) % 60) as i32)) + ), + DatePart::Millisecond => Ok(self.unary_opt(|ns| { + range_check(ns).then_some(((ns % NANOSECONDS) / 1_000_000) as i32) + })), + DatePart::Microsecond => { + Ok(self + .unary_opt(|ns| range_check(ns).then_some(((ns % NANOSECONDS) / 1_000) as i32))) + } + DatePart::Nanosecond => { + Ok(self.unary_opt(|ns| range_check(ns).then_some((ns % NANOSECONDS) as i32))) + } + _ => return_compute_error_with!(format!("{part} does not support"), self.data_type()), + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + // Date32 only encodes number of days, so these will always be 0 + if let DatePart::Hour + | DatePart::Minute + | DatePart::Second + | DatePart::Millisecond + | DatePart::Microsecond + | DatePart::Nanosecond = part + { + Ok(Int32Array::new( + vec![0; self.len()].into(), + self.nulls().cloned(), + )) + } else { + let map_func = get_date_time_part_extract_fn(part); + Ok(self.unary_opt(|d| date32_to_datetime(d).map(map_func))) + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + let map_func = get_date_time_part_extract_fn(part); + Ok(self.unary_opt(|d| date64_to_datetime(d).map(map_func))) + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + // TimestampSecond only encodes number of seconds, so these will always be 0 + let array = + if let DatePart::Millisecond | DatePart::Microsecond | DatePart::Nanosecond = part { + Int32Array::new(vec![0; self.len()].into(), self.nulls().cloned()) + } else if let Some(tz) = get_tz(self.data_type())? { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| { + timestamp_s_to_datetime(d) + .map(|c| Utc.from_utc_datetime(&c).with_timezone(&tz)) + .map(map_func) + }) + } else { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| timestamp_s_to_datetime(d).map(map_func)) + }; + Ok(array) + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + let array = if let Some(tz) = get_tz(self.data_type())? { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| { + timestamp_ms_to_datetime(d) + .map(|c| Utc.from_utc_datetime(&c).with_timezone(&tz)) + .map(map_func) + }) + } else { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| timestamp_ms_to_datetime(d).map(map_func)) + }; + Ok(array) + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + let array = if let Some(tz) = get_tz(self.data_type())? { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| { + timestamp_us_to_datetime(d) + .map(|c| Utc.from_utc_datetime(&c).with_timezone(&tz)) + .map(map_func) + }) + } else { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| timestamp_us_to_datetime(d).map(map_func)) + }; + Ok(array) + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + let array = if let Some(tz) = get_tz(self.data_type())? { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| { + timestamp_ns_to_datetime(d) + .map(|c| Utc.from_utc_datetime(&c).with_timezone(&tz)) + .map(map_func) + }) + } else { + let map_func = get_date_time_part_extract_fn(part); + self.unary_opt(|d| timestamp_ns_to_datetime(d).map(map_func)) + }; + Ok(array) + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Year => Ok(self.unary_opt(|d| Some(d / 12))), + DatePart::Month => Ok(self.unary_opt(|d| Some(d % 12))), + + DatePart::Quarter + | DatePart::Week + | DatePart::Day + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear + | DatePart::Hour + | DatePart::Minute + | DatePart::Second + | DatePart::Millisecond + | DatePart::Microsecond + | DatePart::Nanosecond => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Week => Ok(self.unary_opt(|d| Some(d.days / 7))), + DatePart::Day => Ok(self.unary_opt(|d| Some(d.days))), + DatePart::Hour => Ok(self.unary_opt(|d| Some(d.milliseconds / (60 * 60 * 1_000)))), + DatePart::Minute => Ok(self.unary_opt(|d| Some(d.milliseconds / (60 * 1_000)))), + DatePart::Second => Ok(self.unary_opt(|d| Some(d.milliseconds / 1_000))), + DatePart::Millisecond => Ok(self.unary_opt(|d| Some(d.milliseconds))), + DatePart::Microsecond => Ok(self.unary_opt(|d| d.milliseconds.checked_mul(1_000))), + DatePart::Nanosecond => Ok(self.unary_opt(|d| d.milliseconds.checked_mul(1_000_000))), + + DatePart::Quarter + | DatePart::Year + | DatePart::Month + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Year => Ok(self.unary_opt(|d: IntervalMonthDayNano| Some(d.months / 12))), + DatePart::Month => Ok(self.unary_opt(|d: IntervalMonthDayNano| Some(d.months))), + DatePart::Week => Ok(self.unary_opt(|d: IntervalMonthDayNano| Some(d.days / 7))), + DatePart::Day => Ok(self.unary_opt(|d: IntervalMonthDayNano| Some(d.days))), + DatePart::Hour => { + Ok(self.unary_opt(|d| (d.nanoseconds / (60 * 60 * 1_000_000_000)).try_into().ok())) + } + DatePart::Minute => { + Ok(self.unary_opt(|d| (d.nanoseconds / (60 * 1_000_000_000)).try_into().ok())) + } + DatePart::Second => { + Ok(self.unary_opt(|d| (d.nanoseconds / 1_000_000_000).try_into().ok())) + } + DatePart::Millisecond => { + Ok(self.unary_opt(|d| (d.nanoseconds / 1_000_000).try_into().ok())) + } + DatePart::Microsecond => { + Ok(self.unary_opt(|d| (d.nanoseconds / 1_000).try_into().ok())) + } + DatePart::Nanosecond => Ok(self.unary_opt(|d| d.nanoseconds.try_into().ok())), + + DatePart::Quarter + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Week => Ok(self.unary_opt(|d| (d / (60 * 60 * 24 * 7)).try_into().ok())), + DatePart::Day => Ok(self.unary_opt(|d| (d / (60 * 60 * 24)).try_into().ok())), + DatePart::Hour => Ok(self.unary_opt(|d| (d / (60 * 60)).try_into().ok())), + DatePart::Minute => Ok(self.unary_opt(|d| (d / 60).try_into().ok())), + DatePart::Second => Ok(self.unary_opt(|d| d.try_into().ok())), + DatePart::Millisecond => { + Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok()))) + } + DatePart::Microsecond => { + Ok(self.unary_opt(|d| d.checked_mul(1_000_000).and_then(|d| d.try_into().ok()))) + } + DatePart::Nanosecond => Ok( + self.unary_opt(|d| d.checked_mul(1_000_000_000).and_then(|d| d.try_into().ok())) + ), + + DatePart::Year + | DatePart::Quarter + | DatePart::Month + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Week => { + Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60 * 24 * 7)).try_into().ok())) + } + DatePart::Day => Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60 * 24)).try_into().ok())), + DatePart::Hour => Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60)).try_into().ok())), + DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000 * 60)).try_into().ok())), + DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())), + DatePart::Millisecond => Ok(self.unary_opt(|d| d.try_into().ok())), + DatePart::Microsecond => { + Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok()))) + } + DatePart::Nanosecond => { + Ok(self.unary_opt(|d| d.checked_mul(1_000_000).and_then(|d| d.try_into().ok()))) + } + + DatePart::Year + | DatePart::Quarter + | DatePart::Month + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Week => { + Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60 * 24 * 7)).try_into().ok())) + } + DatePart::Day => { + Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60 * 24)).try_into().ok())) + } + DatePart::Hour => Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60)).try_into().ok())), + DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000_000 * 60)).try_into().ok())), + DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000_000).try_into().ok())), + DatePart::Millisecond => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())), + DatePart::Microsecond => Ok(self.unary_opt(|d| d.try_into().ok())), + DatePart::Nanosecond => { + Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok()))) + } + + DatePart::Year + | DatePart::Quarter + | DatePart::Month + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +impl ExtractDatePartExt for PrimitiveArray { + fn date_part(&self, part: DatePart) -> Result { + match part { + DatePart::Week => { + Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60 * 24 * 7)).try_into().ok())) + } + DatePart::Day => { + Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60 * 24)).try_into().ok())) + } + DatePart::Hour => { + Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60)).try_into().ok())) + } + DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60)).try_into().ok())), + DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000_000_000).try_into().ok())), + DatePart::Millisecond => Ok(self.unary_opt(|d| (d / 1_000_000).try_into().ok())), + DatePart::Microsecond => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())), + DatePart::Nanosecond => Ok(self.unary_opt(|d| d.try_into().ok())), + + DatePart::Year + | DatePart::Quarter + | DatePart::Month + | DatePart::DayOfWeekSunday0 + | DatePart::DayOfWeekMonday0 + | DatePart::DayOfYear => { + return_compute_error_with!(format!("{part} does not support"), self.data_type()) + } + } + } +} + +macro_rules! return_compute_error_with { + ($msg:expr, $param:expr) => { + return { Err(ArrowError::ComputeError(format!("{}: {:?}", $msg, $param))) } + }; +} + +pub(crate) use return_compute_error_with; + +// Internal trait, which is used for mapping values from DateLike structures +trait ChronoDateExt { + /// Returns a value in range `1..=4` indicating the quarter this date falls into + fn quarter(&self) -> u32; + + /// Returns a value in range `0..=3` indicating the quarter (zero-based) this date falls into + fn quarter0(&self) -> u32; + + /// Returns the day of week; Monday is encoded as `0`, Tuesday as `1`, etc. + fn num_days_from_monday(&self) -> i32; + + /// Returns the day of week; Sunday is encoded as `0`, Monday as `1`, etc. + fn num_days_from_sunday(&self) -> i32; +} + +impl ChronoDateExt for T { + fn quarter(&self) -> u32 { + self.quarter0() + 1 + } + + fn quarter0(&self) -> u32 { + self.month0() / 3 + } + + fn num_days_from_monday(&self) -> i32 { + self.weekday().num_days_from_monday() as i32 + } + + fn num_days_from_sunday(&self) -> i32 { + self.weekday().num_days_from_sunday() as i32 + } +} + +/// Parse the given string into a string representing fixed-offset that is correct as of the given +/// UTC NaiveDateTime. +/// +/// Note that the offset is function of time and can vary depending on whether daylight savings is +/// in effect or not. e.g. Australia/Sydney is +10:00 or +11:00 depending on DST. +#[deprecated(note = "Use arrow_array::timezone::Tz instead")] +pub fn using_chrono_tz_and_utc_naive_date_time( + tz: &str, + utc: NaiveDateTime, +) -> Option { + let tz: Tz = tz.parse().ok()?; + Some(tz.offset_from_utc_datetime(&utc).fix()) +} + +/// Extracts the hours of a given array as an array of integers within +/// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn hour_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Hour) +} + +/// Extracts the hours of a given temporal primitive array as an array of integers within +/// the range of [0, 23]. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn hour(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Hour) +} + +/// Extracts the years of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn year_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Year) +} + +/// Extracts the years of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn year(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Year) +} + +/// Extracts the quarter of a given temporal array as an array of integersa within +/// the range of [1, 4]. If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn quarter_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Quarter) +} + +/// Extracts the quarter of a given temporal primitive array as an array of integers within +/// the range of [1, 4]. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn quarter(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Quarter) +} + +/// Extracts the month of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn month_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Month) +} + +/// Extracts the month of a given temporal primitive array as an array of integers within +/// the range of [1, 12]. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn month(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Month) +} + +/// Extracts the day of week of a given temporal array as an array of +/// integers. +/// +/// Monday is encoded as `0`, Tuesday as `1`, etc. +/// +/// See also [`num_days_from_sunday`] which starts at Sunday. +/// +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn num_days_from_monday_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::DayOfWeekMonday0) +} + +/// Extracts the day of week of a given temporal primitive array as an array of +/// integers. +/// +/// Monday is encoded as `0`, Tuesday as `1`, etc. +/// +/// See also [`num_days_from_sunday`] which starts at Sunday. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn num_days_from_monday(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::DayOfWeekMonday0) +} + +/// Extracts the day of week of a given temporal array as an array of +/// integers, starting at Sunday. +/// +/// Sunday is encoded as `0`, Monday as `1`, etc. +/// +/// See also [`num_days_from_monday`] which starts at Monday. +/// +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn num_days_from_sunday_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::DayOfWeekSunday0) +} + +/// Extracts the day of week of a given temporal primitive array as an array of +/// integers, starting at Sunday. +/// +/// Sunday is encoded as `0`, Monday as `1`, etc. +/// +/// See also [`num_days_from_monday`] which starts at Monday. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn num_days_from_sunday(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::DayOfWeekSunday0) +} + +/// Extracts the day of a given temporal array as an array of integers. +/// +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn day_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Day) +} + +/// Extracts the day of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn day(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Day) +} + +/// Extracts the day of year of a given temporal array as an array of integers. +/// +/// The day of year that ranges from 1 to 366. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn doy_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::DayOfYear) +} + +/// Extracts the day of year of a given temporal primitive array as an array of integers. +/// +/// The day of year that ranges from 1 to 366 +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn doy(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + T::Native: ArrowNativeType, + i64: From, +{ + date_part_primitive(array, DatePart::DayOfYear) +} + +/// Extracts the minutes of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn minute(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Minute) +} + +/// Extracts the week of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn week_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Week) +} + +/// Extracts the week of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn week(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Week) +} + +/// Extracts the seconds of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn second(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Second) +} + +/// Extracts the nanoseconds of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn nanosecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Nanosecond) +} + +/// Extracts the nanoseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn nanosecond_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Nanosecond) +} + +/// Extracts the microseconds of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn microsecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Microsecond) +} + +/// Extracts the microseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn microsecond_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Microsecond) +} + +/// Extracts the milliseconds of a given temporal primitive array as an array of integers +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn millisecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + date_part_primitive(array, DatePart::Millisecond) +} + +/// Extracts the milliseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn millisecond_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Millisecond) +} + +/// Extracts the minutes of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn minute_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Minute) +} + +/// Extracts the seconds of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] +pub fn second_dyn(array: &dyn Array) -> Result { + date_part(array, DatePart::Second) +} + +#[cfg(test)] +#[allow(deprecated)] +mod tests { + use super::*; + + #[test] + fn test_temporal_array_date64_hour() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = hour(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(4, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_hour() { + let a: PrimitiveArray = vec![Some(15147), None, Some(15148)].into(); + + let b = hour(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(0, b.value(2)); + } + + #[test] + fn test_temporal_array_time32_second_hour() { + let a: PrimitiveArray = vec![37800, 86339].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_time64_micro_hour() { + let a: PrimitiveArray = vec![37800000000, 86339000000].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_micro_hour() { + let a: TimestampMicrosecondArray = vec![37800000000, 86339000000].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_date64_year() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2018, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2019, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_year() { + let a: PrimitiveArray = vec![Some(15147), None, Some(15448)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2011, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2012, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_quarter() { + //1514764800000 -> 2018-01-01 + //1566275025000 -> 2019-08-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1566275025000)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(3, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_quarter() { + let a: PrimitiveArray = vec![Some(1), None, Some(300)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(4, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_quarter_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("+00:00".to_string()); + let b = quarter(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("-10:00".to_string()); + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_month() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_month() { + let a: PrimitiveArray = vec![Some(1), None, Some(31)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_month_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("+00:00".to_string()); + let b = month(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("-10:00".to_string()); + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_day_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400]).with_timezone("+00:00".to_string()); + let b = day(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400]).with_timezone("-10:00".to_string()); + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_weekday() { + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = num_days_from_monday(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_weekday0() { + //1483228800000 -> 2017-01-01 (Sunday) + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = vec![ + Some(1483228800000), + None, + Some(1514764800000), + Some(1550636625000), + ] + .into(); + + let b = num_days_from_sunday(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + assert_eq!(3, b.value(3)); + } + + #[test] + fn test_temporal_array_date64_day() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(20, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_day() { + let a: PrimitiveArray = vec![Some(0), None, Some(31)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_doy() { + //1483228800000 -> 2017-01-01 (Sunday) + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = vec![ + Some(1483228800000), + Some(1514764800000), + None, + Some(1550636625000), + ] + .into(); + + let b = doy(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(1, b.value(1)); + assert!(!b.is_valid(2)); + assert_eq!(51, b.value(3)); + } + + #[test] + fn test_temporal_array_timestamp_micro_year() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2021, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2024, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_minute() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = minute(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(23, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_micro_minute() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = minute(&a).unwrap(); + assert_eq!(57, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(44, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_week() { + let a: PrimitiveArray = vec![Some(0), None, Some(7)].into(); + + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_week() { + // 1646116175000 -> 2022.03.01 , 1641171600000 -> 2022.01.03 + // 1640998800000 -> 2022.01.01 + let a: PrimitiveArray = vec![ + Some(1646116175000), + None, + Some(1641171600000), + Some(1640998800000), + ] + .into(); + + let b = week(&a).unwrap(); + assert_eq!(9, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + assert_eq!(52, b.value(3)); + } + + #[test] + fn test_temporal_array_timestamp_micro_week() { + //1612025847000000 -> 2021.1.30 + //1722015847000000 -> 2024.7.27 + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = week(&a).unwrap(); + assert_eq!(4, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(30, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_second() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = second(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(45, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_micro_second() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = second(&a).unwrap(); + assert_eq!(27, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(7, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_second_with_timezone() { + let a = TimestampSecondArray::from(vec![10, 20]).with_timezone("+00:00".to_string()); + let b = second(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(20, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_minute_with_timezone() { + let a = TimestampSecondArray::from(vec![0, 60]).with_timezone("+00:50".to_string()); + let b = minute(&a).unwrap(); + assert_eq!(50, b.value(0)); + assert_eq!(51, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_minute_with_negative_timezone() { + let a = TimestampSecondArray::from(vec![60 * 55]).with_timezone("-00:50".to_string()); + let b = minute(&a).unwrap(); + assert_eq!(5, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01:00".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_colon() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+0100".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_minutes() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_initial_sign() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("0100".to_string()); + let err = hour(&a).unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_with_only_colon() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("01:00".to_string()); + let err = hour(&a).unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } + + #[test] + fn test_temporal_array_timestamp_week_without_timezone() { + // 1970-01-01T00:00:00 -> 1970-01-01T00:00:00 Thursday (week 1) + // 1970-01-01T00:00:00 + 4 days -> 1970-01-05T00:00:00 Monday (week 2) + // 1970-01-01T00:00:00 + 4 days - 1 second -> 1970-01-04T23:59:59 Sunday (week 1) + let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]); + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(2, b.value(1)); + assert_eq!(1, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_week_with_timezone() { + // 1970-01-01T01:00:00+01:00 -> 1970-01-01T01:00:00+01:00 Thursday (week 1) + // 1970-01-01T01:00:00+01:00 + 4 days -> 1970-01-05T01:00:00+01:00 Monday (week 2) + // 1970-01-01T01:00:00+01:00 + 4 days - 1 second -> 1970-01-05T00:59:59+01:00 Monday (week 2) + let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]) + .with_timezone("+01:00".to_string()); + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(2, b.value(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_hour_minute_second_dictionary_array() { + let a = TimestampSecondArray::from(vec![ + 60 * 60 * 10 + 61, + 60 * 60 * 20 + 122, + 60 * 60 * 30 + 183, + ]) + .with_timezone("+01:00".to_string()); + + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 1]); + let dict = DictionaryArray::try_new(keys.clone(), Arc::new(a)).unwrap(); + + let b = hour_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![11, 21, 7]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + + let b = date_part(&dict, DatePart::Minute).unwrap(); + + let b_old = minute_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + assert_eq!(&expected, &b_old); + + let b = date_part(&dict, DatePart::Second).unwrap(); + + let b_old = second_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + assert_eq!(&expected, &b_old); + + let b = date_part(&dict, DatePart::Nanosecond).unwrap(); + + let expected_dict = + DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_year_dictionary_array() { + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); + + let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = year_dyn(&dict).unwrap(); + + let expected_dict = DictionaryArray::new( + keys, + Arc::new(Int32Array::from(vec![2018, 2019, 2019, 2018])), + ); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_quarter_month_dictionary_array() { + //1514764800000 -> 2018-01-01 + //1566275025000 -> 2019-08-20 + let a: PrimitiveArray = vec![Some(1514764800000), Some(1566275025000)].into(); + + let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = quarter_dyn(&dict).unwrap(); + + let expected = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 3, 3, 1]))); + assert_eq!(b.as_ref(), &expected); + + let b = month_dyn(&dict).unwrap(); + + let expected = DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![1, 8, 8, 1]))); + assert_eq!(b.as_ref(), &expected); + } + + #[test] + fn test_num_days_from_monday_sunday_day_doy_week_dictionary_array() { + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), Some(0), None]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = num_days_from_monday_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(0), Some(2), Some(2), Some(0), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = num_days_from_sunday_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(3), Some(3), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = day_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(20), Some(20), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = doy_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(51), Some(51), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = week_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(8), Some(8), Some(1), None]); + let expected = DictionaryArray::new(keys, Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + } + + #[test] + fn test_temporal_array_date64_nanosecond() { + // new Date(1667328721453) + // Tue Nov 01 2022 11:52:01 GMT-0700 (Pacific Daylight Time) + // + // new Date(1667328721453).getMilliseconds() + // 453 + + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = nanosecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453_000_000, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = nanosecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453_000_000)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_temporal_array_date64_microsecond() { + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = microsecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453_000, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = microsecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453_000)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_temporal_array_date64_millisecond() { + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = millisecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = millisecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_temporal_array_time64_nanoseconds() { + // 23:32:50.123456789 + let input: Time64NanosecondArray = vec![Some(84_770_123_456_789)].into(); + + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(23, actual.value(0)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(32, actual.value(0)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(50, actual.value(0)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123, actual.value(0)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_456, actual.value(0)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_456_789, actual.value(0)); + + // invalid values should turn into null + let input: Time64NanosecondArray = vec![ + Some(-1), + Some(86_400_000_000_000), + Some(86_401_000_000_000), + None, + ] + .into(); + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + let expected: Int32Array = vec![None, None, None, None].into(); + assert_eq!(&expected, actual); + } + + #[test] + fn test_temporal_array_time64_microseconds() { + // 23:32:50.123456 + let input: Time64MicrosecondArray = vec![Some(84_770_123_456)].into(); + + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(23, actual.value(0)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(32, actual.value(0)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(50, actual.value(0)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123, actual.value(0)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_456, actual.value(0)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_456_000, actual.value(0)); + + // invalid values should turn into null + let input: Time64MicrosecondArray = + vec![Some(-1), Some(86_400_000_000), Some(86_401_000_000), None].into(); + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + let expected: Int32Array = vec![None, None, None, None].into(); + assert_eq!(&expected, actual); + } + + #[test] + fn test_temporal_array_time32_milliseconds() { + // 23:32:50.123 + let input: Time32MillisecondArray = vec![Some(84_770_123)].into(); + + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(23, actual.value(0)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(32, actual.value(0)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(50, actual.value(0)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123, actual.value(0)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_000, actual.value(0)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(123_000_000, actual.value(0)); + + // invalid values should turn into null + let input: Time32MillisecondArray = + vec![Some(-1), Some(86_400_000), Some(86_401_000), None].into(); + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + let expected: Int32Array = vec![None, None, None, None].into(); + assert_eq!(&expected, actual); + } + + #[test] + fn test_temporal_array_time32_seconds() { + // 23:32:50 + let input: Time32SecondArray = vec![84_770].into(); + + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(23, actual.value(0)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(32, actual.value(0)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(50, actual.value(0)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + + // invalid values should turn into null + let input: Time32SecondArray = vec![Some(-1), Some(86_400), Some(86_401), None].into(); + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + let expected: Int32Array = vec![None, None, None, None].into(); + assert_eq!(&expected, actual); + } + + #[test] + fn test_temporal_array_time_invalid_parts() { + fn ensure_returns_error(array: &dyn Array) { + let invalid_parts = [ + DatePart::Quarter, + DatePart::Year, + DatePart::Month, + DatePart::Week, + DatePart::Day, + DatePart::DayOfWeekSunday0, + DatePart::DayOfWeekMonday0, + DatePart::DayOfYear, + ]; + + for part in invalid_parts { + let err = date_part(array, part).unwrap_err(); + let expected = format!( + "Compute error: {part} does not support: {}", + array.data_type() + ); + assert_eq!(expected, err.to_string()); + } + } + + ensure_returns_error(&Time32SecondArray::from(vec![0])); + ensure_returns_error(&Time32MillisecondArray::from(vec![0])); + ensure_returns_error(&Time64MicrosecondArray::from(vec![0])); + ensure_returns_error(&Time64NanosecondArray::from(vec![0])); + } + + #[test] + fn test_interval_year_month_array() { + let input: IntervalYearMonthArray = vec![0, 5, 24].into(); + + let actual = date_part(&input, DatePart::Year).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(2, actual.value(2)); + + let actual = date_part(&input, DatePart::Month).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(5, actual.value(1)); + assert_eq!(0, actual.value(2)); + + assert!(date_part(&input, DatePart::Day).is_err()); + assert!(date_part(&input, DatePart::Week).is_err()); + } + + // IntervalDayTimeType week, day, hour, minute, second, milli, u, nano; + // invalid month, year; ignores the other part + #[test] + fn test_interval_day_time_array() { + let input: IntervalDayTimeArray = vec![ + IntervalDayTime::ZERO, + IntervalDayTime::new(10, 42), + IntervalDayTime::new(10, 1042), + IntervalDayTime::new(10, MILLISECONDS_IN_DAY as i32 + 1), + ] + .into(); + + // Time doesn't affect days. + let actual = date_part(&input, DatePart::Day).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(10, actual.value(1)); + assert_eq!(10, actual.value(2)); + assert_eq!(10, actual.value(3)); + + let actual = date_part(&input, DatePart::Week).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(1, actual.value(1)); + assert_eq!(1, actual.value(2)); + assert_eq!(1, actual.value(3)); + + // Days doesn't affect time. + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000_000, actual.value(1)); + assert_eq!(1_042_000_000, actual.value(2)); + // Overflow returns zero. + assert_eq!(0, actual.value(3)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000, actual.value(1)); + assert_eq!(1_042_000, actual.value(2)); + // Overflow returns zero. + assert_eq!(0, actual.value(3)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + assert_eq!(1042, actual.value(2)); + assert_eq!(MILLISECONDS_IN_DAY as i32 + 1, actual.value(3)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(1, actual.value(2)); + assert_eq!(24 * 60 * 60, actual.value(3)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + assert_eq!(24 * 60, actual.value(3)); + + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + assert_eq!(24, actual.value(3)); + + // Month and year are not valid (since days in month varies). + assert!(date_part(&input, DatePart::Month).is_err()); + assert!(date_part(&input, DatePart::Year).is_err()); + } + + // IntervalMonthDayNanoType year -> nano; + // days don't affect months, time doesn't affect days, time doesn't affect months (and vice versa) + #[test] + fn test_interval_month_day_nano_array() { + let input: IntervalMonthDayNanoArray = vec![ + IntervalMonthDayNano::ZERO, + IntervalMonthDayNano::new(5, 10, 42), + IntervalMonthDayNano::new(16, 35, MILLISECONDS_IN_DAY * 1_000_000 + 1), + ] + .into(); + + // Year and month follow from month, but are not affected by days or nanos. + let actual = date_part(&input, DatePart::Year).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(1, actual.value(2)); + + let actual = date_part(&input, DatePart::Month).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(5, actual.value(1)); + assert_eq!(16, actual.value(2)); + + // Week and day follow from day, but are not affected by months or nanos. + let actual = date_part(&input, DatePart::Week).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(1, actual.value(1)); + assert_eq!(5, actual.value(2)); + + let actual = date_part(&input, DatePart::Day).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(10, actual.value(1)); + assert_eq!(35, actual.value(2)); + + // Times follow from nanos, but are not affected by months or days. + let actual = date_part(&input, DatePart::Hour).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(24, actual.value(2)); + + let actual = date_part(&input, DatePart::Minute).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(24 * 60, actual.value(2)); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(24 * 60 * 60, actual.value(2)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(24 * 60 * 60 * 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + // Overflow gives zero. + assert_eq!(0, actual.value(2)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + // Overflow gives zero. + assert_eq!(0, actual.value(2)); + } + + #[test] + fn test_interval_array_invalid_parts() { + fn ensure_returns_error(array: &dyn Array) { + let invalid_parts = [ + DatePart::Quarter, + DatePart::DayOfWeekSunday0, + DatePart::DayOfWeekMonday0, + DatePart::DayOfYear, + ]; + + for part in invalid_parts { + let err = date_part(array, part).unwrap_err(); + let expected = format!( + "Compute error: {part} does not support: {}", + array.data_type() + ); + assert_eq!(expected, err.to_string()); + } + } + + ensure_returns_error(&IntervalYearMonthArray::from(vec![0])); + ensure_returns_error(&IntervalDayTimeArray::from(vec![IntervalDayTime::ZERO])); + ensure_returns_error(&IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::ZERO, + ])); + } + + #[test] + fn test_duration_second() { + let input: DurationSecondArray = vec![0, 42, 60 * 60 * 24 + 1].into(); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + assert_eq!(60 * 60 * 24 + 1, actual.value(2)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000_000, actual.value(1)); + assert_eq!(0, actual.value(2)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + } + + #[test] + fn test_duration_millisecond() { + let input: DurationMillisecondArray = vec![0, 42, 60 * 60 * 24 + 1].into(); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + assert_eq!(60 * 60 * 24 + 1, actual.value(2)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000_000, actual.value(1)); + assert_eq!(0, actual.value(2)); + } + + #[test] + fn test_duration_microsecond() { + let input: DurationMicrosecondArray = vec![0, 42, 60 * 60 * 24 + 1].into(); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + assert_eq!(60 * 60 * 24 + 1, actual.value(2)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42_000, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2)); + } + + #[test] + fn test_duration_nanosecond() { + let input: DurationNanosecondArray = vec![0, 42, 60 * 60 * 24 + 1].into(); + + let actual = date_part(&input, DatePart::Second).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + + let actual = date_part(&input, DatePart::Millisecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!(0, actual.value(2)); + + let actual = date_part(&input, DatePart::Microsecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(0, actual.value(1)); + assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2)); + + let actual = date_part(&input, DatePart::Nanosecond).unwrap(); + let actual = actual.as_primitive::(); + assert_eq!(0, actual.value(0)); + assert_eq!(42, actual.value(1)); + assert_eq!(60 * 60 * 24 + 1, actual.value(2)); + } + + #[test] + fn test_duration_invalid_parts() { + fn ensure_returns_error(array: &dyn Array) { + let invalid_parts = [ + DatePart::Year, + DatePart::Quarter, + DatePart::Month, + DatePart::DayOfWeekSunday0, + DatePart::DayOfWeekMonday0, + DatePart::DayOfYear, + ]; + + for part in invalid_parts { + let err = date_part(array, part).unwrap_err(); + let expected = format!( + "Compute error: {part} does not support: {}", + array.data_type() + ); + assert_eq!(expected, err.to_string()); + } + } + + ensure_returns_error(&DurationSecondArray::from(vec![0])); + ensure_returns_error(&DurationMillisecondArray::from(vec![0])); + ensure_returns_error(&DurationMicrosecondArray::from(vec![0])); + ensure_returns_error(&DurationNanosecondArray::from(vec![0])); + } +} diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml new file mode 100644 index 000000000000..d993d36b8d74 --- /dev/null +++ b/arrow-array/Cargo.toml @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-array" +version = { workspace = true } +description = "Array abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_array" +path = "src/lib.rs" +bench = false + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } + +[dependencies] +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } +arrow-data = { workspace = true } +chrono = { workspace = true } +chrono-tz = { version = "0.10", optional = true } +num = { version = "0.4.1", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false, features = ["num-traits"] } +hashbrown = { version = "0.14.2", default-features = false } + +[features] +ffi = ["arrow-schema/ffi", "arrow-data/ffi"] +force_validate = [] + +[dev-dependencies] +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } +criterion = { version = "0.5", default-features = false } + +[build-dependencies] + +[[bench]] +name = "occupancy" +harness = false + +[[bench]] +name = "gc_view_types" +harness = false + +[[bench]] +name = "fixed_size_list_array" +harness = false + +[[bench]] +name = "decimal_overflow" +harness = false diff --git a/arrow-array/benches/decimal_overflow.rs b/arrow-array/benches/decimal_overflow.rs new file mode 100644 index 000000000000..8f22b4b47c31 --- /dev/null +++ b/arrow-array/benches/decimal_overflow.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; +use arrow_buffer::i256; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let len = 8192; + let mut builder_128 = Decimal128Builder::with_capacity(len); + let mut builder_256 = Decimal256Builder::with_capacity(len); + for i in 0..len { + if i % 10 == 0 { + builder_128.append_value(i128::MAX); + builder_256.append_value(i256::from_i128(i128::MAX)); + } else { + builder_128.append_value(i as i128); + builder_256.append_value(i256::from_i128(i as i128)); + } + } + let array_128 = builder_128.finish(); + let array_256 = builder_256.finish(); + + c.bench_function("validate_decimal_precision_128", |b| { + b.iter(|| black_box(array_128.validate_decimal_precision(8))); + }); + c.bench_function("null_if_overflow_precision_128", |b| { + b.iter(|| black_box(array_128.null_if_overflow_precision(8))); + }); + c.bench_function("validate_decimal_precision_256", |b| { + b.iter(|| black_box(array_256.validate_decimal_precision(8))); + }); + c.bench_function("null_if_overflow_precision_256", |b| { + b.iter(|| black_box(array_256.null_if_overflow_precision(8))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/benches/fixed_size_list_array.rs b/arrow-array/benches/fixed_size_list_array.rs new file mode 100644 index 000000000000..5f001a4f3d3a --- /dev/null +++ b/arrow-array/benches/fixed_size_list_array.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{Array, FixedSizeListArray, Int32Array}; +use arrow_schema::Field; +use criterion::*; +use rand::{thread_rng, Rng}; +use std::sync::Arc; + +fn gen_fsl(len: usize, value_len: usize) -> FixedSizeListArray { + let mut rng = thread_rng(); + let values = Arc::new(Int32Array::from( + (0..len).map(|_| rng.gen::()).collect::>(), + )); + let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + FixedSizeListArray::new(field, value_len as i32, values, None) +} + +fn criterion_benchmark(c: &mut Criterion) { + let len = 4096; + for value_len in [1, 32, 1024] { + let fsl = gen_fsl(len, value_len); + c.bench_function( + &format!("fixed_size_list_array(len: {len}, value_len: {value_len})"), + |b| { + b.iter(|| { + for i in 0..len / value_len { + black_box(fsl.value(i)); + } + }); + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/benches/gc_view_types.rs b/arrow-array/benches/gc_view_types.rs new file mode 100644 index 000000000000..4b74a8f60b06 --- /dev/null +++ b/arrow-array/benches/gc_view_types.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::StringViewArray; +use criterion::*; + +fn gen_view_array(size: usize) -> StringViewArray { + StringViewArray::from_iter((0..size).map(|v| match v % 3 { + 0 => Some("small"), + 1 => Some("larger than 12 bytes array"), + 2 => None, + _ => unreachable!("unreachable"), + })) +} + +fn criterion_benchmark(c: &mut Criterion) { + let array = gen_view_array(100_000); + + c.bench_function("gc view types all", |b| { + b.iter(|| { + black_box(array.gc()); + }); + }); + + let sliced = array.slice(0, 100_000 / 2); + c.bench_function("gc view types slice half", |b| { + b.iter(|| { + black_box(sliced.gc()); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/benches/occupancy.rs b/arrow-array/benches/occupancy.rs new file mode 100644 index 000000000000..ed4b94351c28 --- /dev/null +++ b/arrow-array/benches/occupancy.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Int32Type; +use arrow_array::{DictionaryArray, Int32Array}; +use arrow_buffer::NullBuffer; +use criterion::*; +use rand::{thread_rng, Rng}; +use std::sync::Arc; + +fn gen_dict( + len: usize, + values_len: usize, + occupancy: f64, + null_percent: f64, +) -> DictionaryArray { + let mut rng = thread_rng(); + let values = Int32Array::from(vec![0; values_len]); + let max_key = (values_len as f64 * occupancy) as i32; + let keys = (0..len).map(|_| rng.gen_range(0..max_key)).collect(); + let nulls = (0..len).map(|_| !rng.gen_bool(null_percent)).collect(); + + let keys = Int32Array::new(keys, Some(NullBuffer::new(nulls))); + DictionaryArray::new(keys, Arc::new(values)) +} + +fn criterion_benchmark(c: &mut Criterion) { + for values in [10, 100, 512] { + for occupancy in [1., 0.5, 0.1] { + for null_percent in [0.0, 0.1, 0.5, 0.9] { + let dict = gen_dict(1024, values, occupancy, null_percent); + c.bench_function(&format!("occupancy(values: {values}, occupancy: {occupancy}, null_percent: {null_percent})"), |b| { + b.iter(|| { + black_box(&dict).occupancy() + }); + }); + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/src/arithmetic.rs b/arrow-array/src/arithmetic.rs new file mode 100644 index 000000000000..fb9c868fb6c0 --- /dev/null +++ b/arrow-array/src/arithmetic.rs @@ -0,0 +1,867 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{i256, ArrowNativeType, IntervalDayTime, IntervalMonthDayNano}; +use arrow_schema::ArrowError; +use half::f16; +use num::complex::ComplexFloat; +use std::cmp::Ordering; + +/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations, +/// and totally ordered comparison operations +/// +/// The APIs with `_wrapping` suffix do not perform overflow-checking. For integer +/// types they will wrap around the boundary of the type. For floating point types they +/// will overflow to INF or -INF preserving the expected sign value +/// +/// Note `div_wrapping` and `mod_wrapping` will panic for integer types if `rhs` is zero +/// although this may be subject to change +/// +/// The APIs with `_checked` suffix perform overflow-checking. For integer types +/// these will return `Err` instead of wrapping. For floating point types they will +/// overflow to INF or -INF preserving the expected sign value +/// +/// Comparison of integer types is as per normal integer comparison rules, floating +/// point values are compared as per IEEE 754's totalOrder predicate see [`f32::total_cmp`] +/// +pub trait ArrowNativeTypeOp: ArrowNativeType { + /// The additive identity + const ZERO: Self; + + /// The multiplicative identity + const ONE: Self; + + /// The minimum value and identity for the `max` aggregation. + /// Note that the aggregation uses the total order predicate for floating point values, + /// which means that this value is a negative NaN. + const MIN_TOTAL_ORDER: Self; + + /// The maximum value and identity for the `min` aggregation. + /// Note that the aggregation uses the total order predicate for floating point values, + /// which means that this value is a positive NaN. + const MAX_TOTAL_ORDER: Self; + + /// Checked addition operation + fn add_checked(self, rhs: Self) -> Result; + + /// Wrapping addition operation + fn add_wrapping(self, rhs: Self) -> Self; + + /// Checked subtraction operation + fn sub_checked(self, rhs: Self) -> Result; + + /// Wrapping subtraction operation + fn sub_wrapping(self, rhs: Self) -> Self; + + /// Checked multiplication operation + fn mul_checked(self, rhs: Self) -> Result; + + /// Wrapping multiplication operation + fn mul_wrapping(self, rhs: Self) -> Self; + + /// Checked division operation + fn div_checked(self, rhs: Self) -> Result; + + /// Wrapping division operation + fn div_wrapping(self, rhs: Self) -> Self; + + /// Checked remainder operation + fn mod_checked(self, rhs: Self) -> Result; + + /// Wrapping remainder operation + fn mod_wrapping(self, rhs: Self) -> Self; + + /// Checked negation operation + fn neg_checked(self) -> Result; + + /// Wrapping negation operation + fn neg_wrapping(self) -> Self; + + /// Checked exponentiation operation + fn pow_checked(self, exp: u32) -> Result; + + /// Wrapping exponentiation operation + fn pow_wrapping(self, exp: u32) -> Self; + + /// Returns true if zero else false + fn is_zero(self) -> bool; + + /// Compare operation + fn compare(self, rhs: Self) -> Ordering; + + /// Equality operation + fn is_eq(self, rhs: Self) -> bool; + + /// Not equal operation + #[inline] + fn is_ne(self, rhs: Self) -> bool { + !self.is_eq(rhs) + } + + /// Less than operation + #[inline] + fn is_lt(self, rhs: Self) -> bool { + self.compare(rhs).is_lt() + } + + /// Less than equals operation + #[inline] + fn is_le(self, rhs: Self) -> bool { + self.compare(rhs).is_le() + } + + /// Greater than operation + #[inline] + fn is_gt(self, rhs: Self) -> bool { + self.compare(rhs).is_gt() + } + + /// Greater than equals operation + #[inline] + fn is_ge(self, rhs: Self) -> bool { + self.compare(rhs).is_ge() + } +} + +macro_rules! native_type_op { + ($t:tt) => { + native_type_op!($t, 0, 1); + }; + ($t:tt, $zero:expr, $one: expr) => { + native_type_op!($t, $zero, $one, $t::MIN, $t::MAX); + }; + ($t:tt, $zero:expr, $one: expr, $min: expr, $max: expr) => { + impl ArrowNativeTypeOp for $t { + const ZERO: Self = $zero; + const ONE: Self = $one; + const MIN_TOTAL_ORDER: Self = $min; + const MAX_TOTAL_ORDER: Self = $max; + + #[inline] + fn add_checked(self, rhs: Self) -> Result { + self.checked_add(rhs).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} + {:?}", + self, rhs + )) + }) + } + + #[inline] + fn add_wrapping(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline] + fn sub_checked(self, rhs: Self) -> Result { + self.checked_sub(rhs).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} - {:?}", + self, rhs + )) + }) + } + + #[inline] + fn sub_wrapping(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline] + fn mul_checked(self, rhs: Self) -> Result { + self.checked_mul(rhs).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} * {:?}", + self, rhs + )) + }) + } + + #[inline] + fn mul_wrapping(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline] + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_div(rhs).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} / {:?}", + self, rhs + )) + }) + } + } + + #[inline] + fn div_wrapping(self, rhs: Self) -> Self { + self.wrapping_div(rhs) + } + + #[inline] + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_rem(rhs).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} % {:?}", + self, rhs + )) + }) + } + } + + #[inline] + fn mod_wrapping(self, rhs: Self) -> Self { + self.wrapping_rem(rhs) + } + + #[inline] + fn neg_checked(self) -> Result { + self.checked_neg().ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!("Overflow happened on: - {:?}", self)) + }) + } + + #[inline] + fn pow_checked(self, exp: u32) -> Result { + self.checked_pow(exp).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Overflow happened on: {:?} ^ {exp:?}", + self + )) + }) + } + + #[inline] + fn pow_wrapping(self, exp: u32) -> Self { + self.wrapping_pow(exp) + } + + #[inline] + fn neg_wrapping(self) -> Self { + self.wrapping_neg() + } + + #[inline] + fn is_zero(self) -> bool { + self == Self::ZERO + } + + #[inline] + fn compare(self, rhs: Self) -> Ordering { + self.cmp(&rhs) + } + + #[inline] + fn is_eq(self, rhs: Self) -> bool { + self == rhs + } + } + }; +} + +native_type_op!(i8); +native_type_op!(i16); +native_type_op!(i32); +native_type_op!(i64); +native_type_op!(i128); +native_type_op!(u8); +native_type_op!(u16); +native_type_op!(u32); +native_type_op!(u64); +native_type_op!(i256, i256::ZERO, i256::ONE, i256::MIN, i256::MAX); + +native_type_op!(IntervalDayTime, IntervalDayTime::ZERO, IntervalDayTime::ONE); +native_type_op!( + IntervalMonthDayNano, + IntervalMonthDayNano::ZERO, + IntervalMonthDayNano::ONE +); + +macro_rules! native_type_float_op { + ($t:tt, $zero:expr, $one:expr, $min:expr, $max:expr) => { + impl ArrowNativeTypeOp for $t { + const ZERO: Self = $zero; + const ONE: Self = $one; + const MIN_TOTAL_ORDER: Self = $min; + const MAX_TOTAL_ORDER: Self = $max; + + #[inline] + fn add_checked(self, rhs: Self) -> Result { + Ok(self + rhs) + } + + #[inline] + fn add_wrapping(self, rhs: Self) -> Self { + self + rhs + } + + #[inline] + fn sub_checked(self, rhs: Self) -> Result { + Ok(self - rhs) + } + + #[inline] + fn sub_wrapping(self, rhs: Self) -> Self { + self - rhs + } + + #[inline] + fn mul_checked(self, rhs: Self) -> Result { + Ok(self * rhs) + } + + #[inline] + fn mul_wrapping(self, rhs: Self) -> Self { + self * rhs + } + + #[inline] + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self / rhs) + } + } + + #[inline] + fn div_wrapping(self, rhs: Self) -> Self { + self / rhs + } + + #[inline] + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self % rhs) + } + } + + #[inline] + fn mod_wrapping(self, rhs: Self) -> Self { + self % rhs + } + + #[inline] + fn neg_checked(self) -> Result { + Ok(-self) + } + + #[inline] + fn neg_wrapping(self) -> Self { + -self + } + + #[inline] + fn pow_checked(self, exp: u32) -> Result { + Ok(self.powi(exp as i32)) + } + + #[inline] + fn pow_wrapping(self, exp: u32) -> Self { + self.powi(exp as i32) + } + + #[inline] + fn is_zero(self) -> bool { + self == $zero + } + + #[inline] + fn compare(self, rhs: Self) -> Ordering { + <$t>::total_cmp(&self, &rhs) + } + + #[inline] + fn is_eq(self, rhs: Self) -> bool { + // Equivalent to `self.total_cmp(&rhs).is_eq()` + // but LLVM isn't able to realise this is bitwise equality + // https://rust.godbolt.org/z/347nWGxoW + self.to_bits() == rhs.to_bits() + } + } + }; +} + +// the smallest/largest bit patterns for floating point numbers are NaN, but differ from the canonical NAN constants. +// See test_float_total_order_min_max for details. +native_type_float_op!( + f16, + f16::ZERO, + f16::ONE, + f16::from_bits(-1 as _), + f16::from_bits(i16::MAX as _) +); +// from_bits is not yet stable as const fn, see https://github.com/rust-lang/rust/issues/72447 +native_type_float_op!( + f32, + 0., + 1., + unsafe { std::mem::transmute(-1_i32) }, + unsafe { std::mem::transmute(i32::MAX) } +); +native_type_float_op!( + f64, + 0., + 1., + unsafe { std::mem::transmute(-1_i64) }, + unsafe { std::mem::transmute(i64::MAX) } +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_native_type_is_zero() { + assert!(0_i8.is_zero()); + assert!(0_i16.is_zero()); + assert!(0_i32.is_zero()); + assert!(0_i64.is_zero()); + assert!(0_i128.is_zero()); + assert!(i256::ZERO.is_zero()); + assert!(0_u8.is_zero()); + assert!(0_u16.is_zero()); + assert!(0_u32.is_zero()); + assert!(0_u64.is_zero()); + assert!(f16::ZERO.is_zero()); + assert!(0.0_f32.is_zero()); + assert!(0.0_f64.is_zero()); + } + + #[test] + fn test_native_type_comparison() { + // is_eq + assert!(8_i8.is_eq(8_i8)); + assert!(8_i16.is_eq(8_i16)); + assert!(8_i32.is_eq(8_i32)); + assert!(8_i64.is_eq(8_i64)); + assert!(8_i128.is_eq(8_i128)); + assert!(i256::from_parts(8, 0).is_eq(i256::from_parts(8, 0))); + assert!(8_u8.is_eq(8_u8)); + assert!(8_u16.is_eq(8_u16)); + assert!(8_u32.is_eq(8_u32)); + assert!(8_u64.is_eq(8_u64)); + assert!(f16::from_f32(8.0).is_eq(f16::from_f32(8.0))); + assert!(8.0_f32.is_eq(8.0_f32)); + assert!(8.0_f64.is_eq(8.0_f64)); + + // is_ne + assert!(8_i8.is_ne(1_i8)); + assert!(8_i16.is_ne(1_i16)); + assert!(8_i32.is_ne(1_i32)); + assert!(8_i64.is_ne(1_i64)); + assert!(8_i128.is_ne(1_i128)); + assert!(i256::from_parts(8, 0).is_ne(i256::from_parts(1, 0))); + assert!(8_u8.is_ne(1_u8)); + assert!(8_u16.is_ne(1_u16)); + assert!(8_u32.is_ne(1_u32)); + assert!(8_u64.is_ne(1_u64)); + assert!(f16::from_f32(8.0).is_ne(f16::from_f32(1.0))); + assert!(8.0_f32.is_ne(1.0_f32)); + assert!(8.0_f64.is_ne(1.0_f64)); + + // is_lt + assert!(8_i8.is_lt(10_i8)); + assert!(8_i16.is_lt(10_i16)); + assert!(8_i32.is_lt(10_i32)); + assert!(8_i64.is_lt(10_i64)); + assert!(8_i128.is_lt(10_i128)); + assert!(i256::from_parts(8, 0).is_lt(i256::from_parts(10, 0))); + assert!(8_u8.is_lt(10_u8)); + assert!(8_u16.is_lt(10_u16)); + assert!(8_u32.is_lt(10_u32)); + assert!(8_u64.is_lt(10_u64)); + assert!(f16::from_f32(8.0).is_lt(f16::from_f32(10.0))); + assert!(8.0_f32.is_lt(10.0_f32)); + assert!(8.0_f64.is_lt(10.0_f64)); + + // is_gt + assert!(8_i8.is_gt(1_i8)); + assert!(8_i16.is_gt(1_i16)); + assert!(8_i32.is_gt(1_i32)); + assert!(8_i64.is_gt(1_i64)); + assert!(8_i128.is_gt(1_i128)); + assert!(i256::from_parts(8, 0).is_gt(i256::from_parts(1, 0))); + assert!(8_u8.is_gt(1_u8)); + assert!(8_u16.is_gt(1_u16)); + assert!(8_u32.is_gt(1_u32)); + assert!(8_u64.is_gt(1_u64)); + assert!(f16::from_f32(8.0).is_gt(f16::from_f32(1.0))); + assert!(8.0_f32.is_gt(1.0_f32)); + assert!(8.0_f64.is_gt(1.0_f64)); + } + + #[test] + fn test_native_type_add() { + // add_wrapping + assert_eq!(8_i8.add_wrapping(2_i8), 10_i8); + assert_eq!(8_i16.add_wrapping(2_i16), 10_i16); + assert_eq!(8_i32.add_wrapping(2_i32), 10_i32); + assert_eq!(8_i64.add_wrapping(2_i64), 10_i64); + assert_eq!(8_i128.add_wrapping(2_i128), 10_i128); + assert_eq!( + i256::from_parts(8, 0).add_wrapping(i256::from_parts(2, 0)), + i256::from_parts(10, 0) + ); + assert_eq!(8_u8.add_wrapping(2_u8), 10_u8); + assert_eq!(8_u16.add_wrapping(2_u16), 10_u16); + assert_eq!(8_u32.add_wrapping(2_u32), 10_u32); + assert_eq!(8_u64.add_wrapping(2_u64), 10_u64); + assert_eq!( + f16::from_f32(8.0).add_wrapping(f16::from_f32(2.0)), + f16::from_f32(10.0) + ); + assert_eq!(8.0_f32.add_wrapping(2.0_f32), 10_f32); + assert_eq!(8.0_f64.add_wrapping(2.0_f64), 10_f64); + + // add_checked + assert_eq!(8_i8.add_checked(2_i8).unwrap(), 10_i8); + assert_eq!(8_i16.add_checked(2_i16).unwrap(), 10_i16); + assert_eq!(8_i32.add_checked(2_i32).unwrap(), 10_i32); + assert_eq!(8_i64.add_checked(2_i64).unwrap(), 10_i64); + assert_eq!(8_i128.add_checked(2_i128).unwrap(), 10_i128); + assert_eq!( + i256::from_parts(8, 0) + .add_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(10, 0) + ); + assert_eq!(8_u8.add_checked(2_u8).unwrap(), 10_u8); + assert_eq!(8_u16.add_checked(2_u16).unwrap(), 10_u16); + assert_eq!(8_u32.add_checked(2_u32).unwrap(), 10_u32); + assert_eq!(8_u64.add_checked(2_u64).unwrap(), 10_u64); + assert_eq!( + f16::from_f32(8.0).add_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(10.0) + ); + assert_eq!(8.0_f32.add_checked(2.0_f32).unwrap(), 10_f32); + assert_eq!(8.0_f64.add_checked(2.0_f64).unwrap(), 10_f64); + } + + #[test] + fn test_native_type_sub() { + // sub_wrapping + assert_eq!(8_i8.sub_wrapping(2_i8), 6_i8); + assert_eq!(8_i16.sub_wrapping(2_i16), 6_i16); + assert_eq!(8_i32.sub_wrapping(2_i32), 6_i32); + assert_eq!(8_i64.sub_wrapping(2_i64), 6_i64); + assert_eq!(8_i128.sub_wrapping(2_i128), 6_i128); + assert_eq!( + i256::from_parts(8, 0).sub_wrapping(i256::from_parts(2, 0)), + i256::from_parts(6, 0) + ); + assert_eq!(8_u8.sub_wrapping(2_u8), 6_u8); + assert_eq!(8_u16.sub_wrapping(2_u16), 6_u16); + assert_eq!(8_u32.sub_wrapping(2_u32), 6_u32); + assert_eq!(8_u64.sub_wrapping(2_u64), 6_u64); + assert_eq!( + f16::from_f32(8.0).sub_wrapping(f16::from_f32(2.0)), + f16::from_f32(6.0) + ); + assert_eq!(8.0_f32.sub_wrapping(2.0_f32), 6_f32); + assert_eq!(8.0_f64.sub_wrapping(2.0_f64), 6_f64); + + // sub_checked + assert_eq!(8_i8.sub_checked(2_i8).unwrap(), 6_i8); + assert_eq!(8_i16.sub_checked(2_i16).unwrap(), 6_i16); + assert_eq!(8_i32.sub_checked(2_i32).unwrap(), 6_i32); + assert_eq!(8_i64.sub_checked(2_i64).unwrap(), 6_i64); + assert_eq!(8_i128.sub_checked(2_i128).unwrap(), 6_i128); + assert_eq!( + i256::from_parts(8, 0) + .sub_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(6, 0) + ); + assert_eq!(8_u8.sub_checked(2_u8).unwrap(), 6_u8); + assert_eq!(8_u16.sub_checked(2_u16).unwrap(), 6_u16); + assert_eq!(8_u32.sub_checked(2_u32).unwrap(), 6_u32); + assert_eq!(8_u64.sub_checked(2_u64).unwrap(), 6_u64); + assert_eq!( + f16::from_f32(8.0).sub_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(6.0) + ); + assert_eq!(8.0_f32.sub_checked(2.0_f32).unwrap(), 6_f32); + assert_eq!(8.0_f64.sub_checked(2.0_f64).unwrap(), 6_f64); + } + + #[test] + fn test_native_type_mul() { + // mul_wrapping + assert_eq!(8_i8.mul_wrapping(2_i8), 16_i8); + assert_eq!(8_i16.mul_wrapping(2_i16), 16_i16); + assert_eq!(8_i32.mul_wrapping(2_i32), 16_i32); + assert_eq!(8_i64.mul_wrapping(2_i64), 16_i64); + assert_eq!(8_i128.mul_wrapping(2_i128), 16_i128); + assert_eq!( + i256::from_parts(8, 0).mul_wrapping(i256::from_parts(2, 0)), + i256::from_parts(16, 0) + ); + assert_eq!(8_u8.mul_wrapping(2_u8), 16_u8); + assert_eq!(8_u16.mul_wrapping(2_u16), 16_u16); + assert_eq!(8_u32.mul_wrapping(2_u32), 16_u32); + assert_eq!(8_u64.mul_wrapping(2_u64), 16_u64); + assert_eq!( + f16::from_f32(8.0).mul_wrapping(f16::from_f32(2.0)), + f16::from_f32(16.0) + ); + assert_eq!(8.0_f32.mul_wrapping(2.0_f32), 16_f32); + assert_eq!(8.0_f64.mul_wrapping(2.0_f64), 16_f64); + + // mul_checked + assert_eq!(8_i8.mul_checked(2_i8).unwrap(), 16_i8); + assert_eq!(8_i16.mul_checked(2_i16).unwrap(), 16_i16); + assert_eq!(8_i32.mul_checked(2_i32).unwrap(), 16_i32); + assert_eq!(8_i64.mul_checked(2_i64).unwrap(), 16_i64); + assert_eq!(8_i128.mul_checked(2_i128).unwrap(), 16_i128); + assert_eq!( + i256::from_parts(8, 0) + .mul_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(16, 0) + ); + assert_eq!(8_u8.mul_checked(2_u8).unwrap(), 16_u8); + assert_eq!(8_u16.mul_checked(2_u16).unwrap(), 16_u16); + assert_eq!(8_u32.mul_checked(2_u32).unwrap(), 16_u32); + assert_eq!(8_u64.mul_checked(2_u64).unwrap(), 16_u64); + assert_eq!( + f16::from_f32(8.0).mul_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(16.0) + ); + assert_eq!(8.0_f32.mul_checked(2.0_f32).unwrap(), 16_f32); + assert_eq!(8.0_f64.mul_checked(2.0_f64).unwrap(), 16_f64); + } + + #[test] + fn test_native_type_div() { + // div_wrapping + assert_eq!(8_i8.div_wrapping(2_i8), 4_i8); + assert_eq!(8_i16.div_wrapping(2_i16), 4_i16); + assert_eq!(8_i32.div_wrapping(2_i32), 4_i32); + assert_eq!(8_i64.div_wrapping(2_i64), 4_i64); + assert_eq!(8_i128.div_wrapping(2_i128), 4_i128); + assert_eq!( + i256::from_parts(8, 0).div_wrapping(i256::from_parts(2, 0)), + i256::from_parts(4, 0) + ); + assert_eq!(8_u8.div_wrapping(2_u8), 4_u8); + assert_eq!(8_u16.div_wrapping(2_u16), 4_u16); + assert_eq!(8_u32.div_wrapping(2_u32), 4_u32); + assert_eq!(8_u64.div_wrapping(2_u64), 4_u64); + assert_eq!( + f16::from_f32(8.0).div_wrapping(f16::from_f32(2.0)), + f16::from_f32(4.0) + ); + assert_eq!(8.0_f32.div_wrapping(2.0_f32), 4_f32); + assert_eq!(8.0_f64.div_wrapping(2.0_f64), 4_f64); + + // div_checked + assert_eq!(8_i8.div_checked(2_i8).unwrap(), 4_i8); + assert_eq!(8_i16.div_checked(2_i16).unwrap(), 4_i16); + assert_eq!(8_i32.div_checked(2_i32).unwrap(), 4_i32); + assert_eq!(8_i64.div_checked(2_i64).unwrap(), 4_i64); + assert_eq!(8_i128.div_checked(2_i128).unwrap(), 4_i128); + assert_eq!( + i256::from_parts(8, 0) + .div_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(4, 0) + ); + assert_eq!(8_u8.div_checked(2_u8).unwrap(), 4_u8); + assert_eq!(8_u16.div_checked(2_u16).unwrap(), 4_u16); + assert_eq!(8_u32.div_checked(2_u32).unwrap(), 4_u32); + assert_eq!(8_u64.div_checked(2_u64).unwrap(), 4_u64); + assert_eq!( + f16::from_f32(8.0).div_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(4.0) + ); + assert_eq!(8.0_f32.div_checked(2.0_f32).unwrap(), 4_f32); + assert_eq!(8.0_f64.div_checked(2.0_f64).unwrap(), 4_f64); + } + + #[test] + fn test_native_type_mod() { + // mod_wrapping + assert_eq!(9_i8.mod_wrapping(2_i8), 1_i8); + assert_eq!(9_i16.mod_wrapping(2_i16), 1_i16); + assert_eq!(9_i32.mod_wrapping(2_i32), 1_i32); + assert_eq!(9_i64.mod_wrapping(2_i64), 1_i64); + assert_eq!(9_i128.mod_wrapping(2_i128), 1_i128); + assert_eq!( + i256::from_parts(9, 0).mod_wrapping(i256::from_parts(2, 0)), + i256::from_parts(1, 0) + ); + assert_eq!(9_u8.mod_wrapping(2_u8), 1_u8); + assert_eq!(9_u16.mod_wrapping(2_u16), 1_u16); + assert_eq!(9_u32.mod_wrapping(2_u32), 1_u32); + assert_eq!(9_u64.mod_wrapping(2_u64), 1_u64); + assert_eq!( + f16::from_f32(9.0).mod_wrapping(f16::from_f32(2.0)), + f16::from_f32(1.0) + ); + assert_eq!(9.0_f32.mod_wrapping(2.0_f32), 1_f32); + assert_eq!(9.0_f64.mod_wrapping(2.0_f64), 1_f64); + + // mod_checked + assert_eq!(9_i8.mod_checked(2_i8).unwrap(), 1_i8); + assert_eq!(9_i16.mod_checked(2_i16).unwrap(), 1_i16); + assert_eq!(9_i32.mod_checked(2_i32).unwrap(), 1_i32); + assert_eq!(9_i64.mod_checked(2_i64).unwrap(), 1_i64); + assert_eq!(9_i128.mod_checked(2_i128).unwrap(), 1_i128); + assert_eq!( + i256::from_parts(9, 0) + .mod_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(1, 0) + ); + assert_eq!(9_u8.mod_checked(2_u8).unwrap(), 1_u8); + assert_eq!(9_u16.mod_checked(2_u16).unwrap(), 1_u16); + assert_eq!(9_u32.mod_checked(2_u32).unwrap(), 1_u32); + assert_eq!(9_u64.mod_checked(2_u64).unwrap(), 1_u64); + assert_eq!( + f16::from_f32(9.0).mod_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(1.0) + ); + assert_eq!(9.0_f32.mod_checked(2.0_f32).unwrap(), 1_f32); + assert_eq!(9.0_f64.mod_checked(2.0_f64).unwrap(), 1_f64); + } + + #[test] + fn test_native_type_neg() { + // neg_wrapping + assert_eq!(8_i8.neg_wrapping(), -8_i8); + assert_eq!(8_i16.neg_wrapping(), -8_i16); + assert_eq!(8_i32.neg_wrapping(), -8_i32); + assert_eq!(8_i64.neg_wrapping(), -8_i64); + assert_eq!(8_i128.neg_wrapping(), -8_i128); + assert_eq!(i256::from_parts(8, 0).neg_wrapping(), i256::from_i128(-8)); + assert_eq!(8_u8.neg_wrapping(), u8::MAX - 7_u8); + assert_eq!(8_u16.neg_wrapping(), u16::MAX - 7_u16); + assert_eq!(8_u32.neg_wrapping(), u32::MAX - 7_u32); + assert_eq!(8_u64.neg_wrapping(), u64::MAX - 7_u64); + assert_eq!(f16::from_f32(8.0).neg_wrapping(), f16::from_f32(-8.0)); + assert_eq!(8.0_f32.neg_wrapping(), -8_f32); + assert_eq!(8.0_f64.neg_wrapping(), -8_f64); + + // neg_checked + assert_eq!(8_i8.neg_checked().unwrap(), -8_i8); + assert_eq!(8_i16.neg_checked().unwrap(), -8_i16); + assert_eq!(8_i32.neg_checked().unwrap(), -8_i32); + assert_eq!(8_i64.neg_checked().unwrap(), -8_i64); + assert_eq!(8_i128.neg_checked().unwrap(), -8_i128); + assert_eq!( + i256::from_parts(8, 0).neg_checked().unwrap(), + i256::from_i128(-8) + ); + assert!(8_u8.neg_checked().is_err()); + assert!(8_u16.neg_checked().is_err()); + assert!(8_u32.neg_checked().is_err()); + assert!(8_u64.neg_checked().is_err()); + assert_eq!( + f16::from_f32(8.0).neg_checked().unwrap(), + f16::from_f32(-8.0) + ); + assert_eq!(8.0_f32.neg_checked().unwrap(), -8_f32); + assert_eq!(8.0_f64.neg_checked().unwrap(), -8_f64); + } + + #[test] + fn test_native_type_pow() { + // pow_wrapping + assert_eq!(8_i8.pow_wrapping(2_u32), 64_i8); + assert_eq!(8_i16.pow_wrapping(2_u32), 64_i16); + assert_eq!(8_i32.pow_wrapping(2_u32), 64_i32); + assert_eq!(8_i64.pow_wrapping(2_u32), 64_i64); + assert_eq!(8_i128.pow_wrapping(2_u32), 64_i128); + assert_eq!( + i256::from_parts(8, 0).pow_wrapping(2_u32), + i256::from_parts(64, 0) + ); + assert_eq!(8_u8.pow_wrapping(2_u32), 64_u8); + assert_eq!(8_u16.pow_wrapping(2_u32), 64_u16); + assert_eq!(8_u32.pow_wrapping(2_u32), 64_u32); + assert_eq!(8_u64.pow_wrapping(2_u32), 64_u64); + assert_eq!(f16::from_f32(8.0).pow_wrapping(2_u32), f16::from_f32(64.0)); + assert_eq!(8.0_f32.pow_wrapping(2_u32), 64_f32); + assert_eq!(8.0_f64.pow_wrapping(2_u32), 64_f64); + + // pow_checked + assert_eq!(8_i8.pow_checked(2_u32).unwrap(), 64_i8); + assert_eq!(8_i16.pow_checked(2_u32).unwrap(), 64_i16); + assert_eq!(8_i32.pow_checked(2_u32).unwrap(), 64_i32); + assert_eq!(8_i64.pow_checked(2_u32).unwrap(), 64_i64); + assert_eq!(8_i128.pow_checked(2_u32).unwrap(), 64_i128); + assert_eq!( + i256::from_parts(8, 0).pow_checked(2_u32).unwrap(), + i256::from_parts(64, 0) + ); + assert_eq!(8_u8.pow_checked(2_u32).unwrap(), 64_u8); + assert_eq!(8_u16.pow_checked(2_u32).unwrap(), 64_u16); + assert_eq!(8_u32.pow_checked(2_u32).unwrap(), 64_u32); + assert_eq!(8_u64.pow_checked(2_u32).unwrap(), 64_u64); + assert_eq!( + f16::from_f32(8.0).pow_checked(2_u32).unwrap(), + f16::from_f32(64.0) + ); + assert_eq!(8.0_f32.pow_checked(2_u32).unwrap(), 64_f32); + assert_eq!(8.0_f64.pow_checked(2_u32).unwrap(), 64_f64); + } + + #[test] + fn test_float_total_order_min_max() { + assert!(::MIN_TOTAL_ORDER.is_lt(f64::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f64::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f64::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f64::NAN)); + + assert!(::MIN_TOTAL_ORDER.is_lt(f32::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f32::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f32::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f32::NAN)); + + assert!(::MIN_TOTAL_ORDER.is_lt(f16::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f16::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f16::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f16::NAN)); + } +} diff --git a/arrow/src/array/array_binary.rs b/arrow-array/src/array/binary_array.rs similarity index 60% rename from arrow/src/array/array_binary.rs rename to arrow-array/src/array/binary_array.rs index 1c63e8e24b29..8f8a39b2093f 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow-array/src/array/binary_array.rs @@ -15,120 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::convert::From; -use std::fmt; -use std::{any::Any, iter::FromIterator}; - -use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, GenericBinaryIter, - GenericListArray, OffsetSizeTrait, -}; -use crate::array::array::ArrayAccessor; -use crate::buffer::Buffer; -use crate::util::bit_util; -use crate::{buffer::MutableBuffer, datatypes::DataType}; - -/// See [`BinaryArray`] and [`LargeBinaryArray`] for storing -/// binary data. -pub struct GenericBinaryArray { - data: ArrayData, - value_offsets: RawPtrBox, - value_data: RawPtrBox, -} +use crate::types::{ByteArrayType, GenericBinaryType}; +use crate::{Array, GenericByteArray, GenericListArray, GenericStringArray, OffsetSizeTrait}; +use arrow_data::ArrayData; +use arrow_schema::DataType; -impl GenericBinaryArray { - /// Data type of the array. - pub const DATA_TYPE: DataType = if OffsetSize::IS_LARGE { - DataType::LargeBinary - } else { - DataType::Binary - }; +/// A [`GenericBinaryArray`] for storing `[u8]` +pub type GenericBinaryArray = GenericByteArray>; +impl GenericBinaryArray { /// Get the data type of the array. #[deprecated(note = "please use `Self::DATA_TYPE` instead")] pub const fn get_data_type() -> DataType { Self::DATA_TYPE } - /// Returns the length for value at index `i`. - #[inline] - pub fn value_length(&self, i: usize) -> OffsetSize { - let offsets = self.value_offsets(); - offsets[i + 1] - offsets[i] - } - - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[1].clone() - } - - /// Returns the offset values in the offsets buffer - #[inline] - pub fn value_offsets(&self) -> &[OffsetSize] { - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the ArrayData instance. - unsafe { - std::slice::from_raw_parts( - self.value_offsets.as_ptr().add(self.data.offset()), - self.len() + 1, - ) - } - } - - /// Returns the element at index `i` as bytes slice - /// # Safety - /// Caller is responsible for ensuring that the index is within the bounds of the array - pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { - let end = *self.value_offsets().get_unchecked(i + 1); - let start = *self.value_offsets().get_unchecked(i); - - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the value_offset invariants - - // Safety of `to_isize().unwrap()` - // `start` and `end` are &OffsetSize, which is a generic type that implements the - // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, - // both of which should cleanly cast to isize on an architecture that supports - // 32/64-bit offsets - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(start.to_isize().unwrap()), - (end - start).to_usize().unwrap(), - ) - } - - /// Returns the element at index `i` as bytes slice - /// # Panics - /// Panics if index `i` is out of bounds. - pub fn value(&self, i: usize) -> &[u8] { - assert!( - i < self.data.len(), - "Trying to access an element at index {} from a BinaryArray of length {}", - i, - self.len() - ); - //Soundness: length checked above, offset buffer length is 1 larger than logical array length - let end = unsafe { self.value_offsets().get_unchecked(i + 1) }; - let start = unsafe { self.value_offsets().get_unchecked(i) }; - - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the value_offset invariants - - // Safety of `to_isize().unwrap()` - // `start` and `end` are &OffsetSize, which is a generic type that implements the - // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, - // both of which should cleanly cast to isize on an architecture that supports - // 32/64-bit offsets - unsafe { - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(start.to_isize().unwrap()), - (*end - *start).to_usize().unwrap(), - ) - } - } - /// Creates a [GenericBinaryArray] from a vector of byte slices /// /// See also [`Self::from_iter_values`] @@ -142,13 +43,14 @@ impl GenericBinaryArray { } fn from_list(v: GenericListArray) -> Self { + let v = v.into_data(); assert_eq!( - v.data_ref().child_data().len(), + v.child_data().len(), 1, "BinaryArray can only be created from list array of u8 values \ (i.e. List>)." ); - let child_data = &v.data_ref().child_data()[0]; + let child_data = &v.child_data()[0]; assert_eq!( child_data.child_data().len(), @@ -170,55 +72,19 @@ impl GenericBinaryArray { let builder = ArrayData::builder(Self::DATA_TYPE) .len(v.len()) .offset(v.offset()) - .add_buffer(v.data_ref().buffers()[0].clone()) + .add_buffer(v.buffers()[0].clone()) .add_buffer(child_data.buffers()[0].slice(child_data.offset())) - .null_bit_buffer(v.data_ref().null_buffer().cloned()); + .nulls(v.nulls().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) } - /// Creates a [`GenericBinaryArray`] based on an iterator of values without nulls - pub fn from_iter_values(iter: I) -> Self - where - Ptr: AsRef<[u8]>, - I: IntoIterator, - { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for s in iter { - let s = s.as_ref(); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s); - } - - // iterator size hint may not be correct so compute the actual number of offsets - assert!(!offsets.is_empty()); // wrote at least one - let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(actual_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` pub fn take_iter<'a>( &'a self, indexes: impl Iterator> + 'a, - ) -> impl Iterator> + 'a { + ) -> impl Iterator> { indexes.map(|opt_index| opt_index.map(|index| self.value(index))) } @@ -229,87 +95,12 @@ impl GenericBinaryArray { pub unsafe fn take_iter_unchecked<'a>( &'a self, indexes: impl Iterator> + 'a, - ) -> impl Iterator> + 'a { + ) -> impl Iterator> { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } - - /// constructs a new iterator - pub fn iter(&self) -> GenericBinaryIter<'_, OffsetSize> { - GenericBinaryIter::<'_, OffsetSize>::new(self) - } -} - -impl fmt::Debug for GenericBinaryArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = OffsetSize::PREFIX; - - write!(f, "{}BinaryArray\n[\n", prefix)?; - print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) - })?; - write!(f, "]") - } -} - -impl Array for GenericBinaryArray { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } - - fn into_data(self) -> ArrayData { - self.into() - } } -impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor - for &'a GenericBinaryArray -{ - type Item = &'a [u8]; - - fn value(&self, index: usize) -> Self::Item { - GenericBinaryArray::value(self, index) - } - - unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - GenericBinaryArray::value_unchecked(self, index) - } -} - -impl From for GenericBinaryArray { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.data_type(), - &Self::DATA_TYPE, - "[Large]BinaryArray expects Datatype::[Large]Binary" - ); - assert_eq!( - data.buffers().len(), - 2, - "BinaryArray data should contain 2 buffers only (offsets and values)" - ); - let offsets = data.buffers()[0].as_ptr(); - let values = data.buffers()[1].as_ptr(); - Self { - data, - value_offsets: unsafe { RawPtrBox::new(offsets) }, - value_data: unsafe { RawPtrBox::new(values) }, - } - } -} - -impl From> for ArrayData { - fn from(array: GenericBinaryArray) -> Self { - array.data - } -} - -impl From>> - for GenericBinaryArray -{ +impl From>> for GenericBinaryArray { fn from(v: Vec>) -> Self { Self::from_opt_vec(v) } @@ -327,59 +118,23 @@ impl From> for GenericBinaryArray { } } -impl FromIterator> +impl From> for GenericBinaryArray -where - Ptr: AsRef<[u8]>, { - fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = Vec::with_capacity(data_len + 1); - let mut values = Vec::new(); - let mut null_buf = MutableBuffer::new_null(data_len); - let mut length_so_far: OffsetSize = OffsetSize::zero(); - offsets.push(length_so_far); - - { - let null_slice = null_buf.as_slice_mut(); - - for (i, s) in iter.enumerate() { - if let Some(s) = s { - let s = s.as_ref(); - bit_util::set_bit(null_slice, i); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - values.extend_from_slice(s); - } - // always add an element in offsets - offsets.push(length_so_far); - } - } + fn from(value: GenericStringArray) -> Self { + let builder = value + .into_data() + .into_builder() + .data_type(GenericBinaryType::::DATA_TYPE); - // calculate actual data_len, which may be different from the iterator's upper bound - let data_len = offsets.len() - 1; - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(data_len) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) - .null_bit_buffer(Some(null_buf.into())); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) + // Safety: + // A StringArray is a valid BinaryArray + Self::from(unsafe { builder.build_unchecked() }) } } -impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { - type Item = Option<&'a [u8]>; - type IntoIter = GenericBinaryIter<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - GenericBinaryIter::<'a, T>::new(self) - } -} - -/// An array where each element contains 0 or more bytes. +/// A [`GenericBinaryArray`] of `[u8]` using `i32` offsets +/// /// The byte length of each element is represented by an i32. /// /// # Examples @@ -387,7 +142,7 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// Create a BinaryArray from a vector of byte slices. /// /// ``` -/// use arrow::array::{Array, BinaryArray}; +/// use arrow_array::{Array, BinaryArray}; /// let values: Vec<&[u8]> = /// vec![b"one", b"two", b"", b"three"]; /// let array = BinaryArray::from_vec(values); @@ -401,7 +156,7 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// Create a BinaryArray from a vector of Optional (null) byte slices. /// /// ``` -/// use arrow::array::{Array, BinaryArray}; +/// use arrow_array::{Array, BinaryArray}; /// let values: Vec> = /// vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; /// let array = BinaryArray::from_opt_vec(values); @@ -417,17 +172,17 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// assert!(!array.is_null(4)); /// ``` /// +/// See [`GenericByteArray`] for more information and examples pub type BinaryArray = GenericBinaryArray; -/// An array where each element contains 0 or more bytes. -/// The byte length of each element is represented by an i64. +/// A [`GenericBinaryArray`] of `[u8]` using `i64` offsets /// /// # Examples /// /// Create a LargeBinaryArray from a vector of byte slices. /// /// ``` -/// use arrow::array::{Array, LargeBinaryArray}; +/// use arrow_array::{Array, LargeBinaryArray}; /// let values: Vec<&[u8]> = /// vec![b"one", b"two", b"", b"three"]; /// let array = LargeBinaryArray::from_vec(values); @@ -441,7 +196,7 @@ pub type BinaryArray = GenericBinaryArray; /// Create a LargeBinaryArray from a vector of Optional (null) byte slices. /// /// ``` -/// use arrow::array::{Array, LargeBinaryArray}; +/// use arrow_array::{Array, LargeBinaryArray}; /// let values: Vec> = /// vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; /// let array = LargeBinaryArray::from_opt_vec(values); @@ -457,12 +212,16 @@ pub type BinaryArray = GenericBinaryArray; /// assert!(!array.is_null(4)); /// ``` /// +/// See [`GenericByteArray`] for more information and examples pub type LargeBinaryArray = GenericBinaryArray; #[cfg(test)] mod tests { use super::*; - use crate::{array::ListArray, datatypes::Field}; + use crate::{ListArray, StringArray}; + use arrow_buffer::Buffer; + use arrow_schema::Field; + use std::sync::Arc; #[test] fn test_binary_array() { @@ -474,8 +233,8 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data = ArrayData::builder(DataType::Binary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); @@ -513,8 +272,8 @@ mod tests { let array_data = ArrayData::builder(DataType::Binary) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); @@ -538,8 +297,8 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data = ArrayData::builder(DataType::LargeBinary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = LargeBinaryArray::from(array_data); @@ -577,8 +336,8 @@ mod tests { let array_data = ArrayData::builder(DataType::LargeBinary) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = LargeBinaryArray::from(array_data); @@ -607,28 +366,27 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data1 = ArrayData::builder(GenericBinaryArray::::DATA_TYPE) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array1 = GenericBinaryArray::::from(array_data1); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); let array_data2 = ArrayData::builder(data_type) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); let list_array = GenericListArray::::from(array_data2); let binary_array2 = GenericBinaryArray::::from(list_array); - assert_eq!(2, binary_array2.data().buffers().len()); - assert_eq!(0, binary_array2.data().child_data().len()); - assert_eq!(binary_array1.len(), binary_array2.len()); assert_eq!(binary_array1.null_count(), binary_array2.null_count()); assert_eq!(binary_array1.value_offsets(), binary_array2.value_offsets()); @@ -662,16 +420,18 @@ mod tests { .unwrap(); let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); - let null_buffer = Buffer::from_slice_ref(&[0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let null_buffer = Buffer::from_slice_ref([0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .null_bit_buffer(Some(null_buffer)) .add_child_data(child_data) .build() @@ -696,26 +456,26 @@ mod tests { _test_generic_binary_array_from_list_array_with_offset::(); } - fn _test_generic_binary_array_from_list_array_with_child_nulls_failed< - O: OffsetSizeTrait, - >() { + fn _test_generic_binary_array_from_list_array_with_child_nulls_failed() { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) .add_buffer(Buffer::from(&values[..])) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010]))) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) .build() .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) .len(2) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); @@ -768,7 +528,7 @@ mod tests { .scan(0usize, |pos, i| { if *pos < 10 { *pos += 1; - Some(Some(format!("value {}", i))) + Some(Some(format!("value {i}"))) } else { // actually returns up to 10 values None @@ -787,24 +547,21 @@ mod tests { #[test] #[should_panic( - expected = "assertion failed: `(left == right)`\n left: `UInt32`,\n \ - right: `UInt8`: BinaryArray can only be created from List arrays, \ - mismatched data types." + expected = "BinaryArray can only be created from List arrays, mismatched data types." )] fn test_binary_array_from_incorrect_list_array() { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let offsets: [i32; 4] = [0, 5, 5, 12]; - let data_type = - DataType::List(Box::new(Field::new("item", DataType::UInt32, false))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::UInt32, false))); let array_data = ArrayData::builder(data_type) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(values_data) .build() .unwrap(); @@ -817,25 +574,31 @@ mod tests { expected = "Trying to access an element at index 4 from a BinaryArray of length 3" )] fn test_binary_array_get_value_index_out_of_bound() { - let values: [u8; 12] = - [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; + let values: [u8; 12] = [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; let offsets: [i32; 4] = [0, 5, 5, 12]; let array_data = ArrayData::builder(DataType::Binary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); binary_array.value(4); } + #[test] + #[should_panic(expected = "LargeBinaryArray expects DataType::LargeBinary")] + fn test_binary_array_validation() { + let array = BinaryArray::from_iter_values([&[1, 2]]); + let _ = LargeBinaryArray::from(array.into_data()); + } + #[test] fn test_binary_array_all_null() { let data = vec![None]; let array = BinaryArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } @@ -845,8 +608,36 @@ mod tests { let data = vec![None]; let array = LargeBinaryArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } + + #[test] + fn test_empty_offsets() { + let string = BinaryArray::from( + ArrayData::builder(DataType::Binary) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.value_offsets(), &[0]); + let string = LargeBinaryArray::from( + ArrayData::builder(DataType::LargeBinary) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_to_from_string() { + let s = StringArray::from_iter_values(["a", "b", "c", "d"]); + let b = BinaryArray::from(s.clone()); + let sa = StringArray::from(b); // Performs UTF-8 validation again + + assert_eq!(s, sa); + } } diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs new file mode 100644 index 000000000000..2bf8129fd007 --- /dev/null +++ b/arrow-array/src/array/boolean_array.rs @@ -0,0 +1,700 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::BooleanBuilder; +use crate::iterator::BooleanIter; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; +use arrow_buffer::{bit_util, BooleanBuffer, Buffer, MutableBuffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::DataType; +use std::any::Any; +use std::sync::Arc; + +/// An array of [boolean values](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout) +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = vec![true, true, false].into(); +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = vec![Some(true), None, Some(false)].into(); +/// ``` +/// +/// # Example: From an iterator +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = (0..5).map(|x| (x % 2 == 0).then(|| x % 3 == 0)).collect(); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(&values, &[Some(true), None, Some(false), None, Some(false)]) +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::BooleanBuilder; +/// let mut builder = BooleanBuilder::new(); +/// builder.append_value(true); +/// builder.append_null(); +/// builder.append_value(false); +/// let array = builder.finish(); +/// let values: Vec<_> = array.iter().collect(); +/// assert_eq!(&values, &[Some(true), None, Some(false)]) +/// ``` +/// +#[derive(Clone)] +pub struct BooleanArray { + values: BooleanBuffer, + nulls: Option, +} + +impl std::fmt::Debug for BooleanArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "BooleanArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl BooleanArray { + /// Create a new [`BooleanArray`] from the provided values and nulls + /// + /// # Panics + /// + /// Panics if `values.len() != nulls.len()` + pub fn new(values: BooleanBuffer, nulls: Option) -> Self { + if let Some(n) = nulls.as_ref() { + assert_eq!(values.len(), n.len()); + } + Self { values, nulls } + } + + /// Create a new [`BooleanArray`] with length `len` consisting only of nulls + pub fn new_null(len: usize) -> Self { + Self { + values: BooleanBuffer::new_unset(len), + nulls: Some(NullBuffer::new_null(len)), + } + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: bool) -> Scalar { + let values = match value { + true => BooleanBuffer::new_set(1), + false => BooleanBuffer::new_unset(1), + }; + Scalar::new(Self::new(values, None)) + } + + /// Create a new [`BooleanArray`] from a [`Buffer`] specified by `offset` and `len`, the `offset` and `len` in bits + /// Logically convert each bit in [`Buffer`] to boolean and use it to build [`BooleanArray`]. + /// using this method will make the following points self-evident: + /// * there is no `null` in the constructed [`BooleanArray`]; + /// * without considering `buffer.into()`, this method is efficient because there is no need to perform pack and unpack operations on boolean; + pub fn new_from_packed(buffer: impl Into, offset: usize, len: usize) -> Self { + BooleanBuffer::new(buffer.into(), offset, len).into() + } + + /// Create a new [`BooleanArray`] from `&[u8]` + /// This method uses `new_from_packed` and constructs a [`Buffer`] using `value`, and offset is set to 0 and len is set to `value.len() * 8` + /// using this method will make the following points self-evident: + /// * there is no `null` in the constructed [`BooleanArray`]; + /// * the length of the constructed [`BooleanArray`] is always a multiple of 8; + pub fn new_from_u8(value: &[u8]) -> Self { + BooleanBuffer::new(Buffer::from(value), 0, value.len() * 8).into() + } + + /// Returns the length of this array. + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns whether this array is empty. + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + values: self.values.slice(offset, length), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Returns a new boolean array builder + pub fn builder(capacity: usize) -> BooleanBuilder { + BooleanBuilder::with_capacity(capacity) + } + + /// Returns the underlying [`BooleanBuffer`] holding all the values of this array + pub fn values(&self) -> &BooleanBuffer { + &self.values + } + + /// Returns the number of non null, true values within this array + pub fn true_count(&self) -> usize { + match self.nulls() { + Some(nulls) => { + let null_chunks = nulls.inner().bit_chunks().iter_padded(); + let value_chunks = self.values().bit_chunks().iter_padded(); + null_chunks + .zip(value_chunks) + .map(|(a, b)| (a & b).count_ones() as usize) + .sum() + } + None => self.values().count_set_bits(), + } + } + + /// Returns the number of non null, false values within this array + pub fn false_count(&self) -> usize { + self.len() - self.null_count() - self.true_count() + } + + /// Returns the boolean value at index `i`. + /// + /// # Safety + /// This doesn't check bounds, the caller must ensure that index < self.len() + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + self.values.value_unchecked(i) + } + + /// Returns the boolean value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds + pub fn value(&self, i: usize) -> bool { + assert!( + i < self.len(), + "Trying to access an element at index {} from a BooleanArray of length {}", + i, + self.len() + ); + // Safety: + // `i < self.len() + unsafe { self.value_unchecked(i) } + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } + + /// Create a [`BooleanArray`] by evaluating the operation for + /// each element of the provided array + /// + /// ``` + /// # use arrow_array::{BooleanArray, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let r = BooleanArray::from_unary(&array, |x| x > 2); + /// assert_eq!(&r, &BooleanArray::from(vec![false, false, true, true, true])); + /// ``` + pub fn from_unary(left: T, mut op: F) -> Self + where + F: FnMut(T::Item) -> bool, + { + let nulls = left.logical_nulls(); + let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { + // SAFETY: i in range 0..len + op(left.value_unchecked(i)) + }); + Self::new(values, nulls) + } + + /// Create a [`BooleanArray`] by evaluating the binary operation for + /// each element of the provided arrays + /// + /// ``` + /// # use arrow_array::{BooleanArray, Int32Array}; + /// + /// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let b = Int32Array::from(vec![1, 2, 0, 2, 5]); + /// let r = BooleanArray::from_binary(&a, &b, |a, b| a == b); + /// assert_eq!(&r, &BooleanArray::from(vec![true, true, false, false, true])); + /// ``` + /// + /// # Panics + /// + /// This function panics if left and right are not the same length + /// + pub fn from_binary(left: T, right: S, mut op: F) -> Self + where + F: FnMut(T::Item, S::Item) -> bool, + { + assert_eq!(left.len(), right.len()); + + let nulls = NullBuffer::union( + left.logical_nulls().as_ref(), + right.logical_nulls().as_ref(), + ); + let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { + // SAFETY: i in range 0..len + op(left.value_unchecked(i), right.value_unchecked(i)) + }); + Self::new(values, nulls) + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (BooleanBuffer, Option) { + (self.values, self.nulls) + } +} + +impl Array for BooleanArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &DataType::Boolean + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn offset(&self) -> usize { + self.values.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.values.inner().capacity(); + if let Some(x) = &self.nulls { + sum += x.buffer().capacity() + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a> ArrayAccessor for &'a BooleanArray { + type Item = bool; + + fn value(&self, index: usize) -> Self::Item { + BooleanArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + BooleanArray::value_unchecked(self, index) + } +} + +impl From> for BooleanArray { + fn from(data: Vec) -> Self { + let mut mut_buf = MutableBuffer::new_null(data.len()); + { + let mut_slice = mut_buf.as_slice_mut(); + for (i, b) in data.iter().enumerate() { + if *b { + bit_util::set_bit(mut_slice, i); + } + } + } + let array_data = ArrayData::builder(DataType::Boolean) + .len(data.len()) + .add_buffer(mut_buf.into()); + + let array_data = unsafe { array_data.build_unchecked() }; + BooleanArray::from(array_data) + } +} + +impl From>> for BooleanArray { + fn from(data: Vec>) -> Self { + data.iter().collect() + } +} + +impl From for BooleanArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &DataType::Boolean, + "BooleanArray expected ArrayData with type {} got {}", + DataType::Boolean, + data.data_type() + ); + assert_eq!( + data.buffers().len(), + 1, + "BooleanArray data should contain a single buffer only (values buffer)" + ); + let values = BooleanBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + + Self { + values, + nulls: data.nulls().cloned(), + } + } +} + +impl From for ArrayData { + fn from(array: BooleanArray) -> Self { + let builder = ArrayDataBuilder::new(DataType::Boolean) + .len(array.values.len()) + .offset(array.values.offset()) + .nulls(array.nulls) + .buffers(vec![array.values.into_inner()]); + + unsafe { builder.build_unchecked() } + } +} + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = BooleanIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BooleanIter::<'a>::new(self) + } +} + +impl<'a> BooleanArray { + /// constructs a new iterator + pub fn iter(&'a self) -> BooleanIter<'a> { + BooleanIter::<'a>::new(self) + } +} + +impl>> FromIterator for BooleanArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let num_bytes = bit_util::ceil(data_len, 8); + let mut null_builder = MutableBuffer::from_len_zeroed(num_bytes); + let mut val_builder = MutableBuffer::from_len_zeroed(num_bytes); + + let data = val_builder.as_slice_mut(); + + let null_slice = null_builder.as_slice_mut(); + iter.enumerate().for_each(|(i, item)| { + if let Some(a) = item.borrow() { + bit_util::set_bit(null_slice, i); + if *a { + bit_util::set_bit(data, i); + } + } + }); + + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + data_len, + None, + Some(null_builder.into()), + 0, + vec![val_builder.into()], + vec![], + ) + }; + BooleanArray::from(data) + } +} + +impl From for BooleanArray { + fn from(values: BooleanBuffer) -> Self { + Self { + values, + nulls: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::Buffer; + use rand::{thread_rng, Rng}; + + #[test] + fn test_boolean_fmt_debug() { + let arr = BooleanArray::from(vec![true, false, false]); + assert_eq!( + "BooleanArray\n[\n true,\n false,\n false,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_boolean_with_null_fmt_debug() { + let mut builder = BooleanArray::builder(3); + builder.append_value(true); + builder.append_null(); + builder.append_value(false); + let arr = builder.finish(); + assert_eq!( + "BooleanArray\n[\n true,\n null,\n false,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_boolean_array_from_vec() { + let buf = Buffer::from([10_u8]); + let arr = BooleanArray::from(vec![false, true, false, true]); + assert_eq!(&buf, arr.values().inner()); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..4 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + + #[test] + fn test_boolean_array_from_vec_option() { + let buf = Buffer::from([10_u8]); + let arr = BooleanArray::from(vec![Some(false), Some(true), None, Some(true)]); + assert_eq!(&buf, arr.values().inner()); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + for i in 0..4 { + if i == 2 { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } else { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + } + + #[test] + fn test_boolean_array_from_packed() { + let v = [1_u8, 2_u8, 3_u8]; + let arr = BooleanArray::new_from_packed(v, 0, 24); + assert_eq!(24, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.nulls.is_none()); + for i in 0..24 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!( + i == 0 || i == 9 || i == 16 || i == 17, + arr.value(i), + "failed t {i}" + ) + } + } + + #[test] + fn test_boolean_array_from_slice_u8() { + let v: Vec = vec![1, 2, 3]; + let slice = &v[..]; + let arr = BooleanArray::new_from_u8(slice); + assert_eq!(24, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.nulls().is_none()); + for i in 0..24 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!( + i == 0 || i == 9 || i == 16 || i == 17, + arr.value(i), + "failed t {i}" + ) + } + } + + #[test] + fn test_boolean_array_from_iter() { + let v = vec![Some(false), Some(true), Some(false), Some(true)]; + let arr = v.into_iter().collect::(); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.nulls().is_none()); + for i in 0..3 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + + #[test] + fn test_boolean_array_from_nullable_iter() { + let v = vec![Some(true), None, Some(false), None]; + let arr = v.into_iter().collect::(); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(2, arr.null_count()); + assert!(arr.nulls().is_some()); + + assert!(arr.is_valid(0)); + assert!(arr.is_null(1)); + assert!(arr.is_valid(2)); + assert!(arr.is_null(3)); + + assert!(arr.value(0)); + assert!(!arr.value(2)); + } + + #[test] + fn test_boolean_array_builder() { + // Test building a boolean array with ArrayData builder and offset + // 000011011 + let buf = Buffer::from([27_u8]); + let buf2 = buf.clone(); + let data = ArrayData::builder(DataType::Boolean) + .len(5) + .offset(2) + .add_buffer(buf) + .build() + .unwrap(); + let arr = BooleanArray::from(data); + assert_eq!(&buf2, arr.values().inner()); + assert_eq!(5, arr.len()); + assert_eq!(2, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..3 { + assert_eq!(i != 0, arr.value(i), "failed at {i}"); + } + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a BooleanArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let v = vec![Some(true), None, Some(false)]; + let array = v.into_iter().collect::(); + + array.value(4); + } + + #[test] + #[should_panic(expected = "BooleanArray data should contain a single buffer only \ + (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_boolean_array_invalid_buffer_len() { + let data = unsafe { + ArrayData::builder(DataType::Boolean) + .len(5) + .build_unchecked() + }; + drop(BooleanArray::from(data)); + } + + #[test] + #[should_panic(expected = "BooleanArray expected ArrayData with type Boolean got Int32")] + fn test_from_array_data_validation() { + let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32)); + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_true_false_count() { + let mut rng = thread_rng(); + + for _ in 0..10 { + // No nulls + let d: Vec<_> = (0..2000).map(|_| rng.gen_bool(0.5)).collect(); + let b = BooleanArray::from(d.clone()); + + let expected_true = d.iter().filter(|x| **x).count(); + assert_eq!(b.true_count(), expected_true); + assert_eq!(b.false_count(), d.len() - expected_true); + + // With nulls + let d: Vec<_> = (0..2000) + .map(|_| rng.gen_bool(0.5).then(|| rng.gen_bool(0.5))) + .collect(); + let b = BooleanArray::from(d.clone()); + + let expected_true = d.iter().filter(|x| matches!(x, Some(true))).count(); + assert_eq!(b.true_count(), expected_true); + + let expected_false = d.iter().filter(|x| matches!(x, Some(false))).count(); + assert_eq!(b.false_count(), expected_false); + } + } + + #[test] + fn test_into_parts() { + let boolean_array = [Some(true), None, Some(false)] + .into_iter() + .collect::(); + let (values, nulls) = boolean_array.into_parts(); + assert_eq!(values.values(), &[0b0000_0001]); + assert!(nulls.is_some()); + assert_eq!(nulls.unwrap().buffer().as_slice(), &[0b0000_0101]); + + let boolean_array = + BooleanArray::from(vec![false, false, false, false, false, false, false, true]); + let (values, nulls) = boolean_array.into_parts(); + assert_eq!(values.values(), &[0b1000_0000]); + assert!(nulls.is_none()); + } +} diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs new file mode 100644 index 000000000000..a57abc5b1e71 --- /dev/null +++ b/arrow-array/src/array/byte_array.rs @@ -0,0 +1,617 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{get_offsets, print_long_array}; +use crate::builder::GenericByteBuilder; +use crate::iterator::ArrayIter; +use crate::types::bytes::ByteArrayNativeType; +use crate::types::ByteArrayType; +use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait, Scalar}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{NullBuffer, OffsetBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::sync::Arc; + +/// An array of [variable length byte arrays](https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout) +/// +/// See [`StringArray`] and [`LargeStringArray`] for storing utf8 encoded string data +/// +/// See [`BinaryArray`] and [`LargeBinaryArray`] for storing arbitrary bytes +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = vec!["hello", "world", ""].into(); +/// assert_eq!(arr.value_data(), b"helloworld"); +/// assert_eq!(arr.value_offsets(), &[0, 5, 10, 10]); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("hello"), Some("world"), Some("")]); +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = vec![Some("hello"), Some("world"), Some(""), None].into(); +/// assert_eq!(arr.value_data(), b"helloworld"); +/// assert_eq!(arr.value_offsets(), &[0, 5, 10, 10, 10]); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("hello"), Some("world"), Some(""), None]); +/// ``` +/// +/// # Example: From an iterator of option +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = (0..5).map(|x| (x % 2 == 0).then(|| x.to_string())).collect(); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("0"), None, Some("2"), None, Some("4")]); +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::GenericByteBuilder; +/// # use arrow_array::types::Utf8Type; +/// let mut builder = GenericByteBuilder::::new(); +/// builder.append_value("hello"); +/// builder.append_null(); +/// builder.append_value("world"); +/// let array = builder.finish(); +/// let values: Vec<_> = array.iter().collect(); +/// assert_eq!(values, &[Some("hello"), None, Some("world")]); +/// ``` +/// +/// [`StringArray`]: crate::StringArray +/// [`LargeStringArray`]: crate::LargeStringArray +/// [`BinaryArray`]: crate::BinaryArray +/// [`LargeBinaryArray`]: crate::LargeBinaryArray +pub struct GenericByteArray { + data_type: DataType, + value_offsets: OffsetBuffer, + value_data: Buffer, + nulls: Option, +} + +impl Clone for GenericByteArray { + fn clone(&self) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: self.value_offsets.clone(), + value_data: self.value_data.clone(), + nulls: self.nulls.clone(), + } + } +} + +impl GenericByteArray { + /// Data type of the array. + pub const DATA_TYPE: DataType = T::DATA_TYPE; + + /// Create a new [`GenericByteArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`GenericByteArray::try_new`] returns an error + pub fn new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Self { + Self::try_new(offsets, values, nulls).unwrap() + } + + /// Create a new [`GenericByteArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * Any consecutive pair of `offsets` does not denote a valid slice of `values` + pub fn try_new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Result { + let len = offsets.len() - 1; + + // Verify that each pair of offsets is a valid slices of values + T::validate(&offsets, &values)?; + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for {}{}Array, expected {len} got {}", + T::Offset::PREFIX, + T::PREFIX, + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + value_offsets: offsets, + value_data: values, + nulls, + }) + } + + /// Create a new [`GenericByteArray`] from the provided parts, without validation + /// + /// # Safety + /// + /// Safe if [`Self::try_new`] would not error + pub unsafe fn new_unchecked( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: offsets, + value_data: values, + nulls, + } + } + + /// Create a new [`GenericByteArray`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: OffsetBuffer::new_zeroed(len), + value_data: MutableBuffer::new(0).into(), + nulls: Some(NullBuffer::new_null(len)), + } + } + + /// Create a new [`Scalar`] from `v` + pub fn new_scalar(value: impl AsRef) -> Scalar { + Scalar::new(Self::from_iter_values(std::iter::once(value))) + } + + /// Creates a [`GenericByteArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self + where + Ptr: AsRef, + I: IntoIterator, + { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let mut offsets = MutableBuffer::new((data_len + 1) * std::mem::size_of::()); + offsets.push(T::Offset::usize_as(0)); + + let mut values = MutableBuffer::new(0); + for s in iter { + let s: &[u8] = s.as_ref().as_ref(); + values.extend_from_slice(s); + offsets.push(T::Offset::usize_as(values.len())); + } + + T::Offset::from_usize(values.len()).expect("offset overflow"); + let offsets = Buffer::from(offsets); + + // Safety: valid by construction + let value_offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Self { + data_type: T::DATA_TYPE, + value_data: values.into(), + value_offsets, + nulls: None, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (OffsetBuffer, Buffer, Option) { + (self.value_offsets, self.value_data, self.nulls) + } + + /// Returns the length for value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds. + #[inline] + pub fn value_length(&self, i: usize) -> T::Offset { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } + + /// Returns a reference to the offsets of this array + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns the values of this array + /// + /// Unlike [`Self::value_data`] this returns the [`Buffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn values(&self) -> &Buffer { + &self.value_data + } + + /// Returns the raw value data + pub fn value_data(&self) -> &[u8] { + self.value_data.as_slice() + } + + /// Returns true if all data within this array is ASCII + pub fn is_ascii(&self) -> bool { + let offsets = self.value_offsets(); + let start = offsets.first().unwrap(); + let end = offsets.last().unwrap(); + self.value_data()[start.as_usize()..end.as_usize()].is_ascii() + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[T::Offset] { + &self.value_offsets + } + + /// Returns the element at index `i` + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + pub unsafe fn value_unchecked(&self, i: usize) -> &T::Native { + let end = *self.value_offsets().get_unchecked(i + 1); + let start = *self.value_offsets().get_unchecked(i); + + // Soundness + // pointer alignment & location is ensured by RawPtrBox + // buffer bounds/offset is ensured by the value_offset invariants + + // Safety of `to_isize().unwrap()` + // `start` and `end` are &OffsetSize, which is a generic type that implements the + // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, + // both of which should cleanly cast to isize on an architecture that supports + // 32/64-bit offsets + let b = std::slice::from_raw_parts( + self.value_data.as_ptr().offset(start.to_isize().unwrap()), + (end - start).to_usize().unwrap(), + ); + + // SAFETY: + // ArrayData is valid + T::Native::from_bytes_unchecked(b) + } + + /// Returns the element at index `i` + /// # Panics + /// Panics if index `i` is out of bounds. + pub fn value(&self, i: usize) -> &T::Native { + assert!( + i < self.len(), + "Trying to access an element at index {} from a {}{}Array of length {}", + i, + T::Offset::PREFIX, + T::PREFIX, + self.len() + ); + // SAFETY: + // Verified length above + unsafe { self.value_unchecked(i) } + } + + /// constructs a new iterator + pub fn iter(&self) -> ArrayIter<&Self> { + ArrayIter::new(self) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: self.value_offsets.slice(offset, length), + value_data: self.value_data.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Returns `GenericByteBuilder` of this byte array for mutating its values if the underlying + /// offset and data buffers are not shared by others. + pub fn into_builder(self) -> Result, Self> { + let len = self.len(); + let value_len = T::Offset::as_usize(self.value_offsets()[len] - self.value_offsets()[0]); + + let data = self.into_data(); + let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); + + let element_len = std::mem::size_of::(); + let offset_buffer = data.buffers()[0] + .slice_with_length(data.offset() * element_len, (len + 1) * element_len); + + let element_len = std::mem::size_of::(); + let value_buffer = data.buffers()[1] + .slice_with_length(data.offset() * element_len, value_len * element_len); + + drop(data); + + let try_mutable_null_buffer = match null_bit_buffer { + None => Ok(None), + Some(null_buffer) => { + // Null buffer exists, tries to make it mutable + null_buffer.into_mutable().map(Some) + } + }; + + let try_mutable_buffers = match try_mutable_null_buffer { + Ok(mutable_null_buffer) => { + // Got mutable null buffer, tries to get mutable value buffer + let try_mutable_offset_buffer = offset_buffer.into_mutable(); + let try_mutable_value_buffer = value_buffer.into_mutable(); + + // try_mutable_offset_buffer.map(...).map_err(...) doesn't work as the compiler complains + // mutable_null_buffer is moved into map closure. + match (try_mutable_offset_buffer, try_mutable_value_buffer) { + (Ok(mutable_offset_buffer), Ok(mutable_value_buffer)) => unsafe { + Ok(GenericByteBuilder::::new_from_buffer( + mutable_offset_buffer, + mutable_value_buffer, + mutable_null_buffer, + )) + }, + (Ok(mutable_offset_buffer), Err(value_buffer)) => Err(( + mutable_offset_buffer.into(), + value_buffer, + mutable_null_buffer.map(|b| b.into()), + )), + (Err(offset_buffer), Ok(mutable_value_buffer)) => Err(( + offset_buffer, + mutable_value_buffer.into(), + mutable_null_buffer.map(|b| b.into()), + )), + (Err(offset_buffer), Err(value_buffer)) => Err(( + offset_buffer, + value_buffer, + mutable_null_buffer.map(|b| b.into()), + )), + } + } + Err(mutable_null_buffer) => { + // Unable to get mutable null buffer + Err((offset_buffer, value_buffer, Some(mutable_null_buffer))) + } + }; + + match try_mutable_buffers { + Ok(builder) => Ok(builder), + Err((offset_buffer, value_buffer, null_bit_buffer)) => { + let builder = ArrayData::builder(T::DATA_TYPE) + .len(len) + .add_buffer(offset_buffer) + .add_buffer(value_buffer) + .null_bit_buffer(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + let array = GenericByteArray::::from(array_data); + + Err(array) + } + } + } +} + +impl std::fmt::Debug for GenericByteArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}{}Array\n[\n", T::Offset::PREFIX, T::PREFIX)?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl Array for GenericByteArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.value_offsets.inner().inner().capacity(); + sum += self.value_data.capacity(); + if let Some(x) = &self.nulls { + sum += x.buffer().capacity() + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a, T: ByteArrayType> ArrayAccessor for &'a GenericByteArray { + type Item = &'a T::Native; + + fn value(&self, index: usize) -> Self::Item { + GenericByteArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericByteArray::value_unchecked(self, index) + } +} + +impl From for GenericByteArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &Self::DATA_TYPE, + "{}{}Array expects DataType::{}", + T::Offset::PREFIX, + T::PREFIX, + Self::DATA_TYPE + ); + assert_eq!( + data.buffers().len(), + 2, + "{}{}Array data should contain 2 buffers only (offsets and values)", + T::Offset::PREFIX, + T::PREFIX, + ); + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + let value_data = data.buffers()[1].clone(); + Self { + value_offsets, + value_data, + data_type: T::DATA_TYPE, + nulls: data.nulls().cloned(), + } + } +} + +impl From> for ArrayData { + fn from(array: GenericByteArray) -> Self { + let len = array.len(); + + let offsets = array.value_offsets.into_inner().into_inner(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .buffers(vec![offsets, array.value_data]) + .nulls(array.nulls); + + unsafe { builder.build_unchecked() } + } +} + +impl<'a, T: ByteArrayType> IntoIterator for &'a GenericByteArray { + type Item = Option<&'a T::Native>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, Ptr, T: ByteArrayType> FromIterator<&'a Option> for GenericByteArray +where + Ptr: AsRef + 'a, +{ + fn from_iter>>(iter: I) -> Self { + iter.into_iter() + .map(|o| o.as_ref().map(|p| p.as_ref())) + .collect() + } +} + +impl FromIterator> for GenericByteArray +where + Ptr: AsRef, +{ + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut builder = GenericByteBuilder::with_capacity(iter.size_hint().0, 1024); + builder.extend(iter); + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use crate::{BinaryArray, StringArray}; + use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; + + #[test] + fn try_new() { + let data = Buffer::from_slice_ref("helloworld"); + let offsets = OffsetBuffer::new(vec![0, 5, 10].into()); + StringArray::new(offsets.clone(), data.clone(), None); + + let nulls = NullBuffer::new_null(3); + let err = + StringArray::try_new(offsets.clone(), data.clone(), Some(nulls.clone())).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for StringArray, expected 2 got 3"); + + let err = BinaryArray::try_new(offsets.clone(), data.clone(), Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for BinaryArray, expected 2 got 3"); + + let non_utf8_data = Buffer::from_slice_ref(b"he\xFFloworld"); + let err = StringArray::try_new(offsets.clone(), non_utf8_data.clone(), None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2"); + + BinaryArray::new(offsets, non_utf8_data, None); + + let offsets = OffsetBuffer::new(vec![0, 5, 11].into()); + let err = StringArray::try_new(offsets.clone(), data.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Offset of 11 exceeds length of values 10" + ); + + let err = BinaryArray::try_new(offsets.clone(), data, None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Maximum offset of 11 is larger than values of length 10" + ); + + let non_ascii_data = Buffer::from_slice_ref("heìloworld"); + StringArray::new(offsets.clone(), non_ascii_data.clone(), None); + BinaryArray::new(offsets, non_ascii_data.clone(), None); + + let offsets = OffsetBuffer::new(vec![0, 3, 10].into()); + let err = StringArray::try_new(offsets.clone(), non_ascii_data.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Split UTF-8 codepoint at offset 3" + ); + + BinaryArray::new(offsets, non_ascii_data, None); + } +} diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs new file mode 100644 index 000000000000..c53478d8b057 --- /dev/null +++ b/arrow-array/src/array/byte_view_array.rs @@ -0,0 +1,999 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::{ArrayBuilder, GenericByteViewBuilder}; +use crate::iterator::ArrayIter; +use crate::types::bytes::ByteArrayNativeType; +use crate::types::{BinaryViewType, ByteViewType, StringViewType}; +use crate::{Array, ArrayAccessor, ArrayRef, GenericByteArray, OffsetSizeTrait, Scalar}; +use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder, ByteView}; +use arrow_schema::{ArrowError, DataType}; +use core::str; +use num::ToPrimitive; +use std::any::Any; +use std::fmt::Debug; +use std::marker::PhantomData; +use std::sync::Arc; + +use super::ByteArrayType; + +/// [Variable-size Binary View Layout]: An array of variable length bytes view arrays. +/// +/// [Variable-size Binary View Layout]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-view-layout +/// +/// This is different from [`GenericByteArray`] as it stores both an offset and +/// length meaning that take / filter operations can be implemented without +/// copying the underlying data. In addition, it stores an inlined prefix which +/// can be used to speed up comparisons. +/// +/// # See Also +/// +/// See [`StringViewArray`] for storing utf8 encoded string data and +/// [`BinaryViewArray`] for storing bytes. +/// +/// # Notes +/// +/// Comparing two `GenericByteViewArray` using PartialEq compares by structure, +/// not by value. as there are many different buffer layouts to represent the +/// same data (e.g. different offsets, different buffer sizes, etc). +/// +/// # Layout: "views" and buffers +/// +/// A `GenericByteViewArray` stores variable length byte strings. An array of +/// `N` elements is stored as `N` fixed length "views" and a variable number +/// of variable length "buffers". +/// +/// Each view is a `u128` value whose layout is different depending on the +/// length of the string stored at that location: +/// +/// ```text +/// ┌──────┬────────────────────────┐ +/// │length│ string value │ +/// Strings (len <= 12) │ │ (padded with 0) │ +/// └──────┴────────────────────────┘ +/// 0 31 127 +/// +/// ┌───────┬───────┬───────┬───────┐ +/// │length │prefix │ buf │offset │ +/// Strings (len > 12) │ │ │ index │ │ +/// └───────┴───────┴───────┴───────┘ +/// 0 31 63 95 127 +/// ``` +/// +/// * Strings with length <= 12 are stored directly in the view. See +/// [`Self::inline_value`] to access the inlined prefix from a short view. +/// +/// * Strings with length > 12: The first four bytes are stored inline in the +/// view and the entire string is stored in one of the buffers. See [`ByteView`] +/// to access the fields of the these views. +/// +/// Unlike [`GenericByteArray`], there are no constraints on the offsets other +/// than they must point into a valid buffer. However, they can be out of order, +/// non continuous and overlapping. +/// +/// For example, in the following diagram, the strings "FishWasInTownToday" and +/// "CrumpleFacedFish" are both longer than 12 bytes and thus are stored in a +/// separate buffer while the string "LavaMonster" is stored inlined in the +/// view. In this case, the same bytes for "Fish" are used to store both strings. +/// +/// [`ByteView`]: arrow_data::ByteView +/// +/// ```text +/// ┌───┐ +/// ┌──────┬──────┬──────┬──────┐ offset │...│ +/// "FishWasInTownTodayYay" │ 21 │ Fish │ 0 │ 115 │─ ─ 103 │Mr.│ +/// └──────┴──────┴──────┴──────┘ │ ┌ ─ ─ ─ ─ ▶ │Cru│ +/// ┌──────┬──────┬──────┬──────┐ │mpl│ +/// "CrumpleFacedFish" │ 16 │ Crum │ 0 │ 103 │─ ─│─ ─ ─ ┘ │eFa│ +/// └──────┴──────┴──────┴──────┘ │ced│ +/// ┌──────┬────────────────────┐ └ ─ ─ ─ ─ ─ ─ ─ ─ ▶│Fis│ +/// "LavaMonster" │ 11 │ LavaMonster\0 │ │hWa│ +/// └──────┴────────────────────┘ offset │sIn│ +/// 115 │Tow│ +/// │nTo│ +/// │day│ +/// u128 "views" │Yay│ +/// buffer 0 │...│ +/// └───┘ +/// ``` +pub struct GenericByteViewArray { + data_type: DataType, + views: ScalarBuffer, + buffers: Vec, + phantom: PhantomData, + nulls: Option, +} + +impl Clone for GenericByteViewArray { + fn clone(&self) -> Self { + Self { + data_type: T::DATA_TYPE, + views: self.views.clone(), + buffers: self.buffers.clone(), + nulls: self.nulls.clone(), + phantom: Default::default(), + } + } +} + +// PartialEq +impl PartialEq for GenericByteViewArray { + fn eq(&self, other: &Self) -> bool { + other.data_type.eq(&self.data_type) + && other.views.eq(&self.views) + && other.buffers.eq(&self.buffers) + && other.nulls.eq(&self.nulls) + } +} + +impl GenericByteViewArray { + /// Create a new [`GenericByteViewArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`GenericByteViewArray::try_new`] returns an error + pub fn new(views: ScalarBuffer, buffers: Vec, nulls: Option) -> Self { + Self::try_new(views, buffers, nulls).unwrap() + } + + /// Create a new [`GenericByteViewArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `views.len() != nulls.len()` + /// * [ByteViewType::validate] fails + pub fn try_new( + views: ScalarBuffer, + buffers: Vec, + nulls: Option, + ) -> Result { + T::validate(&views, &buffers)?; + + if let Some(n) = nulls.as_ref() { + if n.len() != views.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for {}ViewArray, expected {} got {}", + T::PREFIX, + views.len(), + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + views, + buffers, + nulls, + phantom: Default::default(), + }) + } + + /// Create a new [`GenericByteViewArray`] from the provided parts, without validation + /// + /// # Safety + /// + /// Safe if [`Self::try_new`] would not error + pub unsafe fn new_unchecked( + views: ScalarBuffer, + buffers: Vec, + nulls: Option, + ) -> Self { + Self { + data_type: T::DATA_TYPE, + phantom: Default::default(), + views, + buffers, + nulls, + } + } + + /// Create a new [`GenericByteViewArray`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + views: vec![0; len].into(), + buffers: vec![], + nulls: Some(NullBuffer::new_null(len)), + phantom: Default::default(), + } + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: impl AsRef) -> Scalar { + Scalar::new(Self::from_iter_values(std::iter::once(value))) + } + + /// Creates a [`GenericByteViewArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self + where + Ptr: AsRef, + I: IntoIterator, + { + let iter = iter.into_iter(); + let mut builder = GenericByteViewBuilder::::with_capacity(iter.size_hint().0); + for v in iter { + builder.append_value(v); + } + builder.finish() + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (ScalarBuffer, Vec, Option) { + (self.views, self.buffers, self.nulls) + } + + /// Returns the views buffer + #[inline] + pub fn views(&self) -> &ScalarBuffer { + &self.views + } + + /// Returns the buffers storing string data + #[inline] + pub fn data_buffers(&self) -> &[Buffer] { + &self.buffers + } + + /// Returns the element at index `i` + /// # Panics + /// Panics if index `i` is out of bounds. + pub fn value(&self, i: usize) -> &T::Native { + assert!( + i < self.len(), + "Trying to access an element at index {} from a {}ViewArray of length {}", + i, + T::PREFIX, + self.len() + ); + + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` without bounds checking + /// + /// # Safety + /// + /// Caller is responsible for ensuring that the index is within the bounds + /// of the array + pub unsafe fn value_unchecked(&self, idx: usize) -> &T::Native { + let v = self.views.get_unchecked(idx); + let len = *v as u32; + let b = if len <= 12 { + Self::inline_value(v, len as usize) + } else { + let view = ByteView::from(*v); + let data = self.buffers.get_unchecked(view.buffer_index as usize); + let offset = view.offset as usize; + data.get_unchecked(offset..offset + len as usize) + }; + T::Native::from_bytes_unchecked(b) + } + + /// Returns the first `len` bytes the inline value of the view. + /// + /// # Safety + /// - The `view` must be a valid element from `Self::views()` that adheres to the view layout. + /// - The `len` must be the length of the inlined value. It should never be larger than 12. + #[inline(always)] + pub unsafe fn inline_value(view: &u128, len: usize) -> &[u8] { + debug_assert!(len <= 12); + std::slice::from_raw_parts((view as *const u128 as *const u8).wrapping_add(4), len) + } + + /// Constructs a new iterator for iterating over the values of this array + pub fn iter(&self) -> ArrayIter<&Self> { + ArrayIter::new(self) + } + + /// Returns an iterator over the bytes of this array, including null values + pub fn bytes_iter(&self) -> impl Iterator { + self.views.iter().map(move |v| { + let len = *v as u32; + if len <= 12 { + unsafe { Self::inline_value(v, len as usize) } + } else { + let view = ByteView::from(*v); + let data = &self.buffers[view.buffer_index as usize]; + let offset = view.offset as usize; + unsafe { data.get_unchecked(offset..offset + len as usize) } + } + }) + } + + /// Returns an iterator over the first `prefix_len` bytes of each array + /// element, including null values. + /// + /// If `prefix_len` is larger than the element's length, the iterator will + /// return an empty slice (`&[]`). + pub fn prefix_bytes_iter(&self, prefix_len: usize) -> impl Iterator { + self.views().into_iter().map(move |v| { + let len = (*v as u32) as usize; + + if len < prefix_len { + return &[] as &[u8]; + } + + if prefix_len <= 4 || len <= 12 { + unsafe { StringViewArray::inline_value(v, prefix_len) } + } else { + let view = ByteView::from(*v); + let data = unsafe { + self.data_buffers() + .get_unchecked(view.buffer_index as usize) + }; + let offset = view.offset as usize; + unsafe { data.get_unchecked(offset..offset + prefix_len) } + } + }) + } + + /// Returns an iterator over the last `suffix_len` bytes of each array + /// element, including null values. + /// + /// Note that for [`StringViewArray`] the last bytes may start in the middle + /// of a UTF-8 codepoint, and thus may not be a valid `&str`. + /// + /// If `suffix_len` is larger than the element's length, the iterator will + /// return an empty slice (`&[]`). + pub fn suffix_bytes_iter(&self, suffix_len: usize) -> impl Iterator { + self.views().into_iter().map(move |v| { + let len = (*v as u32) as usize; + + if len < suffix_len { + return &[] as &[u8]; + } + + if len <= 12 { + unsafe { &StringViewArray::inline_value(v, len)[len - suffix_len..] } + } else { + let view = ByteView::from(*v); + let data = unsafe { + self.data_buffers() + .get_unchecked(view.buffer_index as usize) + }; + let offset = view.offset as usize; + unsafe { data.get_unchecked(offset + len - suffix_len..offset + len) } + } + }) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + views: self.views.slice(offset, length), + buffers: self.buffers.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + phantom: Default::default(), + } + } + + /// Returns a "compacted" version of this array + /// + /// The original array will *not* be modified + /// + /// # Garbage Collection + /// + /// Before GC: + /// ```text + /// ┌──────┐ + /// │......│ + /// │......│ + /// ┌────────────────────┐ ┌ ─ ─ ─ ▶ │Data1 │ Large buffer + /// │ View 1 │─ ─ ─ ─ │......│ with data that + /// ├────────────────────┤ │......│ is not referred + /// │ View 2 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data2 │ to by View 1 or + /// └────────────────────┘ │......│ View 2 + /// │......│ + /// 2 views, refer to │......│ + /// small portions of a └──────┘ + /// large buffer + /// ``` + /// + /// After GC: + /// + /// ```text + /// ┌────────────────────┐ ┌─────┐ After gc, only + /// │ View 1 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data1│ data that is + /// ├────────────────────┤ ┌ ─ ─ ─ ▶ │Data2│ pointed to by + /// │ View 2 │─ ─ ─ ─ └─────┘ the views is + /// └────────────────────┘ left + /// + /// + /// 2 views + /// ``` + /// This method will compact the data buffers by recreating the view array and only include the data + /// that is pointed to by the views. + /// + /// Note that it will copy the array regardless of whether the original array is compact. + /// Use with caution as this can be an expensive operation, only use it when you are sure that the view + /// array is significantly smaller than when it is originally created, e.g., after filtering or slicing. + /// + /// Note: this function does not attempt to canonicalize / deduplicate values. For this + /// feature see [`GenericByteViewBuilder::with_deduplicate_strings`]. + pub fn gc(&self) -> Self { + let mut builder = GenericByteViewBuilder::::with_capacity(self.len()); + + for v in self.iter() { + builder.append_option(v); + } + + builder.finish() + } + + /// Compare two [`GenericByteViewArray`] at index `left_idx` and `right_idx` + /// + /// Comparing two ByteView types are non-trivial. + /// It takes a bit of patience to understand why we don't just compare two &[u8] directly. + /// + /// ByteView types give us the following two advantages, and we need to be careful not to lose them: + /// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view. + /// Meaning that reading one array element requires only one memory access + /// (two memory access required for StringArray, one for offset buffer, the other for value buffer). + /// + /// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray, + /// thanks to the inlined 4 bytes. + /// Consider equality check: + /// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access). + /// + /// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary. + /// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer, + /// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string. + /// + /// # Order check flow + /// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view. + /// (2) if any of the string is larger than 12 bytes, we need to compare the full string. + /// (2.1) if the inlined 4 bytes are different, we can return the result immediately. + /// (2.2) o.w., we need to compare the full string. + /// + /// # Safety + /// The left/right_idx must within range of each array + pub unsafe fn compare_unchecked( + left: &GenericByteViewArray, + left_idx: usize, + right: &GenericByteViewArray, + right_idx: usize, + ) -> std::cmp::Ordering { + let l_view = left.views().get_unchecked(left_idx); + let l_len = *l_view as u32; + + let r_view = right.views().get_unchecked(right_idx); + let r_len = *r_view as u32; + + if l_len <= 12 && r_len <= 12 { + let l_data = unsafe { GenericByteViewArray::::inline_value(l_view, l_len as usize) }; + let r_data = unsafe { GenericByteViewArray::::inline_value(r_view, r_len as usize) }; + return l_data.cmp(r_data); + } + + // one of the string is larger than 12 bytes, + // we then try to compare the inlined data first + let l_inlined_data = unsafe { GenericByteViewArray::::inline_value(l_view, 4) }; + let r_inlined_data = unsafe { GenericByteViewArray::::inline_value(r_view, 4) }; + if r_inlined_data != l_inlined_data { + return l_inlined_data.cmp(r_inlined_data); + } + + // unfortunately, we need to compare the full data + let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() }; + let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() }; + + l_full_data.cmp(r_full_data) + } +} + +impl Debug for GenericByteViewArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}ViewArray\n[\n", T::PREFIX)?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl Array for GenericByteViewArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.views.len() + } + + fn is_empty(&self) -> bool { + self.views.is_empty() + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.buffers.iter().map(|b| b.capacity()).sum::(); + sum += self.views.inner().capacity(); + if let Some(x) = &self.nulls { + sum += x.buffer().capacity() + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a, T: ByteViewType + ?Sized> ArrayAccessor for &'a GenericByteViewArray { + type Item = &'a T::Native; + + fn value(&self, index: usize) -> Self::Item { + GenericByteViewArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericByteViewArray::value_unchecked(self, index) + } +} + +impl<'a, T: ByteViewType + ?Sized> IntoIterator for &'a GenericByteViewArray { + type Item = Option<&'a T::Native>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl From for GenericByteViewArray { + fn from(value: ArrayData) -> Self { + let views = value.buffers()[0].clone(); + let views = ScalarBuffer::new(views, value.offset(), value.len()); + let buffers = value.buffers()[1..].to_vec(); + Self { + data_type: T::DATA_TYPE, + views, + buffers, + nulls: value.nulls().cloned(), + phantom: Default::default(), + } + } +} + +/// Convert a [`GenericByteArray`] to a [`GenericByteViewArray`] but in a smart way: +/// If the offsets are all less than u32::MAX, then we directly build the view array on top of existing buffer. +impl From<&GenericByteArray> for GenericByteViewArray +where + FROM: ByteArrayType, + FROM::Offset: OffsetSizeTrait + ToPrimitive, + V: ByteViewType, +{ + fn from(byte_array: &GenericByteArray) -> Self { + let offsets = byte_array.offsets(); + + let can_reuse_buffer = match offsets.last() { + Some(offset) => offset.as_usize() < u32::MAX as usize, + None => true, + }; + + if can_reuse_buffer { + let len = byte_array.len(); + let mut views_builder = GenericByteViewBuilder::::with_capacity(len); + let str_values_buf = byte_array.values().clone(); + let block = views_builder.append_block(str_values_buf); + for (i, w) in offsets.windows(2).enumerate() { + let offset = w[0].as_usize(); + let end = w[1].as_usize(); + let length = end - offset; + + if byte_array.is_null(i) { + views_builder.append_null(); + } else { + // Safety: the input was a valid array so it valid UTF8 (if string). And + // all offsets were valid + unsafe { + views_builder.append_view_unchecked(block, offset as u32, length as u32) + } + } + } + assert_eq!(views_builder.len(), len); + views_builder.finish() + } else { + // TODO: the first u32::MAX can still be reused + GenericByteViewArray::::from_iter(byte_array.iter()) + } + } +} + +impl From> for ArrayData { + fn from(mut array: GenericByteViewArray) -> Self { + let len = array.len(); + array.buffers.insert(0, array.views.into_inner()); + let builder = ArrayDataBuilder::new(T::DATA_TYPE) + .len(len) + .buffers(array.buffers) + .nulls(array.nulls); + + unsafe { builder.build_unchecked() } + } +} + +impl<'a, Ptr, T> FromIterator<&'a Option> for GenericByteViewArray +where + Ptr: AsRef + 'a, + T: ByteViewType + ?Sized, +{ + fn from_iter>>(iter: I) -> Self { + iter.into_iter() + .map(|o| o.as_ref().map(|p| p.as_ref())) + .collect() + } +} + +impl FromIterator> for GenericByteViewArray +where + Ptr: AsRef, +{ + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut builder = GenericByteViewBuilder::::with_capacity(iter.size_hint().0); + builder.extend(iter); + builder.finish() + } +} + +/// A [`GenericByteViewArray`] of `[u8]` +/// +/// # Example +/// ``` +/// use arrow_array::BinaryViewArray; +/// let array = BinaryViewArray::from_iter_values(vec![b"hello" as &[u8], b"world", b"lulu", b"large payload over 12 bytes"]); +/// assert_eq!(array.value(0), b"hello"); +/// assert_eq!(array.value(3), b"large payload over 12 bytes"); +/// ``` +pub type BinaryViewArray = GenericByteViewArray; + +impl BinaryViewArray { + /// Convert the [`BinaryViewArray`] to [`StringViewArray`] + /// If items not utf8 data, validate will fail and error returned. + pub fn to_string_view(self) -> Result { + StringViewType::validate(self.views(), self.data_buffers())?; + unsafe { Ok(self.to_string_view_unchecked()) } + } + + /// Convert the [`BinaryViewArray`] to [`StringViewArray`] + /// # Safety + /// Caller is responsible for ensuring that items in array are utf8 data. + pub unsafe fn to_string_view_unchecked(self) -> StringViewArray { + StringViewArray::new_unchecked(self.views, self.buffers, self.nulls) + } +} + +impl From> for BinaryViewArray { + fn from(v: Vec<&[u8]>) -> Self { + Self::from_iter_values(v) + } +} + +impl From>> for BinaryViewArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() + } +} + +/// A [`GenericByteViewArray`] that stores utf8 data +/// +/// # Example +/// ``` +/// use arrow_array::StringViewArray; +/// let array = StringViewArray::from_iter_values(vec!["hello", "world", "lulu", "large payload over 12 bytes"]); +/// assert_eq!(array.value(0), "hello"); +/// assert_eq!(array.value(3), "large payload over 12 bytes"); +/// ``` +pub type StringViewArray = GenericByteViewArray; + +impl StringViewArray { + /// Convert the [`StringViewArray`] to [`BinaryViewArray`] + pub fn to_binary_view(self) -> BinaryViewArray { + unsafe { BinaryViewArray::new_unchecked(self.views, self.buffers, self.nulls) } + } + + /// Returns true if all data within this array is ASCII + pub fn is_ascii(&self) -> bool { + // Alternative (but incorrect): directly check the underlying buffers + // (1) Our string view might be sparse, i.e., a subset of the buffers, + // so even if the buffer is not ascii, we can still be ascii. + // (2) It is quite difficult to know the range of each buffer (unlike StringArray) + // This means that this operation is quite expensive, shall we cache the result? + // i.e. track `is_ascii` in the builder. + self.iter().all(|v| match v { + Some(v) => v.is_ascii(), + None => true, + }) + } +} + +impl From> for StringViewArray { + fn from(v: Vec<&str>) -> Self { + Self::from_iter_values(v) + } +} + +impl From>> for StringViewArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() + } +} + +impl From> for StringViewArray { + fn from(v: Vec) -> Self { + Self::from_iter_values(v) + } +} + +impl From>> for StringViewArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() + } +} + +#[cfg(test)] +mod tests { + use crate::builder::{BinaryViewBuilder, StringViewBuilder}; + use crate::{Array, BinaryViewArray, StringViewArray}; + use arrow_buffer::{Buffer, ScalarBuffer}; + use arrow_data::ByteView; + + #[test] + fn try_new_string() { + let array = StringViewArray::from_iter_values(vec![ + "hello", + "world", + "lulu", + "large payload over 12 bytes", + ]); + assert_eq!(array.value(0), "hello"); + assert_eq!(array.value(3), "large payload over 12 bytes"); + } + + #[test] + fn try_new_binary() { + let array = BinaryViewArray::from_iter_values(vec![ + b"hello".as_slice(), + b"world".as_slice(), + b"lulu".as_slice(), + b"large payload over 12 bytes".as_slice(), + ]); + assert_eq!(array.value(0), b"hello"); + assert_eq!(array.value(3), b"large payload over 12 bytes"); + } + + #[test] + fn try_new_empty_string() { + // test empty array + let array = { + let mut builder = StringViewBuilder::new(); + builder.finish() + }; + assert!(array.is_empty()); + } + + #[test] + fn try_new_empty_binary() { + // test empty array + let array = { + let mut builder = BinaryViewBuilder::new(); + builder.finish() + }; + assert!(array.is_empty()); + } + + #[test] + fn test_append_string() { + // test builder append + let array = { + let mut builder = StringViewBuilder::new(); + builder.append_value("hello"); + builder.append_null(); + builder.append_option(Some("large payload over 12 bytes")); + builder.finish() + }; + assert_eq!(array.value(0), "hello"); + assert!(array.is_null(1)); + assert_eq!(array.value(2), "large payload over 12 bytes"); + } + + #[test] + fn test_append_binary() { + // test builder append + let array = { + let mut builder = BinaryViewBuilder::new(); + builder.append_value(b"hello"); + builder.append_null(); + builder.append_option(Some(b"large payload over 12 bytes")); + builder.finish() + }; + assert_eq!(array.value(0), b"hello"); + assert!(array.is_null(1)); + assert_eq!(array.value(2), b"large payload over 12 bytes"); + } + + #[test] + fn test_in_progress_recreation() { + let array = { + // make a builder with small block size. + let mut builder = StringViewBuilder::new().with_fixed_block_size(14); + builder.append_value("large payload over 12 bytes"); + builder.append_option(Some("another large payload over 12 bytes that double than the first one, so that we can trigger the in_progress in builder re-created")); + builder.finish() + }; + assert_eq!(array.value(0), "large payload over 12 bytes"); + assert_eq!(array.value(1), "another large payload over 12 bytes that double than the first one, so that we can trigger the in_progress in builder re-created"); + assert_eq!(2, array.buffers.len()); + } + + #[test] + #[should_panic(expected = "Invalid buffer index at 0: got index 3 but only has 1 buffers")] + fn new_with_invalid_view_data() { + let v = "large payload over 12 bytes"; + let view = ByteView { + length: 13, + prefix: u32::from_le_bytes(v.as_bytes()[0..4].try_into().unwrap()), + buffer_index: 3, + offset: 1, + }; + let views = ScalarBuffer::from(vec![view.into()]); + let buffers = vec![Buffer::from_slice_ref(v)]; + StringViewArray::new(views, buffers, None); + } + + #[test] + #[should_panic( + expected = "Encountered non-UTF-8 data at index 0: invalid utf-8 sequence of 1 bytes from index 0" + )] + fn new_with_invalid_utf8_data() { + let v: Vec = vec![0xf0, 0x80, 0x80, 0x80]; + let view = ByteView { + length: v.len() as u32, + prefix: u32::from_le_bytes(v[0..4].try_into().unwrap()), + buffer_index: 0, + offset: 0, + }; + let views = ScalarBuffer::from(vec![view.into()]); + let buffers = vec![Buffer::from_slice_ref(v)]; + StringViewArray::new(views, buffers, None); + } + + #[test] + #[should_panic(expected = "View at index 0 contained non-zero padding for string of length 1")] + fn new_with_invalid_zero_padding() { + let mut data = [0; 12]; + data[0] = b'H'; + data[11] = 1; // no zero padding + + let mut view_buffer = [0; 16]; + view_buffer[0..4].copy_from_slice(&1u32.to_le_bytes()); + view_buffer[4..].copy_from_slice(&data); + + let view = ByteView::from(u128::from_le_bytes(view_buffer)); + let views = ScalarBuffer::from(vec![view.into()]); + let buffers = vec![]; + StringViewArray::new(views, buffers, None); + } + + #[test] + #[should_panic(expected = "Mismatch between embedded prefix and data")] + fn test_mismatch_between_embedded_prefix_and_data() { + let input_str_1 = "Hello, Rustaceans!"; + let input_str_2 = "Hallo, Rustaceans!"; + let length = input_str_1.len() as u32; + assert!(input_str_1.len() > 12); + + let mut view_buffer = [0; 16]; + view_buffer[0..4].copy_from_slice(&length.to_le_bytes()); + view_buffer[4..8].copy_from_slice(&input_str_1.as_bytes()[0..4]); + view_buffer[8..12].copy_from_slice(&0u32.to_le_bytes()); + view_buffer[12..].copy_from_slice(&0u32.to_le_bytes()); + let view = ByteView::from(u128::from_le_bytes(view_buffer)); + let views = ScalarBuffer::from(vec![view.into()]); + let buffers = vec![Buffer::from_slice_ref(input_str_2.as_bytes())]; + + StringViewArray::new(views, buffers, None); + } + + #[test] + fn test_gc() { + let test_data = [ + Some("longer than 12 bytes"), + Some("short"), + Some("t"), + Some("longer than 12 bytes"), + None, + Some("short"), + ]; + + let array = { + let mut builder = StringViewBuilder::new().with_fixed_block_size(8); // create multiple buffers + test_data.into_iter().for_each(|v| builder.append_option(v)); + builder.finish() + }; + assert!(array.buffers.len() > 1); + + fn check_gc(to_test: &StringViewArray) { + let gc = to_test.gc(); + assert_ne!(to_test.data_buffers().len(), gc.data_buffers().len()); + + to_test.iter().zip(gc.iter()).for_each(|(a, b)| { + assert_eq!(a, b); + }); + assert_eq!(to_test.len(), gc.len()); + } + + check_gc(&array); + check_gc(&array.slice(1, 3)); + check_gc(&array.slice(2, 1)); + check_gc(&array.slice(2, 2)); + check_gc(&array.slice(3, 1)); + } + + #[test] + fn test_eq() { + let test_data = [ + Some("longer than 12 bytes"), + None, + Some("short"), + Some("again, this is longer than 12 bytes"), + ]; + + let array1 = { + let mut builder = StringViewBuilder::new().with_fixed_block_size(8); + test_data.into_iter().for_each(|v| builder.append_option(v)); + builder.finish() + }; + let array2 = { + // create a new array with the same data but different layout + let mut builder = StringViewBuilder::new().with_fixed_block_size(100); + test_data.into_iter().for_each(|v| builder.append_option(v)); + builder.finish() + }; + assert_eq!(array1, array1.clone()); + assert_eq!(array2, array2.clone()); + assert_ne!(array1, array2); + } +} diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs new file mode 100644 index 000000000000..d6c5dd4c3e13 --- /dev/null +++ b/arrow-array/src/array/dictionary_array.rs @@ -0,0 +1,1383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}; +use crate::cast::AsArray; +use crate::iterator::ArrayIter; +use crate::types::*; +use crate::{ + make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, PrimitiveArray, Scalar, + StringArray, +}; +use arrow_buffer::bit_util::set_bit; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, BooleanBufferBuilder}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::sync::Arc; + +/// A [`DictionaryArray`] indexed by `i8` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int8DictionaryArray, Int8Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int8DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int8DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i16` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int16DictionaryArray, Int16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int16DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int16Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int16DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i32` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int32DictionaryArray, Int32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int32DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int32Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int32DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i64` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int64DictionaryArray, Int64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int64DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int64Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int64DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u8` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt8DictionaryArray, UInt8Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt8DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt8Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt8DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u16` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt16DictionaryArray, UInt16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt16DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt16Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt16DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u32` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt32DictionaryArray, UInt32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt32DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt32Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt32DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u64` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt64DictionaryArray, UInt64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt64DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt64Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt64DictionaryArray = DictionaryArray; + +/// An array of [dictionary encoded values](https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout) +/// +/// This is mostly used to represent strings or a limited set of primitive types as integers, +/// for example when doing NLP analysis or representing chromosomes by name. +/// +/// [`DictionaryArray`] are represented using a `keys` array and a +/// `values` array, which may be different lengths. The `keys` array +/// stores indexes in the `values` array which holds +/// the corresponding logical value, as shown here: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌─────────┐ │ ┌─────────────────┐ +/// │ │ A │ │ 0 │ │ A │ values[keys[0]] +/// ├─────────────────┤ ├─────────┤ │ ├─────────────────┤ +/// │ │ D │ │ 2 │ │ B │ values[keys[1]] +/// ├─────────────────┤ ├─────────┤ │ ├─────────────────┤ +/// │ │ B │ │ 2 │ │ B │ values[keys[2]] +/// └─────────────────┘ ├─────────┤ │ ├─────────────────┤ +/// │ │ 1 │ │ D │ values[keys[3]] +/// ├─────────┤ │ ├─────────────────┤ +/// │ │ 1 │ │ D │ values[keys[4]] +/// ├─────────┤ │ ├─────────────────┤ +/// │ │ 0 │ │ A │ values[keys[5]] +/// └─────────┘ │ └─────────────────┘ +/// │ values keys +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Logical array +/// Contents +/// DictionaryArray +/// length = 6 +/// ``` +/// +/// # Example: From Nullable Data +/// +/// ``` +/// # use arrow_array::{DictionaryArray, Int8Array, types::Int8Type}; +/// let test = vec!["a", "a", "b", "c"]; +/// let array : DictionaryArray = test.iter().map(|&x| if x == "b" {None} else {Some(x)}).collect(); +/// assert_eq!(array.keys(), &Int8Array::from(vec![Some(0), Some(0), None, Some(1)])); +/// ``` +/// +/// # Example: From Non-Nullable Data +/// +/// ``` +/// # use arrow_array::{DictionaryArray, Int8Array, types::Int8Type}; +/// let test = vec!["a", "a", "b", "c"]; +/// let array : DictionaryArray = test.into_iter().collect(); +/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); +/// ``` +/// +/// # Example: From Existing Arrays +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{DictionaryArray, Int8Array, StringArray, types::Int8Type}; +/// // You can form your own DictionaryArray by providing the +/// // values (dictionary) and keys (indexes into the dictionary): +/// let values = StringArray::from_iter_values(["a", "b", "c"]); +/// let keys = Int8Array::from_iter_values([0, 0, 1, 2]); +/// let array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); +/// let expected: DictionaryArray:: = vec!["a", "a", "b", "c"].into_iter().collect(); +/// assert_eq!(&array, &expected); +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::{Array, StringArray}; +/// # use arrow_array::builder::StringDictionaryBuilder; +/// # use arrow_array::types::Int32Type; +/// let mut builder = StringDictionaryBuilder::::new(); +/// builder.append_value("a"); +/// builder.append_null(); +/// builder.append_value("a"); +/// builder.append_value("b"); +/// let array = builder.finish(); +/// +/// let values: Vec<_> = array.downcast_dict::().unwrap().into_iter().collect(); +/// assert_eq!(&values, &[Some("a"), None, Some("a"), Some("b")]); +/// ``` +pub struct DictionaryArray { + data_type: DataType, + + /// The keys of this dictionary. These are constructed from the + /// buffer and null bitmap of `data`. Also, note that these do + /// not correspond to the true values of this array. Rather, they + /// map to the real values. + keys: PrimitiveArray, + + /// Array of dictionary values (can by any DataType). + values: ArrayRef, + + /// Values are ordered. + is_ordered: bool, +} + +impl Clone for DictionaryArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.clone(), + values: self.values.clone(), + is_ordered: self.is_ordered, + } + } +} + +impl DictionaryArray { + /// Attempt to create a new DictionaryArray with a specified keys + /// (indexes into the dictionary) and values (dictionary) + /// array. + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(keys: PrimitiveArray, values: ArrayRef) -> Self { + Self::try_new(keys, values).unwrap() + } + + /// Attempt to create a new DictionaryArray with a specified keys + /// (indexes into the dictionary) and values (dictionary) + /// array. + /// + /// # Errors + /// + /// Returns an error if any `keys[i] >= values.len() || keys[i] < 0` + pub fn try_new(keys: PrimitiveArray, values: ArrayRef) -> Result { + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + let zero = K::Native::usize_as(0); + let values_len = values.len(); + + if let Some((idx, v)) = + keys.values().iter().enumerate().find(|(idx, v)| { + (v.is_lt(zero) || v.as_usize() >= values_len) && keys.is_valid(*idx) + }) + { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid dictionary key {v:?} at index {idx}, expected 0 <= key < {values_len}", + ))); + } + + Ok(Self { + data_type, + keys, + values, + is_ordered: false, + }) + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: Scalar) -> Scalar { + Scalar::new(Self::new( + PrimitiveArray::new(vec![K::Native::usize_as(0)].into(), None), + Arc::new(value.into_inner()), + )) + } + + /// Create a new [`DictionaryArray`] without performing validation + /// + /// # Safety + /// + /// Safe provided [`Self::try_new`] would not return an error + pub unsafe fn new_unchecked(keys: PrimitiveArray, values: ArrayRef) -> Self { + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + Self { + data_type, + keys, + values, + is_ordered: false, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (PrimitiveArray, ArrayRef) { + (self.keys, self.values) + } + + /// Return an array view of the keys of this dictionary as a PrimitiveArray. + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + /// If `value` is present in `values` (aka the dictionary), + /// returns the corresponding key (index into the `values` + /// array). Otherwise returns `None`. + /// + /// Panics if `values` is not a [`StringArray`]. + pub fn lookup_key(&self, value: &str) -> Option { + let rd_buf: &StringArray = self.values.as_any().downcast_ref::().unwrap(); + + (0..rd_buf.len()) + .position(|i| rd_buf.value(i) == value) + .and_then(K::Native::from_usize) + } + + /// Returns a reference to the dictionary values array + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_type().clone() + } + + /// The length of the dictionary is the length of the keys array. + pub fn len(&self) -> usize { + self.keys.len() + } + + /// Whether this dictionary is empty + pub fn is_empty(&self) -> bool { + self.keys.is_empty() + } + + /// Currently exists for compatibility purposes with Arrow IPC. + pub fn is_ordered(&self) -> bool { + self.is_ordered + } + + /// Return an iterator over the keys (indexes into the dictionary) + pub fn keys_iter(&self) -> impl Iterator> + '_ { + self.keys.iter().map(|key| key.map(|k| k.as_usize())) + } + + /// Return the value of `keys` (the dictionary key) at index `i`, + /// cast to `usize`, `None` if the value at `i` is `NULL`. + pub fn key(&self, i: usize) -> Option { + self.keys.is_valid(i).then(|| self.keys.value(i).as_usize()) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.slice(offset, length), + values: self.values.clone(), + is_ordered: self.is_ordered, + } + } + + /// Downcast this dictionary to a [`TypedDictionaryArray`] + /// + /// ``` + /// use arrow_array::{Array, ArrayAccessor, DictionaryArray, StringArray, types::Int32Type}; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let dictionary = DictionaryArray::::from_iter(orig); + /// let typed = dictionary.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.is_null(2)); + /// ``` + /// + pub fn downcast_dict(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedDictionaryArray { + dictionary: self, + values, + }) + } + + /// Returns a new dictionary with the same keys as the current instance + /// but with a different set of dictionary values + /// + /// This can be used to perform an operation on the values of a dictionary + /// + /// # Panics + /// + /// Panics if `values` has a length less than the current values + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::builder::PrimitiveDictionaryBuilder; + /// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; + /// # use arrow_array::types::{Int32Type, Int8Type}; + /// + /// // Construct a Dict(Int32, Int8) + /// let mut builder = PrimitiveDictionaryBuilder::::with_capacity(2, 200); + /// for i in 0..100 { + /// builder.append(i % 2).unwrap(); + /// } + /// + /// let dictionary = builder.finish(); + /// + /// // Perform a widening cast of dictionary values + /// let typed_dictionary = dictionary.downcast_dict::().unwrap(); + /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); + /// + /// // Create a Dict(Int32, + /// let new = dictionary.with_values(Arc::new(values)); + /// + /// // Verify values are as expected + /// let new_typed = new.downcast_dict::().unwrap(); + /// for i in 0..100 { + /// assert_eq!(new_typed.value(i), (i % 2) as i64) + /// } + /// ``` + /// + pub fn with_values(&self, values: ArrayRef) -> Self { + assert!(values.len() >= self.values.len()); + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); + Self { + data_type, + keys: self.keys.clone(), + values, + is_ordered: false, + } + } + + /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating + /// its keys and values if the underlying data buffer is not shared by others. + pub fn into_primitive_dict_builder(self) -> Result, Self> + where + V: ArrowPrimitiveType, + { + if !self.value_type().is_primitive() { + return Err(self); + } + + let key_array = self.keys().clone(); + let value_array = self.values().as_primitive::().clone(); + + drop(self.keys); + drop(self.values); + + let key_builder = key_array.into_builder(); + let value_builder = value_array.into_builder(); + + match (key_builder, value_builder) { + (Ok(key_builder), Ok(value_builder)) => Ok(unsafe { + PrimitiveDictionaryBuilder::new_from_builders(key_builder, value_builder) + }), + (Err(key_array), Ok(mut value_builder)) => { + Err(Self::try_new(key_array, Arc::new(value_builder.finish())).unwrap()) + } + (Ok(mut key_builder), Err(value_array)) => { + Err(Self::try_new(key_builder.finish(), Arc::new(value_array)).unwrap()) + } + (Err(key_array), Err(value_array)) => { + Err(Self::try_new(key_array, Arc::new(value_array)).unwrap()) + } + } + } + + /// Applies an unary and infallible function to a mutable dictionary array. + /// Mutable dictionary array means that the buffers are not shared with other arrays. + /// As a result, this mutates the buffers directly without allocating new buffers. + /// + /// # Implementation + /// + /// This will apply the function for all dictionary values, including those on null slots. + /// This implies that the operation must be infallible for any value of the corresponding type + /// or this function may panic. + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::{Array, ArrayAccessor, DictionaryArray, StringArray, types::{Int8Type, Int32Type}}; + /// # use arrow_array::{Int8Array, Int32Array}; + /// let values = Int32Array::from(vec![Some(10), Some(20), None]); + /// let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + /// let dictionary = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + /// let c = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap(); + /// let typed = c.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), 11); + /// assert_eq!(typed.value(1), 11); + /// assert_eq!(typed.value(2), 21); + /// ``` + pub fn unary_mut(self, op: F) -> Result, DictionaryArray> + where + V: ArrowPrimitiveType, + F: Fn(V::Native) -> V::Native, + { + let mut builder: PrimitiveDictionaryBuilder = self.into_primitive_dict_builder()?; + builder + .values_slice_mut() + .iter_mut() + .for_each(|v| *v = op(*v)); + Ok(builder.finish()) + } + + /// Computes an occupancy mask for this dictionary's values + /// + /// For each value in [`Self::values`] the corresponding bit will be set in the + /// returned mask if it is referenced by a key in this [`DictionaryArray`] + pub fn occupancy(&self) -> BooleanBuffer { + let len = self.values.len(); + let mut builder = BooleanBufferBuilder::new(len); + builder.resize(len); + let slice = builder.as_slice_mut(); + match self.keys.nulls().filter(|n| n.null_count() > 0) { + Some(n) => { + let v = self.keys.values(); + n.valid_indices() + .for_each(|idx| set_bit(slice, v[idx].as_usize())) + } + None => { + let v = self.keys.values(); + v.iter().for_each(|v| set_bit(slice, v.as_usize())) + } + } + builder.finish() + } +} + +/// Constructs a `DictionaryArray` from an array data reference. +impl From for DictionaryArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.buffers().len(), + 1, + "DictionaryArray data should contain a single buffer only (keys)." + ); + assert_eq!( + data.child_data().len(), + 1, + "DictionaryArray should contain a single child array (values)." + ); + + if let DataType::Dictionary(key_data_type, _) = data.data_type() { + assert_eq!( + &T::DATA_TYPE, + key_data_type.as_ref(), + "DictionaryArray's data type must match, expected {} got {}", + T::DATA_TYPE, + key_data_type + ); + + let values = make_array(data.child_data()[0].clone()); + let data_type = data.data_type().clone(); + + // create a zero-copy of the keys' data + // SAFETY: + // ArrayData is valid and verified type above + + let keys = PrimitiveArray::::from(unsafe { + data.into_builder() + .data_type(T::DATA_TYPE) + .child_data(vec![]) + .build_unchecked() + }); + + Self { + data_type, + keys, + values, + is_ordered: false, + } + } else { + panic!("DictionaryArray must have Dictionary data type.") + } + } +} + +impl From> for ArrayData { + fn from(array: DictionaryArray) -> Self { + let builder = array + .keys + .into_data() + .into_builder() + .data_type(array.data_type) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +/// Constructs a `DictionaryArray` from an iterator of optional strings. +/// +/// # Example: +/// ``` +/// use arrow_array::{DictionaryArray, PrimitiveArray, StringArray, types::Int8Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: DictionaryArray = test +/// .iter() +/// .map(|&x| if x == "b" { None } else { Some(x) }) +/// .collect(); +/// assert_eq!( +/// "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n null,\n 1,\n] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: ArrowDictionaryKeyType> FromIterator> for DictionaryArray { + fn from_iter>>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringDictionaryBuilder::with_capacity(lower, 256, 1024); + builder.extend(it); + builder.finish() + } +} + +/// Constructs a `DictionaryArray` from an iterator of strings. +/// +/// # Example: +/// +/// ``` +/// use arrow_array::{DictionaryArray, PrimitiveArray, StringArray, types::Int8Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: DictionaryArray = test.into_iter().collect(); +/// assert_eq!( +/// "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 1,\n 2,\n] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray { + fn from_iter>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringDictionaryBuilder::with_capacity(lower, 256, 1024); + it.for_each(|i| { + builder + .append(i) + .expect("Unable to append a value to a dictionary array."); + }); + + builder.finish() + } +} + +impl Array for DictionaryArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.keys.len() + } + + fn is_empty(&self) -> bool { + self.keys.is_empty() + } + + fn offset(&self) -> usize { + self.keys.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.keys.nulls() + } + + fn logical_nulls(&self) -> Option { + match self.values.nulls() { + None => self.nulls().cloned(), + Some(value_nulls) => { + let mut builder = BooleanBufferBuilder::new(self.len()); + match self.keys.nulls() { + Some(n) => builder.append_buffer(n.inner()), + None => builder.append_n(self.len(), true), + } + for (idx, k) in self.keys.values().iter().enumerate() { + let k = k.as_usize(); + // Check range to allow for nulls + if k < value_nulls.len() && value_nulls.is_null(k) { + builder.set_bit(idx, false); + } + } + Some(builder.finish().into()) + } + } + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && (self.nulls().is_some() || self.values.is_nullable()) + } + + fn get_buffer_memory_size(&self) -> usize { + self.keys.get_buffer_memory_size() + self.values.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + + self.keys.get_buffer_memory_size() + + self.values.get_array_memory_size() + } +} + +impl std::fmt::Debug for DictionaryArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!( + f, + "DictionaryArray {{keys: {:?} values: {:?}}}", + self.keys, self.values + ) + } +} + +/// A [`DictionaryArray`] typed on its child values array +/// +/// Implements [`ArrayAccessor`] allowing fast access to its elements +/// +/// ``` +/// use arrow_array::{DictionaryArray, StringArray, types::Int32Type}; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let dictionary = DictionaryArray::::from_iter(orig); +/// +/// // `TypedDictionaryArray` allows you to access the values directly +/// let typed = dictionary.downcast_dict::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +pub struct TypedDictionaryArray<'a, K: ArrowDictionaryKeyType, V> { + /// The dictionary array + dictionary: &'a DictionaryArray, + /// The values of the dictionary + values: &'a V, +} + +// Manually implement `Clone` to avoid `V: Clone` type constraint +impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, K: ArrowDictionaryKeyType, V> Copy for TypedDictionaryArray<'a, K, V> {} + +impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug for TypedDictionaryArray<'a, K, V> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "TypedDictionaryArray({:?})", self.dictionary) + } +} + +impl<'a, K: ArrowDictionaryKeyType, V> TypedDictionaryArray<'a, K, V> { + /// Returns the keys of this [`TypedDictionaryArray`] + pub fn keys(&self) -> &'a PrimitiveArray { + self.dictionary.keys() + } + + /// Returns the values of this [`TypedDictionaryArray`] + pub fn values(&self) -> &'a V { + self.values + } +} + +impl<'a, K: ArrowDictionaryKeyType, V: Sync> Array for TypedDictionaryArray<'a, K, V> { + fn as_any(&self) -> &dyn Any { + self.dictionary + } + + fn to_data(&self) -> ArrayData { + self.dictionary.to_data() + } + + fn into_data(self) -> ArrayData { + self.dictionary.into_data() + } + + fn data_type(&self) -> &DataType { + self.dictionary.data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.dictionary.slice(offset, length)) + } + + fn len(&self) -> usize { + self.dictionary.len() + } + + fn is_empty(&self) -> bool { + self.dictionary.is_empty() + } + + fn offset(&self) -> usize { + self.dictionary.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.dictionary.nulls() + } + + fn logical_nulls(&self) -> Option { + self.dictionary.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.dictionary.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.dictionary.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.dictionary.get_array_memory_size() + } +} + +impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V> +where + K: ArrowDictionaryKeyType, + Self: ArrayAccessor, +{ + type Item = Option<::Item>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V> +where + K: ArrowDictionaryKeyType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, index: usize) -> Self::Item { + assert!( + index < self.len(), + "Trying to access an element at index {} from a TypedDictionaryArray of length {}", + index, + self.len() + ); + unsafe { self.value_unchecked(index) } + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + let val = self.dictionary.keys.value_unchecked(index); + let value_idx = val.as_usize(); + + // As dictionary keys are only verified for non-null indexes + // we must check the value is within bounds + match value_idx < self.values.len() { + true => self.values.value_unchecked(value_idx), + false => Default::default(), + } + } +} + +/// A [`DictionaryArray`] with the key type erased +/// +/// This can be used to efficiently implement kernels for all possible dictionary +/// keys without needing to create specialized implementations for each key type +/// +/// For example +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::*; +/// # use arrow_schema::ArrowError; +/// # use std::sync::Arc; +/// +/// fn to_string(a: &dyn Array) -> Result { +/// if let Some(d) = a.as_any_dictionary_opt() { +/// // Recursively handle dictionary input +/// let r = to_string(d.values().as_ref())?; +/// return Ok(d.with_values(r)); +/// } +/// downcast_primitive_array! { +/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x| format!("{x:?}"))).collect::())), +/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not supported"))) +/// } +/// } +/// +/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap(); +/// let actual = result.as_string::().iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "2", "3"]); +/// +/// let mut dict = PrimitiveDictionaryBuilder::::new(); +/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]); +/// let dict = dict.finish(); +/// +/// let r = to_string(&dict).unwrap(); +/// let r = r.as_dictionary::().downcast_dict::().unwrap(); +/// assert_eq!(r.keys(), dict.keys()); // Keys are the same +/// +/// let actual = r.into_iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "1", "2", "3", "2"]); +/// ``` +/// +/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`] +pub trait AnyDictionaryArray: Array { + /// Returns the primitive keys of this dictionary as an [`Array`] + fn keys(&self) -> &dyn Array; + + /// Returns the values of this dictionary + fn values(&self) -> &ArrayRef; + + /// Returns the keys of this dictionary as usize + /// + /// The values for nulls will be arbitrary, but are guaranteed + /// to be in the range `0..self.values.len()` + /// + /// # Panic + /// + /// Panics if `values.len() == 0` + fn normalized_keys(&self) -> Vec; + + /// Create a new [`DictionaryArray`] replacing `values` with the new values + /// + /// See [`DictionaryArray::with_values`] + fn with_values(&self, values: ArrayRef) -> ArrayRef; +} + +impl AnyDictionaryArray for DictionaryArray { + fn keys(&self) -> &dyn Array { + &self.keys + } + + fn values(&self) -> &ArrayRef { + self.values() + } + + fn normalized_keys(&self) -> Vec { + let v_len = self.values().len(); + assert_ne!(v_len, 0); + let iter = self.keys().values().iter(); + iter.map(|x| x.as_usize().min(v_len - 1)).collect() + } + + fn with_values(&self, values: ArrayRef) -> ArrayRef { + Arc::new(self.with_values(values)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast::as_dictionary_array; + use crate::{Int16Array, Int32Array, Int8Array}; + use arrow_buffer::{Buffer, ToByteSlice}; + + #[test] + fn test_dictionary_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int8) + .len(8) + .add_buffer(Buffer::from( + [10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + let keys = Buffer::from([2_i16, 3, 4].to_byte_slice()); + + // Construct a dictionary array from the above two + let key_type = DataType::Int16; + let value_type = DataType::Int8; + let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + let dict_data = ArrayData::builder(dict_data_type.clone()) + .len(3) + .add_buffer(keys.clone()) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let dict_array = Int16DictionaryArray::from(dict_data); + + let values = dict_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int8, dict_array.value_type()); + assert_eq!(3, dict_array.len()); + + // Null count only makes sense in terms of the component arrays. + assert_eq!(0, dict_array.null_count()); + assert_eq!(0, dict_array.values().null_count()); + assert_eq!(dict_array.keys(), &Int16Array::from(vec![2_i16, 3, 4])); + + // Now test with a non-zero offset + let dict_data = ArrayData::builder(dict_data_type) + .len(2) + .offset(1) + .add_buffer(keys) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let dict_array = Int16DictionaryArray::from(dict_data); + + let values = dict_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int8, dict_array.value_type()); + assert_eq!(2, dict_array.len()); + assert_eq!(dict_array.keys(), &Int16Array::from(vec![3_i16, 4])); + } + + #[test] + fn test_dictionary_array_fmt_debug() { + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); + builder.append(12345678).unwrap(); + builder.append_null(); + builder.append(22345678).unwrap(); + let array = builder.finish(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n null,\n 1,\n] values: PrimitiveArray\n[\n 12345678,\n 22345678,\n]}\n", + format!("{array:?}") + ); + + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(20, 2); + for _ in 0..20 { + builder.append(1).unwrap(); + } + let array = builder.finish(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n] values: PrimitiveArray\n[\n 1,\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_dictionary_array_from_iter() { + let test = vec!["a", "a", "b", "c"]; + let array: DictionaryArray = test + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n null,\n 1,\n] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + + let array: DictionaryArray = test.into_iter().collect(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 1,\n 2,\n] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_dictionary_array_reverse_lookup_key() { + let test = vec!["a", "a", "b", "c"]; + let array: DictionaryArray = test.into_iter().collect(); + + assert_eq!(array.lookup_key("c"), Some(2)); + + // Direction of building a dictionary is the iterator direction + let test = vec!["t3", "t3", "t2", "t2", "t1", "t3", "t4", "t1", "t0"]; + let array: DictionaryArray = test.into_iter().collect(); + + assert_eq!(array.lookup_key("t1"), Some(2)); + assert_eq!(array.lookup_key("non-existent"), None); + } + + #[test] + fn test_dictionary_keys_as_primitive_array() { + let test = vec!["a", "b", "c", "a"]; + let array: DictionaryArray = test.into_iter().collect(); + + let keys = array.keys(); + assert_eq!(&DataType::Int8, keys.data_type()); + assert_eq!(0, keys.null_count()); + assert_eq!(&[0, 1, 2, 0], keys.values()); + } + + #[test] + fn test_dictionary_keys_as_primitive_array_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: DictionaryArray = test.into_iter().collect(); + + let keys = array.keys(); + assert_eq!(&DataType::Int32, keys.data_type()); + assert_eq!(3, keys.null_count()); + + assert!(keys.is_valid(0)); + assert!(!keys.is_valid(1)); + assert!(keys.is_valid(2)); + assert!(!keys.is_valid(3)); + assert!(!keys.is_valid(4)); + assert!(keys.is_valid(5)); + + assert_eq!(0, keys.value(0)); + assert_eq!(1, keys.value(2)); + assert_eq!(0, keys.value(5)); + } + + #[test] + fn test_dictionary_all_nulls() { + let test = vec![None, None, None]; + let array: DictionaryArray = test.into_iter().collect(); + array + .into_data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_dictionary_iter() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int16Array::from_iter_values([2_i16, 3, 4]); + + // Construct a dictionary array from the above two + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let mut key_iter = dict_array.keys_iter(); + assert_eq!(2, key_iter.next().unwrap().unwrap()); + assert_eq!(3, key_iter.next().unwrap().unwrap()); + assert_eq!(4, key_iter.next().unwrap().unwrap()); + assert!(key_iter.next().is_none()); + + let mut iter = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(dict_array.keys_iter()); + + assert_eq!(12, iter.next().unwrap().unwrap()); + assert_eq!(13, iter.next().unwrap().unwrap()); + assert_eq!(14, iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_dictionary_iter_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: DictionaryArray = test.into_iter().collect(); + + let mut iter = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(array.keys_iter()); + + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("b", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_dictionary_key() { + let keys = Int8Array::from(vec![Some(2), None, Some(1)]); + let values = StringArray::from(vec!["foo", "bar", "baz", "blarg"]); + + let array = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(array.key(0), Some(2)); + assert_eq!(array.key(1), None); + assert_eq!(array.key(2), Some(1)); + } + + #[test] + fn test_try_new() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect(); + + let array = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(array.keys().data_type(), &DataType::Int32); + assert_eq!(array.values().data_type(), &DataType::Utf8); + + assert_eq!(array.null_count(), 1); + + assert!(array.keys().is_valid(0)); + assert!(array.keys().is_valid(1)); + assert!(array.keys().is_null(2)); + assert!(array.keys().is_valid(3)); + + assert_eq!(array.keys().value(0), 0); + assert_eq!(array.keys().value(1), 2); + assert_eq!(array.keys().value(3), 1); + + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 2,\n null,\n 1,\n] values: StringArray\n[\n \"foo\",\n \"bar\",\n \"baz\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + #[should_panic(expected = "Invalid dictionary key 3 at index 1, expected 0 <= key < 2")] + fn test_try_new_index_too_large() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + // dictionary only has 2 values, so offset 3 is out of bounds + let keys: Int32Array = [Some(0), Some(3)].into_iter().collect(); + DictionaryArray::new(keys, Arc::new(values)); + } + + #[test] + #[should_panic(expected = "Invalid dictionary key -100 at index 0, expected 0 <= key < 2")] + fn test_try_new_index_too_small() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + let keys: Int32Array = [Some(-100)].into_iter().collect(); + DictionaryArray::new(keys, Arc::new(values)); + } + + #[test] + #[should_panic(expected = "DictionaryArray's data type must match, expected Int64 got Int32")] + fn test_from_array_data_validation() { + let a = DictionaryArray::::from_iter(["32"]); + let _ = DictionaryArray::::from(a.into_data()); + } + + #[test] + fn test_into_primitive_dict_builder() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let boxed: ArrayRef = Arc::new(dict_array); + let col: DictionaryArray = as_dictionary_array(&boxed).clone(); + + drop(boxed); + + let mut builder = col.into_primitive_dict_builder::().unwrap(); + + let slice = builder.values_slice_mut(); + assert_eq!(slice, &[10, 12, 15]); + + slice[0] = 4; + slice[1] = 2; + slice[2] = 1; + + let values = Int32Array::from_iter_values([4_i32, 2, 1]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::new(keys, Arc::new(values)); + + let new_array = builder.finish(); + assert_eq!(expected, new_array); + } + + #[test] + fn test_into_primitive_dict_builder_cloned_array() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let boxed: ArrayRef = Arc::new(dict_array); + + let col: DictionaryArray = DictionaryArray::::from(boxed.to_data()); + let err = col.into_primitive_dict_builder::(); + + let returned = err.unwrap_err(); + + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(expected, returned); + } + + #[test] + fn test_occupancy() { + let keys = Int32Array::new((100..200).collect(), None); + let values = Int32Array::from(vec![0; 1024]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + for (idx, v) in dict.occupancy().iter().enumerate() { + let expected = (100..200).contains(&idx); + assert_eq!(v, expected, "{idx}"); + } + + let keys = Int32Array::new( + (0..100).collect(), + Some((0..100).map(|x| x % 4 == 0).collect()), + ); + let values = Int32Array::from(vec![0; 1024]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + for (idx, v) in dict.occupancy().iter().enumerate() { + let expected = idx % 4 == 0 && idx < 100; + assert_eq!(v, expected, "{idx}"); + } + } + + #[test] + fn test_iterator_nulls() { + let keys = Int32Array::new( + vec![0, 700, 1, 2].into(), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let values = Int32Array::from(vec![Some(50), None, Some(2)]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + let values: Vec<_> = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect(); + assert_eq!(values, &[Some(50), None, None, Some(2)]) + } + + #[test] + fn test_normalized_keys() { + let values = vec![132, 0, 1].into(); + let nulls = NullBuffer::from(vec![false, true, true]); + let keys = Int32Array::new(values, Some(nulls)); + let dictionary = DictionaryArray::new(keys, Arc::new(Int32Array::new_null(2))); + assert_eq!(&dictionary.normalized_keys(), &[1, 0, 1]) + } +} diff --git a/arrow/src/array/array_fixed_size_binary.rs b/arrow-array/src/array/fixed_size_binary_array.rs similarity index 57% rename from arrow/src/array/array_fixed_size_binary.rs rename to arrow-array/src/array/fixed_size_binary_array.rs index 22eac1435a8d..e393e2b15ae6 100644 --- a/arrow/src/array/array_fixed_size_binary.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -15,27 +15,24 @@ // specific language governing permissions and limitations // under the License. +use crate::array::print_long_array; +use crate::iterator::FixedSizeBinaryIter; +use crate::{Array, ArrayAccessor, ArrayRef, FixedSizeListArray, Scalar}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; -use std::convert::From; -use std::fmt; - -use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, -}; -use crate::array::{ArrayAccessor, FixedSizeBinaryIter}; -use crate::buffer::Buffer; -use crate::error::{ArrowError, Result}; -use crate::util::bit_util; -use crate::{buffer::MutableBuffer, datatypes::DataType}; - -/// An array where each element is a fixed-size sequence of bytes. +use std::sync::Arc; + +/// An array of [fixed size binary arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout) /// /// # Examples /// /// Create an array from an iterable argument of byte slices. /// /// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// use arrow_array::{Array, FixedSizeBinaryArray}; /// let input_arg = vec![ vec![1, 2], vec![3, 4], vec![5, 6] ]; /// let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); /// @@ -45,31 +42,109 @@ use crate::{buffer::MutableBuffer, datatypes::DataType}; /// Create an array from an iterable argument of sparse byte slices. /// Sparsity means that the input argument can contain `None` items. /// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// use arrow_array::{Array, FixedSizeBinaryArray}; /// let input_arg = vec![ None, Some(vec![7, 8]), Some(vec![9, 10]), None, Some(vec![13, 14]) ]; -/// let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); +/// let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); /// assert_eq!(5, arr.len()) /// /// ``` /// +#[derive(Clone)] pub struct FixedSizeBinaryArray { - data: ArrayData, - value_data: RawPtrBox, - length: i32, + data_type: DataType, // Must be DataType::FixedSizeBinary(value_length) + value_data: Buffer, + nulls: Option, + len: usize, + value_length: i32, } impl FixedSizeBinaryArray { + /// Create a new [`FixedSizeBinaryArray`] with `size` element size, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(size: i32, values: Buffer, nulls: Option) -> Self { + Self::try_new(size, values, nulls).unwrap() + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: impl AsRef<[u8]>) -> Scalar { + let v = value.as_ref(); + Scalar::new(Self::new(v.len() as _, Buffer::from(v), None)) + } + + /// Create a new [`FixedSizeBinaryArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `size < 0` + /// * `values.len() / size != nulls.len()` + pub fn try_new( + size: i32, + values: Buffer, + nulls: Option, + ) -> Result { + let data_type = DataType::FixedSizeBinary(size); + let s = size.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + })?; + + let len = values.len() / s; + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for FixedSizeBinaryArray, expected {} got {}", + len, + n.len(), + ))); + } + } + + Ok(Self { + data_type, + value_data: values, + value_length: size, + nulls, + len, + }) + } + + /// Create a new [`FixedSizeBinaryArray`] of length `len` where all values are null + /// + /// # Panics + /// + /// Panics if + /// + /// * `size < 0` + /// * `size * len` would overflow `usize` + pub fn new_null(size: i32, len: usize) -> Self { + let capacity = size.to_usize().unwrap().checked_mul(len).unwrap(); + Self { + data_type: DataType::FixedSizeBinary(size), + value_data: MutableBuffer::new(capacity).into(), + nulls: Some(NullBuffer::new_null(len)), + value_length: size, + len, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (i32, Buffer, Option) { + (self.value_length, self.value_data, self.nulls) + } + /// Returns the element at index `i` as a byte slice. /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &[u8] { assert!( - i < self.data.len(), + i < self.len(), "Trying to access an element at index {} from a FixedSizeBinaryArray of length {}", i, self.len() ); - let offset = i + self.data.offset(); + let offset = i + self.offset(); unsafe { let pos = self.value_offset_at(offset); std::slice::from_raw_parts( @@ -83,7 +158,7 @@ impl FixedSizeBinaryArray { /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { - let offset = i + self.data.offset(); + let offset = i + self.offset(); let pos = self.value_offset_at(offset); std::slice::from_raw_parts( self.value_data.as_ptr().offset(pos as isize), @@ -96,7 +171,7 @@ impl FixedSizeBinaryArray { /// Note this doesn't do any bound checking, for performance reason. #[inline] pub fn value_offset(&self, i: usize) -> i32 { - self.value_offset_at(self.data.offset() + i) + self.value_offset_at(self.offset() + i) } /// Returns the length for an element. @@ -104,12 +179,39 @@ impl FixedSizeBinaryArray { /// All elements have the same length as the array is a fixed size. #[inline] pub fn value_length(&self) -> i32 { - self.length + self.value_length } - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[0].clone() + /// Returns the values of this array. + /// + /// Unlike [`Self::value_data`] this returns the [`Buffer`] + /// allowing for zero-copy cloning. + #[inline] + pub fn values(&self) -> &Buffer { + &self.value_data + } + + /// Returns the raw value data. + pub fn value_data(&self) -> &[u8] { + self.value_data.as_slice() + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced FixedSizeBinaryArray cannot exceed the existing length" + ); + + let size = self.value_length as usize; + + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + value_length: self.value_length, + value_data: self.value_data.slice_with_length(offset * size, len * size), + len, + } } /// Create an array from an iterable argument of sparse byte slices. @@ -119,7 +221,7 @@ impl FixedSizeBinaryArray { /// # Examples /// /// ``` - /// use arrow::array::FixedSizeBinaryArray; + /// use arrow_array::FixedSizeBinaryArray; /// let input_arg = vec![ /// None, /// Some(vec![7, 8]), @@ -134,7 +236,10 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_sparse_iter(mut iter: T) -> Result + #[deprecated( + note = "This function will fail if the iterator produces only None values; prefer `try_from_sparse_iter_with_size`" + )] + pub fn try_from_sparse_iter(mut iter: T) -> Result where T: Iterator>, U: AsRef<[u8]>, @@ -142,10 +247,13 @@ impl FixedSizeBinaryArray { let mut len = 0; let mut size = None; let mut byte = 0; - let mut null_buf = MutableBuffer::from_len_zeroed(0); - let mut buffer = MutableBuffer::from_len_zeroed(0); + + let iter_size_hint = iter.size_hint().0; + let mut null_buf = MutableBuffer::new(bit_util::ceil(iter_size_hint, 8)); + let mut buffer = MutableBuffer::new(0); + let mut prepend = 0; - iter.try_for_each(|item| -> Result<()> { + iter.try_for_each(|item| -> Result<(), ArrowError> { // extend null bitmask by one byte per each 8 items if byte == 0 { null_buf.push(0u8); @@ -164,7 +272,12 @@ impl FixedSizeBinaryArray { ))); } } else { - size = Some(slice.len()); + let len = slice.len(); + size = Some(len); + // Now that we know how large each element is we can reserve + // sufficient capacity in the underlying mutable buffer for + // the data. + buffer.reserve(iter_size_hint * len); buffer.extend_zeros(slice.len() * prepend); } bit_util::set_bit(null_buf.as_slice_mut(), len); @@ -186,19 +299,94 @@ impl FixedSizeBinaryArray { )); } - let size = size.unwrap_or(0); - let array_data = unsafe { - ArrayData::new_unchecked( - DataType::FixedSizeBinary(size as i32), - len, - None, - Some(null_buf.into()), - 0, - vec![buffer.into()], - vec![], - ) - }; - Ok(FixedSizeBinaryArray::from(array_data)) + let null_buf = BooleanBuffer::new(null_buf.into(), 0, len); + let nulls = Some(NullBuffer::new(null_buf)).filter(|n| n.null_count() > 0); + + let size = size.unwrap_or(0) as i32; + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls, + value_length: size, + len, + }) + } + + /// Create an array from an iterable argument of sparse byte slices. + /// Sparsity means that items returned by the iterator are optional, i.e input argument can + /// contain `None` items. In cases where the iterator returns only `None` values, this + /// also takes a size parameter to ensure that the a valid FixedSizeBinaryArray is still + /// created. + /// + /// # Examples + /// + /// ``` + /// use arrow_array::FixedSizeBinaryArray; + /// let input_arg = vec![ + /// None, + /// Some(vec![7, 8]), + /// Some(vec![9, 10]), + /// None, + /// Some(vec![13, 14]), + /// None, + /// ]; + /// let array = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); + /// ``` + /// + /// # Errors + /// + /// Returns error if argument has length zero, or sizes of nested slices don't match. + pub fn try_from_sparse_iter_with_size(mut iter: T, size: i32) -> Result + where + T: Iterator>, + U: AsRef<[u8]>, + { + let mut len = 0; + let mut byte = 0; + + let iter_size_hint = iter.size_hint().0; + let mut null_buf = MutableBuffer::new(bit_util::ceil(iter_size_hint, 8)); + let mut buffer = MutableBuffer::new(iter_size_hint * (size as usize)); + + iter.try_for_each(|item| -> Result<(), ArrowError> { + // extend null bitmask by one byte per each 8 items + if byte == 0 { + null_buf.push(0u8); + byte = 8; + } + byte -= 1; + + if let Some(slice) = item { + let slice = slice.as_ref(); + if size as usize != slice.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Nested array size mismatch: one is {}, and the other is {}", + size, + slice.len() + ))); + } + + bit_util::set_bit(null_buf.as_slice_mut(), len); + buffer.extend_from_slice(slice); + } else { + buffer.extend_zeros(size as usize); + } + + len += 1; + + Ok(()) + })?; + + let null_buf = BooleanBuffer::new(null_buf.into(), 0, len); + let nulls = Some(NullBuffer::new(null_buf)).filter(|n| n.null_count() > 0); + + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls, + len, + value_length: size, + }) } /// Create an array from an iterable argument of byte slices. @@ -206,7 +394,7 @@ impl FixedSizeBinaryArray { /// # Examples /// /// ``` - /// use arrow::array::FixedSizeBinaryArray; + /// use arrow_array::FixedSizeBinaryArray; /// let input_arg = vec![ /// vec![1, 2], /// vec![3, 4], @@ -218,15 +406,17 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_iter(mut iter: T) -> Result + pub fn try_from_iter(mut iter: T) -> Result where T: Iterator, U: AsRef<[u8]>, { let mut len = 0; let mut size = None; - let mut buffer = MutableBuffer::from_len_zeroed(0); - iter.try_for_each(|item| -> Result<()> { + let iter_size_hint = iter.size_hint().0; + let mut buffer = MutableBuffer::new(0); + + iter.try_for_each(|item| -> Result<(), ArrowError> { let slice = item.as_ref(); if let Some(size) = size { if size != slice.len() { @@ -237,8 +427,11 @@ impl FixedSizeBinaryArray { ))); } } else { - size = Some(slice.len()); + let len = slice.len(); + size = Some(len); + buffer.reserve(iter_size_hint * len); } + buffer.extend_from_slice(slice); len += 1; @@ -252,17 +445,19 @@ impl FixedSizeBinaryArray { )); } - let size = size.unwrap_or(0); - let array_data = ArrayData::builder(DataType::FixedSizeBinary(size as i32)) - .len(len) - .add_buffer(buffer.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Ok(FixedSizeBinaryArray::from(array_data)) + let size = size.unwrap_or(0).try_into().unwrap(); + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls: None, + value_length: size, + len, + }) } #[inline] fn value_offset_at(&self, i: usize) -> i32 { - self.length * i as i32 + self.value_length * i as i32 } /// constructs a new iterator @@ -278,35 +473,48 @@ impl From for FixedSizeBinaryArray { 1, "FixedSizeBinaryArray data should contain 1 buffer only (values)" ); - let value_data = data.buffers()[0].as_ptr(); - let length = match data.data_type() { + let value_length = match data.data_type() { DataType::FixedSizeBinary(len) => *len, _ => panic!("Expected data type to be FixedSizeBinary"), }; + + let size = value_length as usize; + let value_data = + data.buffers()[0].slice_with_length(data.offset() * size, data.len() * size); + Self { - data, - value_data: unsafe { RawPtrBox::new(value_data) }, - length, + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + len: data.len(), + value_data, + value_length, } } } impl From for ArrayData { fn from(array: FixedSizeBinaryArray) -> Self { - array.data + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .buffers(vec![array.value_data]) + .nulls(array.nulls); + + unsafe { builder.build_unchecked() } } } /// Creates a `FixedSizeBinaryArray` from `FixedSizeList` array impl From for FixedSizeBinaryArray { fn from(v: FixedSizeListArray) -> Self { + let value_len = v.value_length(); + let v = v.into_data(); assert_eq!( - v.data_ref().child_data().len(), + v.child_data().len(), 1, "FixedSizeBinaryArray can only be created from list array of u8 values \ (i.e. FixedSizeList>)." ); - let child_data = &v.data_ref().child_data()[0]; + let child_data = &v.child_data()[0]; assert_eq!( child_data.child_data().len(), @@ -325,11 +533,11 @@ impl From for FixedSizeBinaryArray { "The child array cannot contain null values." ); - let builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) + let builder = ArrayData::builder(DataType::FixedSizeBinary(value_len)) .len(v.len()) .offset(v.offset()) .add_buffer(child_data.buffers()[0].slice(child_data.offset())) - .null_bit_buffer(v.data_ref().null_buffer().cloned()); + .nulls(v.nulls().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) @@ -338,6 +546,7 @@ impl From for FixedSizeBinaryArray { impl From>> for FixedSizeBinaryArray { fn from(v: Vec>) -> Self { + #[allow(deprecated)] Self::try_from_sparse_iter(v.into_iter()).unwrap() } } @@ -348,11 +557,17 @@ impl From> for FixedSizeBinaryArray { } } -impl fmt::Debug for FixedSizeBinaryArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl From> for FixedSizeBinaryArray { + fn from(v: Vec<&[u8; N]>) -> Self { + Self::try_from_iter(v.into_iter()).unwrap() + } +} + +impl std::fmt::Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "FixedSizeBinaryArray<{}>\n[\n", self.value_length())?; print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) + std::fmt::Debug::fmt(&array.value(index), f) })?; write!(f, "]") } @@ -363,13 +578,49 @@ impl Array for FixedSizeBinaryArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.value_data.capacity(); + if let Some(n) = &self.nulls { + sum += n.buffer().capacity(); + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } } impl<'a> ArrayAccessor for &'a FixedSizeBinaryArray { @@ -395,12 +646,8 @@ impl<'a> IntoIterator for &'a FixedSizeBinaryArray { #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::{ - datatypes::{Field, Schema}, - record_batch::RecordBatch, - }; + use crate::RecordBatch; + use arrow_schema::{Field, Schema}; use super::*; @@ -452,9 +699,9 @@ mod tests { fixed_size_binary_array.value(1) ); assert_eq!(2, fixed_size_binary_array.len()); - assert_eq!(5, fixed_size_binary_array.value_offset(0)); + assert_eq!(0, fixed_size_binary_array.value_offset(0)); assert_eq!(5, fixed_size_binary_array.value_length()); - assert_eq!(10, fixed_size_binary_array.value_offset(1)); + assert_eq!(5, fixed_size_binary_array.value_offset(1)); } #[test] @@ -463,19 +710,19 @@ mod tests { let values_data = ArrayData::builder(DataType::UInt8) .len(12) .offset(2) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); // [null, [10, 11, 12, 13]] let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new("item", DataType::UInt8, false)), 4, )) .len(2) .offset(1) .add_child_data(values_data) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101]))) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b101]))) .build_unchecked() }; let list_array = FixedSizeListArray::from(array_data); @@ -499,13 +746,13 @@ mod tests { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Binary, false)), + Arc::new(Field::new("item", DataType::Binary, false)), 4, )) .len(3) @@ -522,14 +769,14 @@ mod tests { let values = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt8) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101010101010]))) + .add_buffer(Buffer::from_slice_ref(values)) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b101010101010]))) .build() .unwrap(); let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new("item", DataType::UInt8, false)), 4, )) .len(3) @@ -552,7 +799,7 @@ mod tests { let arr = FixedSizeBinaryArray::from(array_data); assert_eq!( "FixedSizeBinaryArray<5>\n[\n [104, 101, 108, 108, 111],\n [116, 104, 101, 114, 101],\n [97, 114, 114, 111, 119],\n]", - format!("{:?}", arr) + format!("{arr:?}") ); } @@ -569,8 +816,8 @@ mod tests { fn test_all_none_fixed_size_binary_array_from_sparse_iter() { let none_option: Option<[u8; 32]> = None; let input_arg = vec![none_option, none_option, none_option]; - let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + #[allow(deprecated)] + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); assert_eq!(0, arr.value_length()); assert_eq!(3, arr.len()) } @@ -584,9 +831,24 @@ mod tests { None, Some(vec![13, 14]), ]; + #[allow(deprecated)] + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.iter().cloned()).unwrap(); + assert_eq!(2, arr.value_length()); + assert_eq!(5, arr.len()); + let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); assert_eq!(2, arr.value_length()); + assert_eq!(5, arr.len()); + } + + #[test] + fn test_fixed_size_binary_array_from_sparse_iter_with_size_all_none() { + let input_arg = vec![None, None, None, None, None] as Vec>>; + + let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 16) + .unwrap(); + assert_eq!(16, arr.value_length()); assert_eq!(5, arr.len()) } @@ -651,25 +913,23 @@ mod tests { #[test] fn fixed_size_binary_array_all_null() { let data = vec![None] as Vec>; - let array = FixedSizeBinaryArray::try_from_sparse_iter(data.into_iter()).unwrap(); + let array = + FixedSizeBinaryArray::try_from_sparse_iter_with_size(data.into_iter(), 0).unwrap(); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } #[test] // Test for https://github.com/apache/arrow-rs/issues/1390 - #[should_panic( - expected = "column types must match schema types, expected FixedSizeBinary(2) but found FixedSizeBinary(0) at column index 0" - )] fn fixed_size_binary_array_all_null_in_batch_with_schema() { - let schema = - Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); + let schema = Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); let none_option: Option<[u8; 2]> = None; - let item = FixedSizeBinaryArray::try_from_sparse_iter( + let item = FixedSizeBinaryArray::try_from_sparse_iter_with_size( vec![none_option, none_option, none_option].into_iter(), + 2, ) .unwrap(); @@ -687,4 +947,31 @@ mod tests { array.value(4); } + + #[test] + fn test_constructors() { + let buffer = Buffer::from_vec(vec![0_u8; 10]); + let a = FixedSizeBinaryArray::new(2, buffer.clone(), None); + assert_eq!(a.len(), 5); + + let nulls = NullBuffer::new_null(5); + FixedSizeBinaryArray::new(2, buffer.clone(), Some(nulls)); + + let a = FixedSizeBinaryArray::new(3, buffer.clone(), None); + assert_eq!(a.len(), 3); + + let nulls = NullBuffer::new_null(3); + FixedSizeBinaryArray::new(3, buffer.clone(), Some(nulls)); + + let err = FixedSizeBinaryArray::try_new(-1, buffer.clone(), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Size cannot be negative, got -1" + ); + + let nulls = NullBuffer::new_null(3); + let err = FixedSizeBinaryArray::try_new(2, buffer, Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeBinaryArray, expected 5 got 3"); + } } diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs new file mode 100644 index 000000000000..0d57d9a690aa --- /dev/null +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -0,0 +1,693 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::{FixedSizeListBuilder, PrimitiveBuilder}; +use crate::iterator::FixedSizeListIter; +use crate::{make_array, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ArrowNativeType; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// An array of [fixed length lists], similar to JSON arrays +/// (e.g. `["A", "B"]`). +/// +/// Lists are represented using a `values` child +/// array where each list has a fixed size of `value_length`. +/// +/// Use [`FixedSizeListBuilder`] to construct a [`FixedSizeListArray`]. +/// +/// # Representation +/// +/// A [`FixedSizeListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `FixedSizeListArray` itself is +/// a list which may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, this `FixedSizeListArray` stores lists of strings: +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [C,NULL] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` of this `FixedSizeListArray`s are stored in a child +/// [`StringArray`] where logical null values take up `values_length` slots in the array +/// as shown in the following diagram. The logical values +/// are shown on the left, and the actual `FixedSizeListArray` encoding on the right +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─┐ +/// ┌─────────────┐ │ ┌───┐ ┌───┐ ┌──────┐ │ +/// │ [A,B] │ │ 1 │ │ │ 1 │ │ A │ │ 0 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ NULL │ │ 0 │ │ │ 1 │ │ B │ │ 1 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ [C,NULL] │ │ 1 │ │ │ 0 │ │ ???? │ │ 2 +/// └─────────────┘ │ └───┘ ├───┤ ├──────┤ │ +/// | │ 0 │ │ ???? │ │ 3 +/// Logical Values │ Validity ├───┤ ├──────┤ │ +/// (nulls) │ │ 1 │ │ C │ │ 4 +/// │ ├───┤ ├──────┤ │ +/// │ │ 0 │ │ ???? │ │ 5 +/// │ └───┘ └──────┘ │ +/// │ Values │ +/// │ FixedSizeListArray (Array) │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─┘ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ``` +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{Array, FixedSizeListArray, Int32Array}; +/// # use arrow_data::ArrayData; +/// # use arrow_schema::{DataType, Field}; +/// # use arrow_buffer::Buffer; +/// // Construct a value array +/// let value_data = ArrayData::builder(DataType::Int32) +/// .len(9) +/// .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) +/// .build() +/// .unwrap(); +/// let list_data_type = DataType::FixedSizeList( +/// Arc::new(Field::new("item", DataType::Int32, false)), +/// 3, +/// ); +/// let list_data = ArrayData::builder(list_data_type.clone()) +/// .len(3) +/// .add_child_data(value_data.clone()) +/// .build() +/// .unwrap(); +/// let list_array = FixedSizeListArray::from(list_data); +/// let list0 = list_array.value(0); +/// let list1 = list_array.value(1); +/// let list2 = list_array.value(2); +/// +/// assert_eq!( &[0, 1, 2], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); +/// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [fixed size arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout) +#[derive(Clone)] +pub struct FixedSizeListArray { + data_type: DataType, // Must be DataType::FixedSizeList(value_length) + values: ArrayRef, + nulls: Option, + value_length: i32, + len: usize, +} + +impl FixedSizeListArray { + /// Create a new [`FixedSizeListArray`] with `size` element size, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(field: FieldRef, size: i32, values: ArrayRef, nulls: Option) -> Self { + Self::try_new(field, size, values, nulls).unwrap() + } + + /// Create a new [`FixedSizeListArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `size < 0` + /// * `values.len() / size != nulls.len()` + /// * `values.data_type() != field.data_type()` + /// * `!field.is_nullable() && !nulls.expand(size).contains(values.logical_nulls())` + pub fn try_new( + field: FieldRef, + size: i32, + values: ArrayRef, + nulls: Option, + ) -> Result { + let s = size.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + })?; + + let len = match s { + 0 => nulls.as_ref().map(|x| x.len()).unwrap_or_default(), + _ => { + let len = values.len() / s.max(1); + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for FixedSizeListArray, expected {} got {}", + len, + n.len(), + ))); + } + } + len + } + }; + + if field.data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedSizeListArray expected data type {} got {} for {:?}", + field.data_type(), + values.data_type(), + field.name() + ))); + } + + if let Some(a) = values.logical_nulls() { + let nulls_valid = field.is_nullable() + || nulls + .as_ref() + .map(|n| n.expand(size as _).contains(&a)) + .unwrap_or_default() + || (nulls.is_none() && a.null_count() == 0); + + if !nulls_valid { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable FixedSizeListArray field {:?}", + field.name() + ))); + } + } + + let data_type = DataType::FixedSizeList(field, size); + Ok(Self { + data_type, + values, + value_length: size, + nulls, + len, + }) + } + + /// Create a new [`FixedSizeListArray`] of length `len` where all values are null + /// + /// # Panics + /// + /// Panics if + /// + /// * `size < 0` + /// * `size * len` would overflow `usize` + pub fn new_null(field: FieldRef, size: i32, len: usize) -> Self { + let capacity = size.to_usize().unwrap().checked_mul(len).unwrap(); + Self { + values: make_array(ArrayData::new_null(field.data_type(), capacity)), + data_type: DataType::FixedSizeList(field, size), + nulls: Some(NullBuffer::new_null(len)), + value_length: size, + len, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (FieldRef, i32, ArrayRef, Option) { + let f = match self.data_type { + DataType::FixedSizeList(f, _) => f, + _ => unreachable!(), + }; + (f, self.value_length, self.values, self.nulls) + } + + /// Returns a reference to the values of this list. + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_type().clone() + } + + /// Returns ith value of this list array. + pub fn value(&self, i: usize) -> ArrayRef { + self.values + .slice(self.value_offset_at(i), self.value_length() as usize) + } + + /// Returns the offset for value at index `i`. + /// + /// Note this doesn't do any bound checking, for performance reason. + #[inline] + pub fn value_offset(&self, i: usize) -> i32 { + self.value_offset_at(i) as i32 + } + + /// Returns the length for an element. + /// + /// All elements have the same length as the array is a fixed size. + #[inline] + pub const fn value_length(&self) -> i32 { + self.value_length + } + + #[inline] + const fn value_offset_at(&self, i: usize) -> usize { + i * self.value_length as usize + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced FixedSizeListArray cannot exceed the existing length" + ); + let size = self.value_length as usize; + + Self { + data_type: self.data_type.clone(), + values: self.values.slice(offset * size, len * size), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + value_length: self.value_length, + len, + } + } + + /// Creates a [`FixedSizeListArray`] from an iterator of primitive values + /// # Example + /// ``` + /// # use arrow_array::FixedSizeListArray; + /// # use arrow_array::types::Int32Type; + /// + /// let data = vec![ + /// Some(vec![Some(0), Some(1), Some(2)]), + /// None, + /// Some(vec![Some(3), None, Some(5)]), + /// Some(vec![Some(6), Some(7), Some(45)]), + /// ]; + /// let list_array = FixedSizeListArray::from_iter_primitive::(data, 3); + /// println!("{:?}", list_array); + /// ``` + pub fn from_iter_primitive(iter: I, length: i32) -> Self + where + T: ArrowPrimitiveType, + P: IntoIterator::Native>>, + I: IntoIterator>, + { + let l = length as usize; + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let mut builder = FixedSizeListBuilder::with_capacity( + PrimitiveBuilder::::with_capacity(size_hint * l), + length, + size_hint, + ); + + for i in iter { + match i { + Some(p) => { + for t in p { + builder.values().append_option(t); + } + builder.append(true); + } + None => { + builder.values().append_nulls(l); + builder.append(false) + } + } + } + builder.finish() + } + + /// constructs a new iterator + pub fn iter(&self) -> FixedSizeListIter<'_> { + FixedSizeListIter::new(self) + } +} + +impl From for FixedSizeListArray { + fn from(data: ArrayData) -> Self { + let value_length = match data.data_type() { + DataType::FixedSizeList(_, len) => *len, + _ => { + panic!("FixedSizeListArray data should contain a FixedSizeList data type") + } + }; + + let size = value_length as usize; + let values = + make_array(data.child_data()[0].slice(data.offset() * size, data.len() * size)); + Self { + data_type: data.data_type().clone(), + values, + nulls: data.nulls().cloned(), + value_length, + len: data.len(), + } + } +} + +impl From for ArrayData { + fn from(array: FixedSizeListArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .nulls(array.nulls) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for FixedSizeListArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.get_buffer_memory_size(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.values.get_array_memory_size(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl ArrayAccessor for FixedSizeListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } +} + +impl std::fmt::Debug for FixedSizeListArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "FixedSizeListArray<{}>\n[\n", self.value_length())?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl<'a> ArrayAccessor for &'a FixedSizeListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } +} + +#[cfg(test)] +mod tests { + use arrow_buffer::{bit_util, BooleanBuffer, Buffer}; + use arrow_schema::Field; + + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::{new_empty_array, Int32Array}; + + use super::*; + + #[test] + fn test_fixed_size_list_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data, list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + assert_eq!(0, list_array.value(0).as_primitive::().value(0)); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + + // Now test with a non-zero offset + let list_data = ArrayData::builder(list_data_type) + .len(2) + .offset(1) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data.slice(3, 6), list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(2, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(3, list_array.value(0).as_primitive::().value(0)); + assert_eq!(3, list_array.value_offset(1)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic(expected = "assertion failed: (offset + length) <= self.len()")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_fixed_size_list_array_unequal_children() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(FixedSizeListArray::from(list_data)); + } + + #[test] + fn test_fixed_size_list_array_slice() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data, list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(5, list_array.len()); + assert_eq!(2, list_array.null_count()); + assert_eq!(6, list_array.value_offset(3)); + assert_eq!(2, list_array.value_length()); + + let sliced_array = list_array.slice(1, 4); + assert_eq!(4, sliced_array.len()); + assert_eq!(2, sliced_array.null_count()); + + for i in 0..sliced_array.len() { + if bit_util::get_bit(&null_bits, 1 + i) { + assert!(sliced_array.is_valid(i)); + } else { + assert!(sliced_array.is_null(i)); + } + } + + // Check offset and length for each non-null value. + let sliced_list_array = sliced_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, sliced_list_array.value_length()); + assert_eq!(4, sliced_list_array.value_offset(2)); + assert_eq!(6, sliced_list_array.value_offset(3)); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_fixed_size_list_array_index_out_of_bound() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + list_array.value(10); + } + + #[test] + fn test_fixed_size_list_constructors() { + let values = Arc::new(Int32Array::from_iter([ + Some(1), + Some(2), + None, + None, + Some(3), + Some(4), + ])); + + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), None); + assert_eq!(list.len(), 3); + + let nulls = NullBuffer::new_null(3); + let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), Some(nulls)); + assert_eq!(list.len(), 3); + + let list = FixedSizeListArray::new(field.clone(), 4, values.clone(), None); + assert_eq!(list.len(), 1); + + let err = FixedSizeListArray::try_new(field.clone(), -1, values.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Size cannot be negative, got -1" + ); + + let list = FixedSizeListArray::new(field.clone(), 0, values.clone(), None); + assert_eq!(list.len(), 0); + + let nulls = NullBuffer::new_null(2); + let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeListArray, expected 3 got 2"); + + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\""); + + // Valid as nulls in child masked by parent + let nulls = NullBuffer::new(BooleanBuffer::new(Buffer::from([0b0000101]), 0, 3)); + FixedSizeListArray::new(field, 2, values.clone(), Some(nulls)); + + let field = Arc::new(Field::new("item", DataType::Int64, true)); + let err = FixedSizeListArray::try_new(field, 2, values, None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: FixedSizeListArray expected data type Int64 got Int32 for \"item\""); + } + + #[test] + fn empty_fixed_size_list() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let nulls = NullBuffer::new_null(2); + let values = new_empty_array(&DataType::Int32); + let list = FixedSizeListArray::new(field.clone(), 0, values, Some(nulls)); + assert_eq!(list.len(), 2); + } +} diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs new file mode 100644 index 000000000000..dc1b9f07da16 --- /dev/null +++ b/arrow-array/src/array/list_array.rs @@ -0,0 +1,1184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{get_offsets, make_array, print_long_array}; +use crate::builder::{GenericListBuilder, PrimitiveBuilder}; +use crate::{ + iterator::GenericListArrayIter, new_empty_array, Array, ArrayAccessor, ArrayRef, + ArrowPrimitiveType, FixedSizeListArray, +}; +use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, FieldRef}; +use num::Integer; +use std::any::Any; +use std::sync::Arc; + +/// A type that can be used within a variable-size array to encode offset information +/// +/// See [`ListArray`], [`LargeListArray`], [`BinaryArray`], [`LargeBinaryArray`], +/// [`StringArray`] and [`LargeStringArray`] +/// +/// [`BinaryArray`]: crate::array::BinaryArray +/// [`LargeBinaryArray`]: crate::array::LargeBinaryArray +/// [`StringArray`]: crate::array::StringArray +/// [`LargeStringArray`]: crate::array::LargeStringArray +pub trait OffsetSizeTrait: ArrowNativeType + std::ops::AddAssign + Integer { + /// True for 64 bit offset size and false for 32 bit offset size + const IS_LARGE: bool; + /// Prefix for the offset size + const PREFIX: &'static str; +} + +impl OffsetSizeTrait for i32 { + const IS_LARGE: bool = false; + const PREFIX: &'static str = ""; +} + +impl OffsetSizeTrait for i64 { + const IS_LARGE: bool = true; + const PREFIX: &'static str = "Large"; +} + +/// An array of [variable length lists], similar to JSON arrays +/// (e.g. `["A", "B", "C"]`). +/// +/// Lists are represented using `offsets` into a `values` child +/// array. Offsets are stored in two adjacent entries of an +/// [`OffsetBuffer`]. +/// +/// Arrow defines [`ListArray`] with `i32` offsets and +/// [`LargeListArray`] with `i64` offsets. +/// +/// Use [`GenericListBuilder`] to construct a [`GenericListArray`]. +/// +/// # Representation +/// +/// A [`ListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `ListArray` itself is +/// a list which may be empty, may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, the `ListArray` shown in the following diagram stores +/// lists of strings. Note that `[]` represents an empty (length +/// 0), but non NULL list. +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B,C] │ +/// ├─────────────┤ +/// │ [] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [D] │ +/// ├─────────────┤ +/// │ [NULL, F] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` are stored in a child [`StringArray`] and the offsets +/// are stored in an [`OffsetBuffer`] as shown in the following +/// diagram. The logical values and offsets are shown on the left, and +/// the actual `ListArray` encoding on the right. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌ ─ ─ ─ ─ ─ ─ ┐ │ +/// ┌─────────────┐ ┌───────┐ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ +/// │ [A,B,C] │ │ (0,3) │ │ 1 │ │ 0 │ │ │ 1 │ │ A │ │ 0 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [] │ │ (3,3) │ │ 1 │ │ 3 │ │ │ 1 │ │ B │ │ 1 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ NULL │ │ (3,4) │ │ 0 │ │ 3 │ │ │ 1 │ │ C │ │ 2 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [D] │ │ (4,5) │ │ 1 │ │ 4 │ │ │ ? │ │ ? │ │ 3 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [NULL, F] │ │ (5,7) │ │ 1 │ │ 5 │ │ │ 1 │ │ D │ │ 4 │ +/// └─────────────┘ └───────┘ │ └───┘ ├───┤ ├───┤ ├───┤ +/// │ 7 │ │ │ 0 │ │ ? │ │ 5 │ +/// │ Validity └───┘ ├───┤ ├───┤ +/// Logical Logical (nulls) Offsets │ │ 1 │ │ F │ │ 6 │ +/// Values Offsets │ └───┘ └───┘ +/// │ Values │ │ +/// (offsets[i], │ ListArray (Array) +/// offsets[i+1]) └ ─ ─ ─ ─ ─ ─ ┘ │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// +/// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [variable length lists]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout +pub struct GenericListArray { + data_type: DataType, + nulls: Option, + values: ArrayRef, + value_offsets: OffsetBuffer, +} + +impl Clone for GenericListArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.clone(), + values: self.values.clone(), + value_offsets: self.value_offsets.clone(), + } + } +} + +impl GenericListArray { + /// The data type constructor of list array. + /// The input is the schema of the child array and + /// the output is the [`DataType`], List or LargeList. + pub const DATA_TYPE_CONSTRUCTOR: fn(FieldRef) -> DataType = if OffsetSize::IS_LARGE { + DataType::LargeList + } else { + DataType::List + }; + + /// Create a new [`GenericListArray`] from the provided parts + /// + /// # Errors + /// + /// Errors if + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * `offsets.last() > values.len()` + /// * `!field.is_nullable() && values.is_nullable()` + /// * `field.data_type() != values.data_type()` + pub fn try_new( + field: FieldRef, + offsets: OffsetBuffer, + values: ArrayRef, + nulls: Option, + ) -> Result { + let len = offsets.len() - 1; // Offsets guaranteed to not be empty + let end_offset = offsets.last().unwrap().as_usize(); + // don't need to check other values of `offsets` because they are checked + // during construction of `OffsetBuffer` + if end_offset > values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Max offset of {end_offset} exceeds length of values {}", + values.len() + ))); + } + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for {}ListArray, expected {len} got {}", + OffsetSize::PREFIX, + n.len(), + ))); + } + } + if !field.is_nullable() && values.is_nullable() { + return Err(ArrowError::InvalidArgumentError(format!( + "Non-nullable field of {}ListArray {:?} cannot contain nulls", + OffsetSize::PREFIX, + field.name() + ))); + } + + if field.data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "{}ListArray expected data type {} got {} for {:?}", + OffsetSize::PREFIX, + field.data_type(), + values.data_type(), + field.name() + ))); + } + + Ok(Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field), + nulls, + values, + value_offsets: offsets, + }) + } + + /// Create a new [`GenericListArray`] from the provided parts + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new( + field: FieldRef, + offsets: OffsetBuffer, + values: ArrayRef, + nulls: Option, + ) -> Self { + Self::try_new(field, offsets, values, nulls).unwrap() + } + + /// Create a new [`GenericListArray`] of length `len` where all values are null + pub fn new_null(field: FieldRef, len: usize) -> Self { + let values = new_empty_array(field.data_type()); + Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field), + nulls: Some(NullBuffer::new_null(len)), + value_offsets: OffsetBuffer::new_zeroed(len), + values, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts( + self, + ) -> ( + FieldRef, + OffsetBuffer, + ArrayRef, + Option, + ) { + let f = match self.data_type { + DataType::List(f) | DataType::LargeList(f) => f, + _ => unreachable!(), + }; + (f, self.value_offsets, self.values, self.nulls) + } + + /// Returns a reference to the offsets of this list + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns a reference to the values of this list + #[inline] + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_type().clone() + } + + /// Returns ith value of this list array. + /// # Safety + /// Caller must ensure that the index is within the array bounds + pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { + let end = self.value_offsets().get_unchecked(i + 1).as_usize(); + let start = self.value_offsets().get_unchecked(i).as_usize(); + self.values.slice(start, end - start) + } + + /// Returns ith value of this list array. + pub fn value(&self, i: usize) -> ArrayRef { + let end = self.value_offsets()[i + 1].as_usize(); + let start = self.value_offsets()[i].as_usize(); + self.values.slice(start, end - start) + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[OffsetSize] { + &self.value_offsets + } + + /// Returns the length for value at index `i`. + #[inline] + pub fn value_length(&self, i: usize) -> OffsetSize { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } + + /// constructs a new iterator + pub fn iter<'a>(&'a self) -> GenericListArrayIter<'a, OffsetSize> { + GenericListArrayIter::<'a, OffsetSize>::new(self) + } + + #[inline] + fn get_type(data_type: &DataType) -> Option<&DataType> { + match (OffsetSize::IS_LARGE, data_type) { + (true, DataType::LargeList(child)) | (false, DataType::List(child)) => { + Some(child.data_type()) + } + _ => None, + } + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + values: self.values.clone(), + value_offsets: self.value_offsets.slice(offset, length), + } + } + + /// Creates a [`GenericListArray`] from an iterator of primitive values + /// # Example + /// ``` + /// # use arrow_array::ListArray; + /// # use arrow_array::types::Int32Type; + /// + /// let data = vec![ + /// Some(vec![Some(0), Some(1), Some(2)]), + /// None, + /// Some(vec![Some(3), None, Some(5)]), + /// Some(vec![Some(6), Some(7)]), + /// ]; + /// let list_array = ListArray::from_iter_primitive::(data); + /// println!("{:?}", list_array); + /// ``` + pub fn from_iter_primitive(iter: I) -> Self + where + T: ArrowPrimitiveType, + P: IntoIterator::Native>>, + I: IntoIterator>, + { + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let mut builder = + GenericListBuilder::with_capacity(PrimitiveBuilder::::new(), size_hint); + + for i in iter { + match i { + Some(p) => { + for t in p { + builder.values().append_option(t); + } + builder.append(true); + } + None => builder.append(false), + } + } + builder.finish() + } +} + +impl From for GenericListArray { + fn from(data: ArrayData) -> Self { + Self::try_new_from_array_data(data) + .expect("Expected infallible creation of GenericListArray from ArrayDataRef failed") + } +} + +impl From> for ArrayData { + fn from(array: GenericListArray) -> Self { + let len = array.len(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .nulls(array.nulls) + .buffers(vec![array.value_offsets.into_inner().into_inner()]) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl From for GenericListArray { + fn from(value: FixedSizeListArray) -> Self { + let (field, size) = match value.data_type() { + DataType::FixedSizeList(f, size) => (f, *size as usize), + _ => unreachable!(), + }; + + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(value.len())); + + Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field.clone()), + nulls: value.nulls().cloned(), + values: value.values().clone(), + value_offsets: offsets, + } + } +} + +impl GenericListArray { + fn try_new_from_array_data(data: ArrayData) -> Result { + if data.buffers().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "ListArray data should contain a single buffer only (value offsets), had {}", + data.buffers().len() + ))); + } + + if data.child_data().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "ListArray should contain a single child array (values array), had {}", + data.child_data().len() + ))); + } + + let values = data.child_data()[0].clone(); + + if let Some(child_data_type) = Self::get_type(data.data_type()) { + if values.data_type() != child_data_type { + return Err(ArrowError::InvalidArgumentError(format!( + "[Large]ListArray's child datatype {:?} does not \ + correspond to the List's datatype {:?}", + values.data_type(), + child_data_type + ))); + } + } else { + return Err(ArrowError::InvalidArgumentError(format!( + "[Large]ListArray's datatype must be [Large]ListArray(). It is {:?}", + data.data_type() + ))); + } + + let values = make_array(values); + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + + Ok(Self { + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + values, + value_offsets, + }) + } +} + +impl Array for GenericListArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.get_buffer_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.values.get_array_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor for &'a GenericListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + GenericListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericListArray::value(self, index) + } +} + +impl std::fmt::Debug for GenericListArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let prefix = OffsetSize::PREFIX; + + write!(f, "{prefix}ListArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +/// A [`GenericListArray`] of variable size lists, storing offsets as `i32`. +/// +// See [`ListBuilder`](crate::builder::ListBuilder) for how to construct a [`ListArray`] +pub type ListArray = GenericListArray; + +/// A [`GenericListArray`] of variable size lists, storing offsets as `i64`. +/// +// See [`LargeListBuilder`](crate::builder::LargeListBuilder) for how to construct a [`LargeListArray`] +pub type LargeListArray = GenericListArray; + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{FixedSizeListBuilder, Int32Builder, ListBuilder}; + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::{Int32Array, Int64Array}; + use arrow_buffer::{bit_util, Buffer, ScalarBuffer}; + use arrow_schema::Field; + + fn create_from_buffers() -> ListArray { + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 8])); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + ListArray::new(field, offsets, Arc::new(values), None) + } + + #[test] + fn test_from_iter_primitive() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list_array = ListArray::from_iter_primitive::(data); + + let another = create_from_buffers(); + assert_eq!(list_array, another) + } + + #[test] + fn test_empty_list_array() { + // Construct an empty value array + let value_data = ArrayData::builder(DataType::Int32) + .len(0) + .add_buffer(Buffer::from([])) + .build() + .unwrap(); + + // Construct an empty offset buffer + let value_offsets = Buffer::from([]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type) + .len(0) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + + let list_array = ListArray::from(list_data); + assert_eq!(list_array.len(), 0) + } + + #[test] + fn test_list_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = ListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offsets()[2]); + assert_eq!(2, list_array.value_length(2)); + assert_eq!(0, list_array.value(0).as_primitive::().value(0)); + assert_eq!( + 0, + unsafe { list_array.value_unchecked(0) } + .as_primitive::() + .value(0) + ); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + + // Now test with a non-zero offset (skip first element) + // [[3, 4, 5], [6, 7]] + let list_data = ArrayData::builder(list_data_type) + .len(2) + .offset(1) + .add_buffer(value_offsets) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = ListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(2, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offsets()[1]); + assert_eq!(2, list_array.value_length(1)); + assert_eq!(3, list_array.value(0).as_primitive::().value(0)); + assert_eq!( + 3, + unsafe { list_array.value_unchecked(0) } + .as_primitive::() + .value(0) + ); + } + + #[test] + fn test_large_list_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = DataType::new_large_list(DataType::Int32, false); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = LargeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offsets()[2]); + assert_eq!(2, list_array.value_length(2)); + assert_eq!(0, list_array.value(0).as_primitive::().value(0)); + assert_eq!( + 0, + unsafe { list_array.value_unchecked(0) } + .as_primitive::() + .value(0) + ); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + + // Now test with a non-zero offset + // [[3, 4, 5], [6, 7]] + let list_data = ArrayData::builder(list_data_type) + .len(2) + .offset(1) + .add_buffer(value_offsets) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = LargeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(2, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offsets()[1]); + assert_eq!(2, list_array.value_length(1)); + assert_eq!(3, list_array.value(0).as_primitive::().value(0)); + assert_eq!( + 3, + unsafe { list_array.value_unchecked(0) } + .as_primitive::() + .value(0) + ); + } + + #[test] + fn test_list_array_slice() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] + let value_offsets = Buffer::from_slice_ref([0, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + // 01011001 00000001 + let mut null_bits: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + bit_util::set_bit(&mut null_bits, 6); + bit_util::set_bit(&mut null_bits, 8); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type) + .len(9) + .add_buffer(value_offsets) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = ListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(9, list_array.len()); + assert_eq!(4, list_array.null_count()); + assert_eq!(2, list_array.value_offsets()[3]); + assert_eq!(2, list_array.value_length(3)); + + let sliced_array = list_array.slice(1, 6); + assert_eq!(6, sliced_array.len()); + assert_eq!(3, sliced_array.null_count()); + + for i in 0..sliced_array.len() { + if bit_util::get_bit(&null_bits, 1 + i) { + assert!(sliced_array.is_valid(i)); + } else { + assert!(sliced_array.is_null(i)); + } + } + + // Check offset and length for each non-null value. + let sliced_list_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(2, sliced_list_array.value_offsets()[2]); + assert_eq!(2, sliced_list_array.value_length(2)); + assert_eq!(4, sliced_list_array.value_offsets()[3]); + assert_eq!(2, sliced_list_array.value_length(3)); + assert_eq!(6, sliced_list_array.value_offsets()[5]); + assert_eq!(3, sliced_list_array.value_length(5)); + } + + #[test] + fn test_large_list_array_slice() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] + let value_offsets = Buffer::from_slice_ref([0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + // 01011001 00000001 + let mut null_bits: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + bit_util::set_bit(&mut null_bits, 6); + bit_util::set_bit(&mut null_bits, 8); + + // Construct a list array from the above two + let list_data_type = DataType::new_large_list(DataType::Int32, false); + let list_data = ArrayData::builder(list_data_type) + .len(9) + .add_buffer(value_offsets) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = LargeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(9, list_array.len()); + assert_eq!(4, list_array.null_count()); + assert_eq!(2, list_array.value_offsets()[3]); + assert_eq!(2, list_array.value_length(3)); + + let sliced_array = list_array.slice(1, 6); + assert_eq!(6, sliced_array.len()); + assert_eq!(3, sliced_array.null_count()); + + for i in 0..sliced_array.len() { + if bit_util::get_bit(&null_bits, 1 + i) { + assert!(sliced_array.is_valid(i)); + } else { + assert!(sliced_array.is_null(i)); + } + } + + // Check offset and length for each non-null value. + let sliced_list_array = sliced_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, sliced_list_array.value_offsets()[2]); + assert_eq!(2, sliced_list_array.value_length(2)); + assert_eq!(4, sliced_list_array.value_offsets()[3]); + assert_eq!(2, sliced_list_array.value_length(3)); + assert_eq!(6, sliced_list_array.value_offsets()[5]); + assert_eq!(3, sliced_list_array.value_length(5)); + } + + #[test] + #[should_panic(expected = "index out of bounds: the len is 10 but the index is 11")] + fn test_list_array_index_out_of_bound() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] + let value_offsets = Buffer::from_slice_ref([0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + // 01011001 00000001 + let mut null_bits: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + bit_util::set_bit(&mut null_bits, 6); + bit_util::set_bit(&mut null_bits, 8); + + // Construct a list array from the above two + let list_data_type = DataType::new_large_list(DataType::Int32, false); + let list_data = ArrayData::builder(list_data_type) + .len(9) + .add_buffer(value_offsets) + .add_child_data(value_data) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = LargeListArray::from(list_data); + assert_eq!(9, list_array.len()); + + list_array.value(10); + } + #[test] + #[should_panic(expected = "ListArray data should contain a single buffer only (value offsets)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_list_array_invalid_buffer_len() { + let value_data = unsafe { + ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build_unchecked() + }; + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(ListArray::from(list_data)); + } + + #[test] + #[should_panic(expected = "ListArray should contain a single child array (values array)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_list_array_invalid_child_array_len() { + let value_offsets = Buffer::from_slice_ref([0, 2, 5, 7]); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .build_unchecked() + }; + drop(ListArray::from(list_data)); + } + + #[test] + #[should_panic(expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List")] + fn test_from_array_data_validation() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.append(true); + let array = builder.finish(); + let _ = LargeListArray::from(array.into_data()); + } + + #[test] + fn test_list_array_offsets_need_not_start_at_zero() { + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let value_offsets = Buffer::from_slice_ref([2, 2, 5, 7]); + + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + + let list_array = ListArray::from(list_data); + assert_eq!(list_array.value_length(0), 0); + assert_eq!(list_array.value_length(1), 3); + assert_eq!(list_array.value_length(2), 2); + } + + #[test] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_primitive_array_alignment() { + let buf = Buffer::from_slice_ref([0_u64]); + let buf2 = buf.slice(1); + let array_data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buf2) + .build_unchecked() + }; + drop(Int32Array::from(array_data)); + } + + #[test] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_list_array_alignment() { + let buf = Buffer::from_slice_ref([0_u64]); + let buf2 = buf.slice(1); + + let values: [i32; 8] = [0; 8]; + let value_data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(Buffer::from_slice_ref(values)) + .build_unchecked() + }; + + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .add_buffer(buf2) + .add_child_data(value_data) + .build_unchecked() + }; + drop(ListArray::from(list_data)); + } + + #[test] + fn list_array_equality() { + // test scaffold + fn do_comparison( + lhs_data: Vec>>>, + rhs_data: Vec>>>, + should_equal: bool, + ) { + let lhs = ListArray::from_iter_primitive::(lhs_data.clone()); + let rhs = ListArray::from_iter_primitive::(rhs_data.clone()); + assert_eq!(lhs == rhs, should_equal); + + let lhs = LargeListArray::from_iter_primitive::(lhs_data); + let rhs = LargeListArray::from_iter_primitive::(rhs_data); + assert_eq!(lhs == rhs, should_equal); + } + + do_comparison( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ], + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ], + true, + ); + + do_comparison( + vec![ + None, + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ], + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ], + false, + ); + + do_comparison( + vec![ + None, + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ], + vec![ + None, + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(0), Some(0)]), + ], + false, + ); + + do_comparison( + vec![None, None, Some(vec![Some(1)])], + vec![None, None, Some(vec![Some(2)])], + false, + ); + } + + #[test] + fn test_empty_offsets() { + let f = Arc::new(Field::new("element", DataType::Int32, true)); + let string = ListArray::from( + ArrayData::builder(DataType::List(f.clone())) + .buffers(vec![Buffer::from(&[])]) + .add_child_data(ArrayData::new_empty(&DataType::Int32)) + .build() + .unwrap(), + ); + assert_eq!(string.value_offsets(), &[0]); + let string = LargeListArray::from( + ArrayData::builder(DataType::LargeList(f)) + .buffers(vec![Buffer::from(&[])]) + .add_child_data(ArrayData::new_empty(&DataType::Int32)) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_try_new() { + let offsets = OffsetBuffer::new(vec![0, 1, 4, 5].into()); + let values = Int32Array::new(vec![1, 2, 3, 4, 5].into(), None); + let values = Arc::new(values) as ArrayRef; + + let field = Arc::new(Field::new("element", DataType::Int32, false)); + ListArray::new(field.clone(), offsets.clone(), values.clone(), None); + + let nulls = NullBuffer::new_null(3); + ListArray::new(field.clone(), offsets, values.clone(), Some(nulls)); + + let nulls = NullBuffer::new_null(3); + let offsets = OffsetBuffer::new(vec![0, 1, 2, 4, 5].into()); + let err = LargeListArray::try_new(field, offsets.clone(), values.clone(), Some(nulls)) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for LargeListArray, expected 4 got 3" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, false)); + let err = LargeListArray::try_new(field.clone(), offsets.clone(), values.clone(), None) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: LargeListArray expected data type Int64 got Int32 for \"element\"" + ); + + let nulls = NullBuffer::new_null(7); + let values = Int64Array::new(vec![0; 7].into(), Some(nulls)); + let values = Arc::new(values); + + let err = + LargeListArray::try_new(field, offsets.clone(), values.clone(), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Non-nullable field of LargeListArray \"element\" cannot contain nulls" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, true)); + LargeListArray::new(field.clone(), offsets.clone(), values, None); + + let values = Int64Array::new(vec![0; 2].into(), None); + let err = LargeListArray::try_new(field, offsets, Arc::new(values), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Max offset of 5 exceeds length of values 2" + ); + } + + #[test] + fn test_from_fixed_size_list() { + let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[0, 0, 0]); + builder.append(false); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + let list: ListArray = builder.finish().into(); + + let values: Vec<_> = list + .iter() + .map(|x| x.map(|x| x.as_primitive::().values().to_vec())) + .collect(); + assert_eq!(values, vec![Some(vec![1, 2, 3]), None, Some(vec![4, 5, 6])]) + } +} diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs new file mode 100644 index 000000000000..bddf202bdede --- /dev/null +++ b/arrow-array/src/array/map_array.rs @@ -0,0 +1,801 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{get_offsets, print_long_array}; +use crate::iterator::MapArrayIter; +use crate::{make_array, Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray}; +use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, OffsetBuffer, ToByteSlice}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// An array of key-value maps +/// +/// Keys should always be non-null, but values can be null. +/// +/// [`MapArray`] is physically a [`ListArray`] of key values pairs stored as an `entries` +/// [`StructArray`] with 2 child fields. +/// +/// See [`MapBuilder`](crate::builder::MapBuilder) for how to construct a [`MapArray`] +#[derive(Clone)] +pub struct MapArray { + data_type: DataType, + nulls: Option, + /// The [`StructArray`] that is the direct child of this array + entries: StructArray, + /// The start and end offsets of each entry + value_offsets: OffsetBuffer, +} + +impl MapArray { + /// Create a new [`MapArray`] from the provided parts + /// + /// See [`MapBuilder`](crate::builder::MapBuilder) for a higher-level interface + /// to construct a [`MapArray`] + /// + /// # Errors + /// + /// Errors if + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * `offsets.last() > entries.len()` + /// * `field.is_nullable()` + /// * `entries.null_count() != 0` + /// * `entries.columns().len() != 2` + /// * `field.data_type() != entries.data_type()` + pub fn try_new( + field: FieldRef, + offsets: OffsetBuffer, + entries: StructArray, + nulls: Option, + ordered: bool, + ) -> Result { + let len = offsets.len() - 1; // Offsets guaranteed to not be empty + let end_offset = offsets.last().unwrap().as_usize(); + // don't need to check other values of `offsets` because they are checked + // during construction of `OffsetBuffer` + if end_offset > entries.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Max offset of {end_offset} exceeds length of entries {}", + entries.len() + ))); + } + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for MapArray, expected {len} got {}", + n.len(), + ))); + } + } + if field.is_nullable() || entries.null_count() != 0 { + return Err(ArrowError::InvalidArgumentError( + "MapArray entries cannot contain nulls".to_string(), + )); + } + + if field.data_type() != entries.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray expected data type {} got {} for {:?}", + field.data_type(), + entries.data_type(), + field.name() + ))); + } + + if entries.columns().len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray entries must contain two children, got {}", + entries.columns().len() + ))); + } + + Ok(Self { + data_type: DataType::Map(field, ordered), + nulls, + entries, + value_offsets: offsets, + }) + } + + /// Create a new [`MapArray`] from the provided parts + /// + /// See [`MapBuilder`](crate::builder::MapBuilder) for a higher-level interface + /// to construct a [`MapArray`] + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new( + field: FieldRef, + offsets: OffsetBuffer, + entries: StructArray, + nulls: Option, + ordered: bool, + ) -> Self { + Self::try_new(field, offsets, entries, nulls, ordered).unwrap() + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts( + self, + ) -> ( + FieldRef, + OffsetBuffer, + StructArray, + Option, + bool, + ) { + let (f, ordered) = match self.data_type { + DataType::Map(f, ordered) => (f, ordered), + _ => unreachable!(), + }; + (f, self.value_offsets, self.entries, self.nulls, ordered) + } + + /// Returns a reference to the offsets of this map + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns a reference to the keys of this map + pub fn keys(&self) -> &ArrayRef { + self.entries.column(0) + } + + /// Returns a reference to the values of this map + pub fn values(&self) -> &ArrayRef { + self.entries.column(1) + } + + /// Returns a reference to the [`StructArray`] entries of this map + pub fn entries(&self) -> &StructArray { + &self.entries + } + + /// Returns the data type of the map's keys. + pub fn key_type(&self) -> &DataType { + self.keys().data_type() + } + + /// Returns the data type of the map's values. + pub fn value_type(&self) -> &DataType { + self.values().data_type() + } + + /// Returns ith value of this map array. + /// + /// # Safety + /// Caller must ensure that the index is within the array bounds + pub unsafe fn value_unchecked(&self, i: usize) -> StructArray { + let end = *self.value_offsets().get_unchecked(i + 1); + let start = *self.value_offsets().get_unchecked(i); + self.entries + .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) + } + + /// Returns ith value of this map array. + /// + /// This is a [`StructArray`] containing two fields + pub fn value(&self, i: usize) -> StructArray { + let end = self.value_offsets()[i + 1] as usize; + let start = self.value_offsets()[i] as usize; + self.entries.slice(start, end - start) + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[i32] { + &self.value_offsets + } + + /// Returns the length for value at index `i`. + #[inline] + pub fn value_length(&self, i: usize) -> i32 { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + entries: self.entries.clone(), + value_offsets: self.value_offsets.slice(offset, length), + } + } + + /// constructs a new iterator + pub fn iter(&self) -> MapArrayIter<'_> { + MapArrayIter::new(self) + } +} + +impl From for MapArray { + fn from(data: ArrayData) -> Self { + Self::try_new_from_array_data(data) + .expect("Expected infallible creation of MapArray from ArrayData failed") + } +} + +impl From for ArrayData { + fn from(array: MapArray) -> Self { + let len = array.len(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .nulls(array.nulls) + .buffers(vec![array.value_offsets.into_inner().into_inner()]) + .child_data(vec![array.entries.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl MapArray { + fn try_new_from_array_data(data: ArrayData) -> Result { + if !matches!(data.data_type(), DataType::Map(_, _)) { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray expected ArrayData with DataType::Map got {}", + data.data_type() + ))); + } + + if data.buffers().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray data should contain a single buffer only (value offsets), had {}", + data.len() + ))); + } + + if data.child_data().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a single child array (values array), had {}", + data.child_data().len() + ))); + } + + let entries = data.child_data()[0].clone(); + + if let DataType::Struct(fields) = entries.data_type() { + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array with 2 fields, have {} fields", + fields.len() + ))); + } + } else { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array child, found {:?}", + entries.data_type() + ))); + } + let entries = entries.into(); + + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + + Ok(Self { + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + entries, + value_offsets, + }) + } + + /// Creates map array from provided keys, values and entry_offsets. + pub fn new_from_strings<'a>( + keys: impl Iterator, + values: &dyn Array, + entry_offsets: &[u32], + ) -> Result { + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + let keys_data = StringArray::from_iter_values(keys); + + let keys_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let values_field = Arc::new(Field::new( + "values", + values.data_type().clone(), + values.null_count() > 0, + )); + + let entry_struct = StructArray::from(vec![ + (keys_field, Arc::new(keys_data) as ArrayRef), + (values_field, make_array(values.to_data())), + ]); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.into_data()) + .build()?; + + Ok(MapArray::from(map_data)) + } +} + +impl Array for MapArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into_data() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.entries.get_buffer_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.entries.get_array_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl<'a> ArrayAccessor for &'a MapArray { + type Item = StructArray; + + fn value(&self, index: usize) -> Self::Item { + MapArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + MapArray::value(self, index) + } +} + +impl std::fmt::Debug for MapArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MapArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl From for ListArray { + fn from(value: MapArray) -> Self { + let field = match value.data_type() { + DataType::Map(field, _) => field, + _ => unreachable!("This should be a map type."), + }; + let data_type = DataType::List(field.clone()); + let builder = value.into_data().into_builder().data_type(data_type); + let array_data = unsafe { builder.build_unchecked() }; + + ListArray::from(array_data) + } +} + +#[cfg(test)] +mod tests { + use crate::cast::AsArray; + use crate::types::UInt32Type; + use crate::{Int32Array, UInt32Array}; + use arrow_schema::Fields; + + use super::*; + + fn create_from_buffers() -> MapArray { + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from([0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + [0u32, 10, 20, 30, 40, 50, 60, 70].to_byte_slice(), + )) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from([0, 3, 6, 8].to_byte_slice()); + + let keys = Arc::new(Field::new("keys", DataType::Int32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, false)); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + MapArray::from(map_data) + } + + #[test] + fn test_map_array() { + // Construct key and values + let key_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from([0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let value_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + [0u32, 10, 20, 0, 40, 0, 60, 70].to_byte_slice(), + )) + .null_bit_buffer(Some(Buffer::from(&[0b11010110]))) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from([0, 3, 6, 8].to_byte_slice()); + + let keys_field = Arc::new(Field::new("keys", DataType::Int32, false)); + let values_field = Arc::new(Field::new("values", DataType::UInt32, true)); + let entry_struct = StructArray::from(vec![ + (keys_field.clone(), make_array(key_data)), + (values_field.clone(), make_array(value_data.clone())), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + let map_array = MapArray::from(map_data); + + assert_eq!(value_data, map_array.values().to_data()); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(Int32Array::from(vec![0, 1, 2])) as ArrayRef; + let value_array = + Arc::new(UInt32Array::from(vec![None, Some(10u32), Some(20)])) as ArrayRef; + let struct_array = StructArray::from(vec![ + (keys_field.clone(), key_array), + (values_field.clone(), value_array), + ]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).into_data()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + + // Now test with a non-zero offset + let map_array = map_array.slice(1, 2); + + assert_eq!(value_data, map_array.values().to_data()); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(2, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[1]); + assert_eq!(2, map_array.value_length(1)); + + let key_array = Arc::new(Int32Array::from(vec![3, 4, 5])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![None, Some(40), None])) as ArrayRef; + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + &struct_array, + map_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + } + + #[test] + #[ignore = "Test fails because slice of > is still buggy"] + fn test_map_array_slice() { + let map_array = create_from_buffers(); + + let sliced_array = map_array.slice(1, 2); + assert_eq!(2, sliced_array.len()); + assert_eq!(1, sliced_array.offset()); + let sliced_array_data = sliced_array.to_data(); + for array_data in sliced_array_data.child_data() { + assert_eq!(array_data.offset(), 1); + } + + // Check offset and length for each non-null value. + let sliced_map_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_map_array.value_offsets()[0]); + assert_eq!(3, sliced_map_array.value_length(0)); + assert_eq!(6, sliced_map_array.value_offsets()[1]); + assert_eq!(2, sliced_map_array.value_length(1)); + + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(5) + .add_buffer(Buffer::from([3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(5) + .add_buffer(Buffer::from([30u32, 40, 50, 60, 70].to_byte_slice())) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from([0, 3, 5].to_byte_slice()); + + let keys = Arc::new(Field::new("keys", DataType::Int32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, false)); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let expected_map_data = ArrayData::builder(map_data_type) + .len(2) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + let expected_map_array = MapArray::from(expected_map_data); + + assert_eq!(&expected_map_array, sliced_map_array) + } + + #[test] + #[should_panic(expected = "index out of bounds: the len is ")] + fn test_map_array_index_out_of_bound() { + let map_array = create_from_buffers(); + + map_array.value(map_array.len()); + } + + #[test] + #[should_panic(expected = "MapArray expected ArrayData with DataType::Map got Dictionary")] + fn test_from_array_data_validation() { + // A DictionaryArray has similar buffer layout to a MapArray + // but the meaning of the values differs + let struct_t = DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, true), + Field::new("values", DataType::UInt32, true), + ])); + let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t)); + let _ = MapArray::from(ArrayData::new_empty(&dict_t)); + } + + #[test] + fn test_new_from_strings() { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + assert_eq!( + &values_data, + map_array.values().as_primitive::() + ); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![0u32, 10, 20])) as ArrayRef; + let keys_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let values_field = Arc::new(Field::new("values", DataType::UInt32, false)); + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).into_data()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + } + + #[test] + fn test_try_new() { + let offsets = OffsetBuffer::new(vec![0, 1, 4, 5].into()); + let fields = Fields::from(vec![ + Field::new("key", DataType::Int32, false), + Field::new("values", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + ]; + + let entries = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + + MapArray::new(field.clone(), offsets.clone(), entries.clone(), None, false); + + let nulls = NullBuffer::new_null(3); + MapArray::new(field.clone(), offsets, entries.clone(), Some(nulls), false); + + let nulls = NullBuffer::new_null(3); + let offsets = OffsetBuffer::new(vec![0, 1, 2, 4, 5].into()); + let err = MapArray::try_new( + field.clone(), + offsets.clone(), + entries.clone(), + Some(nulls), + false, + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for MapArray, expected 4 got 3" + ); + + let err = MapArray::try_new(field, offsets.clone(), entries.slice(0, 2), None, false) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Max offset of 5 exceeds length of entries 2" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, false)); + let err = MapArray::try_new(field, offsets.clone(), entries, None, false) + .unwrap_err() + .to_string(); + + assert!( + err.starts_with("Invalid argument error: MapArray expected data type Int64 got Struct"), + "{err}" + ); + + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + ]; + + let s = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + let err = MapArray::try_new(field, offsets, s, None, false).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: MapArray entries must contain two children, got 3" + ); + } +} diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs new file mode 100644 index 000000000000..50c5699bac32 --- /dev/null +++ b/arrow-array/src/array/mod.rs @@ -0,0 +1,1143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The concrete array definitions + +mod binary_array; + +use crate::types::*; +use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_data::ArrayData; +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use std::any::Any; +use std::sync::Arc; + +pub use binary_array::*; + +mod boolean_array; +pub use boolean_array::*; + +mod byte_array; +pub use byte_array::*; + +mod dictionary_array; +pub use dictionary_array::*; + +mod fixed_size_binary_array; +pub use fixed_size_binary_array::*; + +mod fixed_size_list_array; +pub use fixed_size_list_array::*; + +mod list_array; +pub use list_array::*; + +mod map_array; +pub use map_array::*; + +mod null_array; +pub use null_array::*; + +mod primitive_array; +pub use primitive_array::*; + +mod string_array; +pub use string_array::*; + +mod struct_array; +pub use struct_array::*; + +mod union_array; +pub use union_array::*; + +mod run_array; + +pub use run_array::*; + +mod byte_view_array; + +pub use byte_view_array::*; + +/// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) +pub trait Array: std::fmt::Debug + Send + Sync { + /// Returns the array as [`Any`] so that it can be + /// downcasted to a specific implementation. + /// + /// # Example: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{Schema, Field, DataType, ArrowError}; + /// + /// let id = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let batch = RecordBatch::try_new( + /// Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + /// vec![Arc::new(id)] + /// ).unwrap(); + /// + /// let int32array = batch + /// .column(0) + /// .as_any() + /// .downcast_ref::() + /// .expect("Failed to downcast"); + /// ``` + fn as_any(&self) -> &dyn Any; + + /// Returns the underlying data of this array + fn to_data(&self) -> ArrayData; + + /// Returns the underlying data of this array + /// + /// Unlike [`Array::to_data`] this consumes self, allowing it avoid unnecessary clones + fn into_data(self) -> ArrayData; + + /// Returns a reference to the [`DataType`] of this array. + /// + /// # Example: + /// + /// ``` + /// use arrow_schema::DataType; + /// use arrow_array::{Array, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// + /// assert_eq!(*array.data_type(), DataType::Int32); + /// ``` + fn data_type(&self) -> &DataType; + + /// Returns a zero-copy slice of this array with the indicated offset and length. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// // Make slice over the values [2, 3, 4] + /// let array_slice = array.slice(1, 3); + /// + /// assert_eq!(&array_slice, &Int32Array::from(vec![2, 3, 4])); + /// ``` + fn slice(&self, offset: usize, length: usize) -> ArrayRef; + + /// Returns the length (i.e., number of elements) of this array. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// + /// assert_eq!(array.len(), 5); + /// ``` + fn len(&self) -> usize; + + /// Returns whether this array is empty. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// + /// assert_eq!(array.is_empty(), false); + /// ``` + fn is_empty(&self) -> bool; + + /// Returns the offset into the underlying data used by this array(-slice). + /// Note that the underlying data can be shared by many arrays. + /// This defaults to `0`. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, BooleanArray}; + /// + /// let array = BooleanArray::from(vec![false, false, true, true]); + /// let array_slice = array.slice(1, 3); + /// + /// assert_eq!(array.offset(), 0); + /// assert_eq!(array_slice.offset(), 1); + /// ``` + fn offset(&self) -> usize; + + /// Returns the null buffer of this array if any. + /// + /// The null buffer contains the "physical" nulls of an array, that is how + /// the nulls are represented in the underlying arrow format. + /// + /// The physical representation is efficient, but is sometimes non intuitive + /// for certain array types such as those with nullable child arrays like + /// [`DictionaryArray::values`] or [`RunArray::values`], or without a + /// null buffer, such as [`NullArray`]. + /// + /// To determine if each element of such an array is "logically" null, + /// use the slower [`Array::logical_nulls`] to obtain a computed mask. + fn nulls(&self) -> Option<&NullBuffer>; + + /// Returns a potentially computed [`NullBuffer`] that represents the logical + /// null values of this array, if any. + /// + /// Logical nulls represent the values that are null in the array, + /// regardless of the underlying physical arrow representation. + /// + /// For most array types, this is equivalent to the "physical" nulls + /// returned by [`Array::nulls`]. It is different for the following cases, because which + /// elements are null is not encoded in a single null buffer: + /// + /// * [`DictionaryArray`] where [`DictionaryArray::values`] contains nulls + /// * [`RunArray`] where [`RunArray::values`] contains nulls + /// * [`NullArray`] where all indices are nulls + /// + /// In these cases a logical [`NullBuffer`] will be computed, encoding the + /// logical nullability of these arrays, beyond what is encoded in + /// [`Array::nulls`] + fn logical_nulls(&self) -> Option { + self.nulls().cloned() + } + + /// Returns whether the element at `index` is null according to [`Array::nulls`] + /// + /// Note: For performance reasons, this method returns nullability solely as determined by the + /// null buffer. This difference can lead to surprising results, for example, [`NullArray::is_null`] always + /// returns `false` as the array lacks a null buffer. Similarly [`DictionaryArray`] and [`RunArray`] may + /// encode nullability in their children. See [`Self::logical_nulls`] for more information. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array, NullArray}; + /// + /// let array = Int32Array::from(vec![Some(1), None]); + /// assert_eq!(array.is_null(0), false); + /// assert_eq!(array.is_null(1), true); + /// + /// // NullArrays do not have a null buffer, and therefore always + /// // return false for is_null. + /// let array = NullArray::new(1); + /// assert_eq!(array.is_null(0), false); + /// ``` + fn is_null(&self, index: usize) -> bool { + self.nulls().map(|n| n.is_null(index)).unwrap_or_default() + } + + /// Returns whether the element at `index` is *not* null, the + /// opposite of [`Self::is_null`]. + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array}; + /// + /// let array = Int32Array::from(vec![Some(1), None]); + /// + /// assert_eq!(array.is_valid(0), true); + /// assert_eq!(array.is_valid(1), false); + /// ``` + fn is_valid(&self, index: usize) -> bool { + !self.is_null(index) + } + + /// Returns the total number of physical null values in this array. + /// + /// Note: this method returns the physical null count, i.e. that encoded in [`Array::nulls`], + /// see [`Array::logical_nulls`] for logical nullability + /// + /// # Example: + /// + /// ``` + /// use arrow_array::{Array, Int32Array}; + /// + /// // Construct an array with values [1, NULL, NULL] + /// let array = Int32Array::from(vec![Some(1), None, None]); + /// + /// assert_eq!(array.null_count(), 2); + /// ``` + fn null_count(&self) -> usize { + self.nulls().map(|n| n.null_count()).unwrap_or_default() + } + + /// Returns `false` if the array is guaranteed to not contain any logical nulls + /// + /// In general this will be equivalent to `Array::null_count() != 0` but may differ in the + /// presence of logical nullability, see [`Array::logical_nulls`]. + /// + /// Implementations will return `true` unless they can cheaply prove no logical nulls + /// are present. For example a [`DictionaryArray`] with nullable values will still return true, + /// even if the nulls present in [`DictionaryArray::values`] are not referenced by any key, + /// and therefore would not appear in [`Array::logical_nulls`]. + fn is_nullable(&self) -> bool { + self.null_count() != 0 + } + + /// Returns the total number of bytes of memory pointed to by this array. + /// The buffers store bytes in the Arrow memory format, and include the data as well as the validity map. + /// Note that this does not always correspond to the exact memory usage of an array, + /// since multiple arrays can share the same buffers or slices thereof. + fn get_buffer_memory_size(&self) -> usize; + + /// Returns the total number of bytes of memory occupied physically by this array. + /// This value will always be greater than returned by `get_buffer_memory_size()` and + /// includes the overhead of the data structures that contain the pointers to the various buffers. + fn get_array_memory_size(&self) -> usize; +} + +/// A reference-counted reference to a generic `Array` +pub type ArrayRef = Arc; + +/// Ergonomics: Allow use of an ArrayRef as an `&dyn Array` +impl Array for ArrayRef { + fn as_any(&self) -> &dyn Any { + self.as_ref().as_any() + } + + fn to_data(&self) -> ArrayData { + self.as_ref().to_data() + } + + fn into_data(self) -> ArrayData { + self.to_data() + } + + fn data_type(&self) -> &DataType { + self.as_ref().data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + self.as_ref().slice(offset, length) + } + + fn len(&self) -> usize { + self.as_ref().len() + } + + fn is_empty(&self) -> bool { + self.as_ref().is_empty() + } + + fn offset(&self) -> usize { + self.as_ref().offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.as_ref().nulls() + } + + fn logical_nulls(&self) -> Option { + self.as_ref().logical_nulls() + } + + fn is_null(&self, index: usize) -> bool { + self.as_ref().is_null(index) + } + + fn is_valid(&self, index: usize) -> bool { + self.as_ref().is_valid(index) + } + + fn null_count(&self) -> usize { + self.as_ref().null_count() + } + + fn is_nullable(&self) -> bool { + self.as_ref().is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.as_ref().get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.as_ref().get_array_memory_size() + } +} + +impl<'a, T: Array> Array for &'a T { + fn as_any(&self) -> &dyn Any { + T::as_any(self) + } + + fn to_data(&self) -> ArrayData { + T::to_data(self) + } + + fn into_data(self) -> ArrayData { + self.to_data() + } + + fn data_type(&self) -> &DataType { + T::data_type(self) + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + T::slice(self, offset, length) + } + + fn len(&self) -> usize { + T::len(self) + } + + fn is_empty(&self) -> bool { + T::is_empty(self) + } + + fn offset(&self) -> usize { + T::offset(self) + } + + fn nulls(&self) -> Option<&NullBuffer> { + T::nulls(self) + } + + fn logical_nulls(&self) -> Option { + T::logical_nulls(self) + } + + fn is_null(&self, index: usize) -> bool { + T::is_null(self, index) + } + + fn is_valid(&self, index: usize) -> bool { + T::is_valid(self, index) + } + + fn null_count(&self) -> usize { + T::null_count(self) + } + + fn is_nullable(&self) -> bool { + T::is_nullable(self) + } + + fn get_buffer_memory_size(&self) -> usize { + T::get_buffer_memory_size(self) + } + + fn get_array_memory_size(&self) -> usize { + T::get_array_memory_size(self) + } +} + +/// A generic trait for accessing the values of an [`Array`] +/// +/// This trait helps write specialized implementations of algorithms for +/// different array types. Specialized implementations allow the compiler +/// to optimize the code for the specific array type, which can lead to +/// significant performance improvements. +/// +/// # Example +/// For example, to write three different implementations of a string length function +/// for [`StringArray`], [`LargeStringArray`], and [`StringViewArray`], you can write +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayAccessor, ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +/// # use arrow_buffer::ArrowNativeType; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::iterator::ArrayIter; +/// # use arrow_array::types::{Int32Type, Int64Type}; +/// # use arrow_schema::{ArrowError, DataType}; +/// /// This function takes a dynamically typed `ArrayRef` and calls +/// /// calls one of three specialized implementations +/// fn character_length(arg: ArrayRef) -> Result { +/// match arg.data_type() { +/// DataType::Utf8 => { +/// // downcast the ArrayRef to a StringArray and call the specialized implementation +/// let string_array = arg.as_string::(); +/// character_length_general::(string_array) +/// } +/// DataType::LargeUtf8 => { +/// character_length_general::(arg.as_string::()) +/// } +/// DataType::Utf8View => { +/// character_length_general::(arg.as_string_view()) +/// } +/// _ => Err(ArrowError::InvalidArgumentError("Unsupported data type".to_string())), +/// } +/// } +/// +/// /// A generic implementation of the character_length function +/// /// This function uses the `ArrayAccessor` trait to access the values of the array +/// /// so the compiler can generated specialized implementations for different array types +/// /// +/// /// Returns a new array with the length of each string in the input array +/// /// * Int32Array for Utf8 and Utf8View arrays (lengths are 32-bit integers) +/// /// * Int64Array for LargeUtf8 arrays (lengths are 64-bit integers) +/// /// +/// /// This is generic on the type of the primitive array (different string arrays have +/// /// different lengths) and the type of the array accessor (different string arrays +/// /// have different ways to access the values) +/// fn character_length_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( +/// array: V, +/// ) -> Result +/// where +/// T::Native: OffsetSizeTrait, +/// { +/// let iter = ArrayIter::new(array); +/// // Create a Int32Array / Int64Array with the length of each string +/// let result = iter +/// .map(|string| { +/// string.map(|string: &str| { +/// T::Native::from_usize(string.chars().count()) +/// .expect("should not fail as string.chars will always return integer") +/// }) +/// }) +/// .collect::>(); +/// +/// /// Return the result as a new ArrayRef (dynamically typed) +/// Ok(Arc::new(result) as ArrayRef) +/// } +/// ``` +/// +/// # Validity +/// +/// An [`ArrayAccessor`] must always return a well-defined value for an index +/// that is within the bounds `0..Array::len`, including for null indexes where +/// [`Array::is_null`] is true. +/// +/// The value at null indexes is unspecified, and implementations must not rely +/// on a specific value such as [`Default::default`] being returned, however, it +/// must not be undefined +pub trait ArrayAccessor: Array { + /// The Arrow type of the element being accessed. + type Item: Send + Sync; + + /// Returns the element at index `i` + /// # Panics + /// Panics if the value is outside the bounds of the array + fn value(&self, index: usize) -> Self::Item; + + /// Returns the element at index `i` + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + unsafe fn value_unchecked(&self, index: usize) -> Self::Item; +} + +impl PartialEq for dyn Array + '_ { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for dyn Array + '_ { + fn eq(&self, other: &T) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for NullArray { + fn eq(&self, other: &NullArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for PrimitiveArray { + fn eq(&self, other: &PrimitiveArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for BooleanArray { + fn eq(&self, other: &BooleanArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericStringArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericListArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +/// Constructs an array using the input `data`. +/// Returns a reference-counted `Array` instance. +pub fn make_array(data: ArrayData) -> ArrayRef { + match data.data_type() { + DataType::Boolean => Arc::new(BooleanArray::from(data)) as ArrayRef, + DataType::Int8 => Arc::new(Int8Array::from(data)) as ArrayRef, + DataType::Int16 => Arc::new(Int16Array::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(Int32Array::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(Int64Array::from(data)) as ArrayRef, + DataType::UInt8 => Arc::new(UInt8Array::from(data)) as ArrayRef, + DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef, + DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef, + DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef, + DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef, + DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef, + DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, + DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef, + DataType::Date64 => Arc::new(Date64Array::from(data)) as ArrayRef, + DataType::Time32(TimeUnit::Second) => Arc::new(Time32SecondArray::from(data)) as ArrayRef, + DataType::Time32(TimeUnit::Millisecond) => { + Arc::new(Time32MillisecondArray::from(data)) as ArrayRef + } + DataType::Time64(TimeUnit::Microsecond) => { + Arc::new(Time64MicrosecondArray::from(data)) as ArrayRef + } + DataType::Time64(TimeUnit::Nanosecond) => { + Arc::new(Time64NanosecondArray::from(data)) as ArrayRef + } + DataType::Timestamp(TimeUnit::Second, _) => { + Arc::new(TimestampSecondArray::from(data)) as ArrayRef + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Arc::new(TimestampMillisecondArray::from(data)) as ArrayRef + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Arc::new(TimestampMicrosecondArray::from(data)) as ArrayRef + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Arc::new(TimestampNanosecondArray::from(data)) as ArrayRef + } + DataType::Interval(IntervalUnit::YearMonth) => { + Arc::new(IntervalYearMonthArray::from(data)) as ArrayRef + } + DataType::Interval(IntervalUnit::DayTime) => { + Arc::new(IntervalDayTimeArray::from(data)) as ArrayRef + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Arc::new(IntervalMonthDayNanoArray::from(data)) as ArrayRef + } + DataType::Duration(TimeUnit::Second) => { + Arc::new(DurationSecondArray::from(data)) as ArrayRef + } + DataType::Duration(TimeUnit::Millisecond) => { + Arc::new(DurationMillisecondArray::from(data)) as ArrayRef + } + DataType::Duration(TimeUnit::Microsecond) => { + Arc::new(DurationMicrosecondArray::from(data)) as ArrayRef + } + DataType::Duration(TimeUnit::Nanosecond) => { + Arc::new(DurationNanosecondArray::from(data)) as ArrayRef + } + DataType::Binary => Arc::new(BinaryArray::from(data)) as ArrayRef, + DataType::LargeBinary => Arc::new(LargeBinaryArray::from(data)) as ArrayRef, + DataType::FixedSizeBinary(_) => Arc::new(FixedSizeBinaryArray::from(data)) as ArrayRef, + DataType::BinaryView => Arc::new(BinaryViewArray::from(data)) as ArrayRef, + DataType::Utf8 => Arc::new(StringArray::from(data)) as ArrayRef, + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data)) as ArrayRef, + DataType::Utf8View => Arc::new(StringViewArray::from(data)) as ArrayRef, + DataType::List(_) => Arc::new(ListArray::from(data)) as ArrayRef, + DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef, + DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, + DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, + DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef, + DataType::FixedSizeList(_, _) => Arc::new(FixedSizeListArray::from(data)) as ArrayRef, + DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { + DataType::Int8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + dt => panic!("Unexpected dictionary key type {dt:?}"), + }, + DataType::RunEndEncoded(ref run_ends_type, _) => match run_ends_type.data_type() { + DataType::Int16 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(RunArray::::from(data)) as ArrayRef, + dt => panic!("Unexpected data type for run_ends array {dt:?}"), + }, + DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, + DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, + DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, + dt => panic!("Unexpected data type {dt:?}"), + } +} + +/// Creates a new empty array +/// +/// ``` +/// use std::sync::Arc; +/// use arrow_schema::DataType; +/// use arrow_array::{ArrayRef, Int32Array, new_empty_array}; +/// +/// let empty_array = new_empty_array(&DataType::Int32); +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![] as Vec)); +/// +/// assert_eq!(&array, &empty_array); +/// ``` +pub fn new_empty_array(data_type: &DataType) -> ArrayRef { + let data = ArrayData::new_empty(data_type); + make_array(data) +} + +/// Creates a new array of `data_type` of length `length` filled +/// entirely of `NULL` values +/// +/// ``` +/// use std::sync::Arc; +/// use arrow_schema::DataType; +/// use arrow_array::{ArrayRef, Int32Array, new_null_array}; +/// +/// let null_array = new_null_array(&DataType::Int32, 3); +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![None, None, None])); +/// +/// assert_eq!(&array, &null_array); +/// ``` +pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { + make_array(ArrayData::new_null(data_type, length)) +} + +/// Helper function that gets offset from an [`ArrayData`] +/// +/// # Safety +/// +/// - ArrayData must contain a valid [`OffsetBuffer`] as its first buffer +unsafe fn get_offsets(data: &ArrayData) -> OffsetBuffer { + match data.is_empty() && data.buffers()[0].is_empty() { + true => OffsetBuffer::new_empty(), + false => { + let buffer = + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len() + 1); + // Safety: + // ArrayData is valid + unsafe { OffsetBuffer::new_unchecked(buffer) } + } + } +} + +/// Helper function for printing potentially long arrays. +fn print_long_array(array: &A, f: &mut std::fmt::Formatter, print_item: F) -> std::fmt::Result +where + A: Array, + F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result, +{ + let head = std::cmp::min(10, array.len()); + + for i in 0..head { + if array.is_null(i) { + writeln!(f, " null,")?; + } else { + write!(f, " ")?; + print_item(array, i, f)?; + writeln!(f, ",")?; + } + } + if array.len() > 10 { + if array.len() > 20 { + writeln!(f, " ...{} elements...,", array.len() - 20)?; + } + + let tail = std::cmp::max(head, array.len() - 10); + + for i in tail..array.len() { + if array.is_null(i) { + writeln!(f, " null,")?; + } else { + write!(f, " ")?; + print_item(array, i, f)?; + writeln!(f, ",")?; + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast::{as_union_array, downcast_array}; + use crate::downcast_run_array; + use arrow_buffer::MutableBuffer; + use arrow_schema::{Field, Fields, UnionFields, UnionMode}; + + #[test] + fn test_empty_primitive() { + let array = new_empty_array(&DataType::Int32); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 0); + let expected: &[i32] = &[]; + assert_eq!(a.values(), expected); + } + + #[test] + fn test_empty_variable_sized() { + let array = new_empty_array(&DataType::Utf8); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 0); + assert_eq!(a.value_offsets()[0], 0i32); + } + + #[test] + fn test_empty_list_primitive() { + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let array = new_empty_array(&data_type); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 0); + assert_eq!(a.value_offsets()[0], 0i32); + } + + #[test] + fn test_null_boolean() { + let array = new_null_array(&DataType::Boolean, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + + #[test] + fn test_null_primitive() { + let array = new_null_array(&DataType::Int32, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + + #[test] + fn test_null_struct() { + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details + let struct_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); + let array = new_null_array(&struct_type, 9); + + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + assert_eq!(a.column(0).len(), 9); + for i in 0..9 { + assert!(a.is_null(i)); + } + + // Make sure we can slice the resulting array. + a.slice(0, 5); + } + + #[test] + fn test_null_variable_sized() { + let array = new_null_array(&DataType::Utf8, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + assert_eq!(a.value_offsets()[9], 0i32); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + + #[test] + fn test_null_list_primitive() { + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let array = new_null_array(&data_type, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + assert_eq!(a.value_offsets()[9], 0i32); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + + #[test] + fn test_null_map() { + let data_type = DataType::Map( + Arc::new(Field::new( + "entry", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ); + let array = new_null_array(&data_type, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + assert_eq!(a.value_offsets()[9], 0i32); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + + #[test] + fn test_null_dictionary() { + let values = + vec![None, None, None, None, None, None, None, None, None] as Vec>; + + let array: DictionaryArray = values.into_iter().collect(); + let array = Arc::new(array) as ArrayRef; + + let null_array = new_null_array(array.data_type(), 9); + assert_eq!(&array, &null_array); + assert_eq!( + array.to_data().buffers()[0].len(), + null_array.to_data().buffers()[0].len() + ); + } + + #[test] + fn test_null_union() { + for mode in [UnionMode::Sparse, UnionMode::Dense] { + let data_type = DataType::Union( + UnionFields::new( + vec![2, 1], + vec![ + Field::new("foo", DataType::Int32, true), + Field::new("bar", DataType::Int64, true), + ], + ), + mode, + ); + let array = new_null_array(&data_type, 4); + + let array = as_union_array(array.as_ref()); + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + for i in 0..4 { + let a = array.value(i); + assert_eq!(a.len(), 1); + assert_eq!(a.null_count(), 1); + assert!(a.is_null(0)) + } + + array.to_data().validate_full().unwrap(); + } + } + + #[test] + #[allow(unused_parens)] + fn test_null_runs() { + for r in [DataType::Int16, DataType::Int32, DataType::Int64] { + let data_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", r, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + + let array = new_null_array(&data_type, 4); + let array = array.as_ref(); + + downcast_run_array! { + array => { + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + assert_eq!(array.values().len(), 1); + assert_eq!(array.values().null_count(), 1); + assert_eq!(array.run_ends().len(), 4); + assert_eq!(array.run_ends().values(), &[4]); + + let idx = array.get_physical_indices(&[0, 1, 2, 3]).unwrap(); + assert_eq!(idx, &[0,0,0,0]); + } + d => unreachable!("{d}") + } + } + } + + #[test] + fn test_null_fixed_size_binary() { + for size in [1, 2, 7] { + let array = new_null_array(&DataType::FixedSizeBinary(size), 6); + let array = array + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), 6); + assert_eq!(array.null_count(), 6); + array.iter().for_each(|x| assert!(x.is_none())); + } + } + + #[test] + fn test_memory_size_null() { + let null_arr = NullArray::new(32); + + assert_eq!(0, null_arr.get_buffer_memory_size()); + assert_eq!( + std::mem::size_of::(), + null_arr.get_array_memory_size() + ); + } + + #[test] + fn test_memory_size_primitive() { + let arr = PrimitiveArray::::from_iter_values(0..128); + let empty = PrimitiveArray::::from(ArrayData::new_empty(arr.data_type())); + + // subtract empty array to avoid magic numbers for the size of additional fields + assert_eq!( + arr.get_array_memory_size() - empty.get_array_memory_size(), + 128 * std::mem::size_of::() + ); + } + + #[test] + fn test_memory_size_primitive_sliced() { + let arr = PrimitiveArray::::from_iter_values(0..128); + let slice1 = arr.slice(0, 64); + let slice2 = arr.slice(64, 64); + + // both slices report the full buffer memory usage, even though the buffers are shared + assert_eq!(slice1.get_array_memory_size(), arr.get_array_memory_size()); + assert_eq!(slice2.get_array_memory_size(), arr.get_array_memory_size()); + } + + #[test] + fn test_memory_size_primitive_nullable() { + let arr: PrimitiveArray = (0..128) + .map(|i| if i % 20 == 0 { Some(i) } else { None }) + .collect(); + let empty_with_bitmap = PrimitiveArray::::from( + ArrayData::builder(arr.data_type().clone()) + .add_buffer(MutableBuffer::new(0).into()) + .null_bit_buffer(Some(MutableBuffer::new_null(0).into())) + .build() + .unwrap(), + ); + + // expected size is the size of the PrimitiveArray struct, + // which includes the optional validity buffer + // plus one buffer on the heap + assert_eq!( + std::mem::size_of::>(), + empty_with_bitmap.get_array_memory_size() + ); + + // subtract empty array to avoid magic numbers for the size of additional fields + // the size of the validity bitmap is rounded up to 64 bytes + assert_eq!( + arr.get_array_memory_size() - empty_with_bitmap.get_array_memory_size(), + 128 * std::mem::size_of::() + 64 + ); + } + + #[test] + fn test_memory_size_dictionary() { + let values = PrimitiveArray::::from_iter_values(0..16); + let keys = PrimitiveArray::::from_iter_values( + (0..256).map(|i| (i % values.len()) as i16), + ); + + let dict_data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + let dict_data = keys + .into_data() + .into_builder() + .data_type(dict_data_type) + .child_data(vec![values.into_data()]) + .build() + .unwrap(); + + let empty_data = ArrayData::new_empty(&DataType::Dictionary( + Box::new(DataType::Int16), + Box::new(DataType::Int64), + )); + + let arr = DictionaryArray::::from(dict_data); + let empty = DictionaryArray::::from(empty_data); + + let expected_keys_size = 256 * std::mem::size_of::(); + assert_eq!( + arr.keys().get_array_memory_size() - empty.keys().get_array_memory_size(), + expected_keys_size + ); + + let expected_values_size = 16 * std::mem::size_of::(); + assert_eq!( + arr.values().get_array_memory_size() - empty.values().get_array_memory_size(), + expected_values_size + ); + + let expected_size = expected_keys_size + expected_values_size; + assert_eq!( + arr.get_array_memory_size() - empty.get_array_memory_size(), + expected_size + ); + } + + /// Test function that takes an &dyn Array + fn compute_my_thing(arr: &dyn Array) -> bool { + !arr.is_empty() + } + + #[test] + fn test_array_ref_as_array() { + let arr: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + // works well! + assert!(compute_my_thing(&arr)); + + // Should also work when wrapped as an ArrayRef + let arr: ArrayRef = Arc::new(arr); + assert!(compute_my_thing(&arr)); + assert!(compute_my_thing(arr.as_ref())); + } + + #[test] + fn test_downcast_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + let array: Int32Array = downcast_array(&boxed); + + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(array, expected); + } +} diff --git a/arrow-array/src/array/null_array.rs b/arrow-array/src/array/null_array.rs new file mode 100644 index 000000000000..88cc2d911f82 --- /dev/null +++ b/arrow-array/src/array/null_array.rs @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains the `NullArray` type. + +use crate::builder::NullBuilder; +use crate::{Array, ArrayRef}; +use arrow_buffer::buffer::NullBuffer; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::DataType; +use std::any::Any; +use std::sync::Arc; + +/// An array of [null values](https://arrow.apache.org/docs/format/Columnar.html#null-layout) +/// +/// A `NullArray` is a simplified array where all values are null. +/// +/// # Example: Create an array +/// +/// ``` +/// use arrow_array::{Array, NullArray}; +/// +/// let array = NullArray::new(10); +/// +/// assert!(array.is_nullable()); +/// assert_eq!(array.len(), 10); +/// assert_eq!(array.null_count(), 0); +/// assert_eq!(array.logical_nulls().unwrap().null_count(), 10); +/// ``` +#[derive(Clone)] +pub struct NullArray { + len: usize, +} + +impl NullArray { + /// Create a new [`NullArray`] of the specified length + /// + /// *Note*: Use [`crate::array::new_null_array`] if you need an array of some + /// other [`DataType`]. + /// + pub fn new(length: usize) -> Self { + Self { len: length } + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced BooleanBuffer cannot exceed the existing length" + ); + + Self { len } + } + + /// Returns a new null array builder + /// + /// Note that the `capacity` parameter to this function is _deprecated_. It + /// now does nothing, and will be removed in a future version. + pub fn builder(_capacity: usize) -> NullBuilder { + NullBuilder::new() + } +} + +impl Array for NullArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &DataType::Null + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + + fn logical_nulls(&self) -> Option { + (self.len != 0).then(|| NullBuffer::new_null(self.len)) + } + + fn is_nullable(&self) -> bool { + !self.is_empty() + } + + fn get_buffer_memory_size(&self) -> usize { + 0 + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + } +} + +impl From for NullArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &DataType::Null, + "NullArray data type should be Null" + ); + assert_eq!( + data.buffers().len(), + 0, + "NullArray data should contain 0 buffers" + ); + assert!( + data.nulls().is_none(), + "NullArray data should not contain a null buffer, as no buffers are required" + ); + Self { len: data.len() } + } +} + +impl From for ArrayData { + fn from(array: NullArray) -> Self { + let builder = ArrayDataBuilder::new(DataType::Null).len(array.len); + unsafe { builder.build_unchecked() } + } +} + +impl std::fmt::Debug for NullArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "NullArray({})", self.len()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_null_array() { + let null_arr = NullArray::new(32); + + assert_eq!(null_arr.len(), 32); + assert_eq!(null_arr.null_count(), 0); + assert_eq!(null_arr.logical_nulls().unwrap().null_count(), 32); + assert!(null_arr.is_valid(0)); + assert!(null_arr.is_nullable()); + } + + #[test] + fn test_null_array_slice() { + let array1 = NullArray::new(32); + + let array2 = array1.slice(8, 16); + assert_eq!(array2.len(), 16); + assert_eq!(array2.null_count(), 0); + assert_eq!(array2.logical_nulls().unwrap().null_count(), 16); + assert!(array2.is_valid(0)); + assert!(array2.is_nullable()); + } + + #[test] + fn test_debug_null_array() { + let array = NullArray::new(1024 * 1024); + assert_eq!(format!("{array:?}"), "NullArray(1048576)"); + } +} diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs new file mode 100644 index 000000000000..567fa00e7385 --- /dev/null +++ b/arrow-array/src/array/primitive_array.rs @@ -0,0 +1,2734 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder}; +use crate::iterator::PrimitiveIter; +use crate::temporal_conversions::{ + as_date, as_datetime, as_datetime_with_timezone, as_duration, as_time, +}; +use crate::timezone::Tz; +use crate::trusted_len::trusted_len_unzip; +use crate::types::*; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; +use arrow_buffer::{i256, ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; +use arrow_data::bit_iterator::try_for_each_valid_idx; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime}; +use half::f16; +use std::any::Any; +use std::sync::Arc; + +/// A [`PrimitiveArray`] of `i8` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int8Array; +/// // Create from Vec> +/// let arr = Int8Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int8Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int8Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int8Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int16Array; +/// // Create from Vec> +/// let arr = Int16Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int16Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int16Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int32Array; +/// // Create from Vec> +/// let arr = Int32Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int32Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int32Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int64Array; +/// // Create from Vec> +/// let arr = Int64Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int64Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int64Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u8` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt8Array; +/// // Create from Vec> +/// let arr = UInt8Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt8Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt8Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt8Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt16Array; +/// // Create from Vec> +/// let arr = UInt16Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt16Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt16Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt32Array; +/// // Create from Vec> +/// let arr = UInt32Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt32Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt32Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt64Array; +/// // Create from Vec> +/// let arr = UInt64Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt64Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt64Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float16Array; +/// use half::f16; +/// // Create from Vec> +/// let arr = Float16Array::from(vec![Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))]); +/// // Create from Vec +/// let arr = Float16Array::from(vec![f16::from_f64(1.0), f16::from_f64(2.0), f16::from_f64(3.0)]); +/// // Create iter/collect +/// let arr: Float16Array = std::iter::repeat(f16::from_f64(1.0)).take(10).collect(); +/// ``` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::Float16Array; +/// use half::f16; +/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float32Array; +/// // Create from Vec> +/// let arr = Float32Array::from(vec![Some(1.0), None, Some(2.0)]); +/// // Create from Vec +/// let arr = Float32Array::from(vec![1.0, 2.0, 3.0]); +/// // Create iter/collect +/// let arr: Float32Array = std::iter::repeat(42.0).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float64Array; +/// // Create from Vec> +/// let arr = Float64Array::from(vec![Some(1.0), None, Some(2.0)]); +/// // Create from Vec +/// let arr = Float64Array::from(vec![1.0, 2.0, 3.0]); +/// // Create iter/collect +/// let arr: Float64Array = std::iter::repeat(42.0).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of seconds since UNIX epoch stored as `i64` +/// +/// This type is similar to the [`chrono::DateTime`] type and can hold +/// values such as `1970-05-09 14:25:11 +01:00` +/// +/// See also [`Timestamp`](arrow_schema::DataType::Timestamp). +/// +/// # Example: UTC timestamps post epoch +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1970-05-09T14:25:11+0:00 +/// let arr = TimestampSecondArray::from(vec![11111111]); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(11111111)]); +/// let utc_tz: Tz = "+00:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_tz).map(|v| v.to_string()).unwrap(), "1970-05-09 14:25:11 +00:00") +/// ``` +/// +/// # Example: UTC timestamps pre epoch +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1969-08-25T09:34:49+0:00 +/// let arr = TimestampSecondArray::from(vec![-11111111]); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(-11111111)]); +/// let utc_tz: Tz = "+00:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_tz).map(|v| v.to_string()).unwrap(), "1969-08-25 09:34:49 +00:00") +/// ``` +/// +/// # Example: With timezone specified +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1970-05-10T00:25:11+10:00 +/// let arr = TimestampSecondArray::from(vec![11111111]).with_timezone("+10:00".to_string()); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(11111111)]).with_timezone("+10:00".to_string()); +/// let sydney_tz: Tz = "+10:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, sydney_tz).map(|v| v.to_string()).unwrap(), "1970-05-10 00:25:11 +10:00") +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type TimestampSecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampMillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of microseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampMicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of nanoseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampNanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of days since UNIX epoch stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveDate`] type and can hold +/// values such as `2018-11-13` +pub type Date32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since UNIX epoch stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveDate`] type and can hold +/// values such as `2018-11-13` +pub type Date64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of seconds since midnight stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00` +pub type Time32SecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since midnight stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123` +pub type Time32MillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of microseconds since midnight stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123456` +pub type Time64MicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of nanoseconds since midnight stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123456789` +pub type Time64NanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in whole months +/// +/// See [`IntervalYearMonthType`] for details on representation and caveats. +/// +/// # Example +/// ``` +/// # use arrow_array::IntervalYearMonthArray; +/// let array = IntervalYearMonthArray::from(vec![ +/// 2, // 2 months +/// 25, // 2 years and 1 month +/// -1 // -1 months +/// ]); +/// ``` +pub type IntervalYearMonthArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in days and milliseconds +/// +/// See [`IntervalDayTime`] for details on representation and caveats. +/// +/// # Example +/// ``` +/// # use arrow_array::IntervalDayTimeArray; +/// use arrow_array::types::IntervalDayTime; +/// let array = IntervalDayTimeArray::from(vec![ +/// IntervalDayTime::new(1, 1000), // 1 day, 1000 milliseconds +/// IntervalDayTime::new(33, 0), // 33 days, 0 milliseconds +/// IntervalDayTime::new(0, 12 * 60 * 60 * 1000), // 0 days, 12 hours +/// ]); +/// ``` +pub type IntervalDayTimeArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in months, days, and nanoseconds. +/// +/// See [`IntervalMonthDayNano`] for details on representation and caveats. +/// +/// # Example +/// ``` +/// # use arrow_array::IntervalMonthDayNanoArray; +/// use arrow_array::types::IntervalMonthDayNano; +/// let array = IntervalMonthDayNanoArray::from(vec![ +/// IntervalMonthDayNano::new(1, 2, 1000), // 1 month, 2 days, 1 nanosecond +/// IntervalMonthDayNano::new(12, 1, 0), // 12 months, 1 days, 0 nanoseconds +/// IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), // 0 days, 12 milliseconds +/// ]); +/// ``` +pub type IntervalMonthDayNanoArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in seconds +pub type DurationSecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in milliseconds +pub type DurationMillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in microseconds +pub type DurationMicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in nanoseconds +pub type DurationNanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of 128-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal128Array; +/// // Create from Vec> +/// let arr = Decimal128Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Decimal128Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Decimal128Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal128Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of 256-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal256Array; +/// use arrow_buffer::i256; +/// // Create from Vec> +/// let arr = Decimal256Array::from(vec![Some(i256::from(1)), None, Some(i256::from(2))]); +/// // Create from Vec +/// let arr = Decimal256Array::from(vec![i256::from(1), i256::from(2), i256::from(3)]); +/// // Create iter/collect +/// let arr: Decimal256Array = std::iter::repeat(i256::from(42)).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal256Array = PrimitiveArray; + +pub use crate::types::ArrowPrimitiveType; + +/// An array of primitive values, of type [`ArrowPrimitiveType`] +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = vec![1, 2, 3, 4].into(); +/// assert_eq!(4, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// assert_eq!(arr.values(), &[1, 2, 3, 4]) +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = vec![Some(1), None, Some(3), None].into(); +/// assert_eq!(4, arr.len()); +/// assert_eq!(2, arr.null_count()); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(arr.values(), &[1, 0, 3, 0]) +/// ``` +/// +/// # Example: From an iterator of values +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = (0..10).map(|x| x + 1).collect(); +/// assert_eq!(10, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// for i in 0..10i32 { +/// assert_eq!(i + 1, arr.value(i as usize)); +/// } +/// ``` +/// +/// # Example: From an iterator of option +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = (0..10).map(|x| (x % 2 == 0).then_some(x)).collect(); +/// assert_eq!(10, arr.len()); +/// assert_eq!(5, arr.null_count()); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(arr.values(), &[0, 0, 2, 0, 4, 0, 6, 0, 8, 0]) +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::PrimitiveBuilder; +/// # use arrow_array::types::Int32Type; +/// let mut builder = PrimitiveBuilder::::new(); +/// builder.append_value(1); +/// builder.append_null(); +/// builder.append_value(2); +/// let array = builder.finish(); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(array.values(), &[1, 0, 2]); +/// assert!(array.is_null(1)); +/// ``` +/// +/// # Example: Get a `PrimitiveArray` from an [`ArrayRef`] +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{Array, cast::AsArray, ArrayRef, Float32Array, PrimitiveArray}; +/// # use arrow_array::types::{Float32Type}; +/// # use arrow_schema::DataType; +/// # let array: ArrayRef = Arc::new(Float32Array::from(vec![1.2, 2.3])); +/// // will panic if the array is not a Float32Array +/// assert_eq!(&DataType::Float32, array.data_type()); +/// let f32_array: Float32Array = array.as_primitive().clone(); +/// assert_eq!(f32_array, Float32Array::from(vec![1.2, 2.3])); +/// ``` +pub struct PrimitiveArray { + data_type: DataType, + /// Values data + values: ScalarBuffer, + nulls: Option, +} + +impl Clone for PrimitiveArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + values: self.values.clone(), + nulls: self.nulls.clone(), + } + } +} + +impl PrimitiveArray { + /// Create a new [`PrimitiveArray`] from the provided values and nulls + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + /// + /// # Example + /// + /// Creating a [`PrimitiveArray`] directly from a [`ScalarBuffer`] and [`NullBuffer`] using + /// this constructor is the most performant approach, avoiding any additional allocations + /// + /// ``` + /// # use arrow_array::Int32Array; + /// # use arrow_array::types::Int32Type; + /// # use arrow_buffer::NullBuffer; + /// // [1, 2, 3, 4] + /// let array = Int32Array::new(vec![1, 2, 3, 4].into(), None); + /// // [1, null, 3, 4] + /// let nulls = NullBuffer::from(vec![true, false, true, true]); + /// let array = Int32Array::new(vec![1, 2, 3, 4].into(), Some(nulls)); + /// ``` + pub fn new(values: ScalarBuffer, nulls: Option) -> Self { + Self::try_new(values, nulls).unwrap() + } + + /// Create a new [`PrimitiveArray`] of the given length where all values are null + pub fn new_null(length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + values: vec![T::Native::usize_as(0); length].into(), + nulls: Some(NullBuffer::new_null(length)), + } + } + + /// Create a new [`PrimitiveArray`] from the provided values and nulls + /// + /// # Errors + /// + /// Errors if: + /// - `values.len() != nulls.len()` + pub fn try_new( + values: ScalarBuffer, + nulls: Option, + ) -> Result { + if let Some(n) = nulls.as_ref() { + if n.len() != values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for PrimitiveArray, expected {} got {}", + values.len(), + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + values, + nulls, + }) + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: T::Native) -> Scalar { + Scalar::new(Self { + data_type: T::DATA_TYPE, + values: vec![value].into(), + nulls: None, + }) + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (DataType, ScalarBuffer, Option) { + (self.data_type, self.values, self.nulls) + } + + /// Overrides the [`DataType`] of this [`PrimitiveArray`] + /// + /// Prefer using [`Self::with_timezone`] or [`Self::with_precision_and_scale`] where + /// the primitive type is suitably constrained, as these cannot panic + /// + /// # Panics + /// + /// Panics if ![Self::is_compatible] + pub fn with_data_type(self, data_type: DataType) -> Self { + Self::assert_compatible(&data_type); + Self { data_type, ..self } + } + + /// Asserts that `data_type` is compatible with `Self` + fn assert_compatible(data_type: &DataType) { + assert!( + Self::is_compatible(data_type), + "PrimitiveArray expected data type {} got {}", + T::DATA_TYPE, + data_type + ); + } + + /// Returns the length of this array. + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns whether this array is empty. + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Returns the values of this array + #[inline] + pub fn values(&self) -> &ScalarBuffer { + &self.values + } + + /// Returns a new primitive array builder + pub fn builder(capacity: usize) -> PrimitiveBuilder { + PrimitiveBuilder::::with_capacity(capacity) + } + + /// Returns if this [`PrimitiveArray`] is compatible with the provided [`DataType`] + /// + /// This is equivalent to `data_type == T::DATA_TYPE`, however ignores timestamp + /// timezones and decimal precision and scale + pub fn is_compatible(data_type: &DataType) -> bool { + match T::DATA_TYPE { + DataType::Timestamp(t1, _) => { + matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2) + } + DataType::Decimal128(_, _) => matches!(data_type, DataType::Decimal128(_, _)), + DataType::Decimal256(_, _) => matches!(data_type, DataType::Decimal256(_, _)), + _ => T::DATA_TYPE.eq(data_type), + } + } + + /// Returns the primitive value at index `i`. + /// + /// # Safety + /// + /// caller must ensure that the passed in offset is less than the array len() + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> T::Native { + *self.values.get_unchecked(i) + } + + /// Returns the primitive value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds + #[inline] + pub fn value(&self, i: usize) -> T::Native { + assert!( + i < self.len(), + "Trying to access an element at index {} from a PrimitiveArray of length {}", + i, + self.len() + ); + unsafe { self.value_unchecked(i) } + } + + /// Creates a PrimitiveArray based on an iterator of values without nulls + pub fn from_iter_values>(iter: I) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let len = val_buf.len() / std::mem::size_of::(); + Self { + data_type: T::DATA_TYPE, + values: ScalarBuffer::new(val_buf, 0, len), + nulls: None, + } + } + + /// Creates a PrimitiveArray based on an iterator of values with provided nulls + pub fn from_iter_values_with_nulls>( + iter: I, + nulls: Option, + ) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let len = val_buf.len() / std::mem::size_of::(); + Self { + data_type: T::DATA_TYPE, + values: ScalarBuffer::new(val_buf, 0, len), + nulls, + } + } + + /// Creates a PrimitiveArray based on a constant value with `count` elements + pub fn from_value(value: T::Native, count: usize) -> Self { + unsafe { + let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value)); + Self::new(val_buf.into(), None) + } + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + values: self.values.slice(offset, length), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Reinterprets this array's contents as a different data type without copying + /// + /// This can be used to efficiently convert between primitive arrays with the + /// same underlying representation + /// + /// Note: this will not modify the underlying values, and therefore may change + /// the semantic values of the array, e.g. 100 milliseconds in a [`TimestampNanosecondArray`] + /// will become 100 seconds in a [`TimestampSecondArray`]. + /// + /// For casts that preserve the semantic value, check out the + /// [compute kernels](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/index.html). + /// + /// ``` + /// # use arrow_array::{Int64Array, TimestampNanosecondArray}; + /// let a = Int64Array::from_iter_values([1, 2, 3, 4]); + /// let b: TimestampNanosecondArray = a.reinterpret_cast(); + /// ``` + pub fn reinterpret_cast(&self) -> PrimitiveArray + where + K: ArrowPrimitiveType, + { + let d = self.to_data().into_builder().data_type(K::DATA_TYPE); + + // SAFETY: + // Native type is the same + PrimitiveArray::from(unsafe { d.build_unchecked() }) + } + + /// Applies a unary infallible function to a primitive array, producing a + /// new array of potentially different type. + /// + /// This is the fastest way to perform an operation on a primitive array + /// when the benefits of a vectorized operation outweigh the cost of + /// branching nulls and non-nulls. + /// + /// See also + /// * [`Self::unary_mut`] for in place modification. + /// * [`Self::try_unary`] for fallible operations. + /// * [`arrow::compute::binary`] for binary operations + /// + /// [`arrow::compute::binary`]: https://docs.rs/arrow/latest/arrow/compute/fn.binary.html + /// # Null Handling + /// + /// Applies the function for all values, including those on null slots. This + /// will often allow the compiler to generate faster vectorized code, but + /// requires that the operation must be infallible (not error/panic) for any + /// value of the corresponding type or this function may panic. + /// + /// # Example + /// ```rust + /// # use arrow_array::{Int32Array, Float32Array, types::Int32Type}; + /// # fn main() { + /// let array = Int32Array::from(vec![Some(5), Some(7), None]); + /// // Create a new array with the value of applying sqrt + /// let c = array.unary(|x| f32::sqrt(x as f32)); + /// assert_eq!(c, Float32Array::from(vec![Some(2.236068), Some(2.6457512), None])); + /// # } + /// ``` + pub fn unary(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> O::Native, + { + let nulls = self.nulls().cloned(); + let values = self.values().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + PrimitiveArray::new(buffer.into(), nulls) + } + + /// Applies a unary and infallible function to the array in place if possible. + /// + /// # Buffer Reuse + /// + /// If the underlying buffers are not shared with other arrays, mutates the + /// underlying buffer in place, without allocating. + /// + /// If the underlying buffer is shared, returns Err(self) + /// + /// # Null Handling + /// + /// See [`Self::unary`] for more information on null handling. + /// + /// # Example + /// + /// ```rust + /// # use arrow_array::{Int32Array, types::Int32Type}; + /// let array = Int32Array::from(vec![Some(5), Some(7), None]); + /// // Apply x*2+1 to the data in place, no allocations + /// let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + /// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + /// ``` + /// + /// # Example: modify [`ArrayRef`] in place, if not shared + /// + /// It is also possible to modify an [`ArrayRef`] if there are no other + /// references to the underlying buffer. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow_array::{Array, cast::AsArray, ArrayRef, Int32Array, PrimitiveArray, types::Int32Type}; + /// # let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(5), Some(7), None])); + /// // Convert to Int32Array (panic's if array.data_type is not Int32) + /// let a = array.as_primitive::().clone(); + /// // Try to apply x*2+1 to the data in place, fails because array is still shared + /// a.unary_mut(|x| x * 2 + 1).unwrap_err(); + /// // Try again, this time dropping the last remaining reference + /// let a = array.as_primitive::().clone(); + /// drop(array); + /// // Now we can apply the operation in place + /// let c = a.unary_mut(|x| x * 2 + 1).unwrap(); + /// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + /// ``` + + pub fn unary_mut(self, op: F) -> Result, PrimitiveArray> + where + F: Fn(T::Native) -> T::Native, + { + let mut builder = self.into_builder()?; + builder + .values_slice_mut() + .iter_mut() + .for_each(|v| *v = op(*v)); + Ok(builder.finish()) + } + + /// Applies a unary fallible function to all valid values in a primitive + /// array, producing a new array of potentially different type. + /// + /// Applies `op` to only rows that are valid, which is often significantly + /// slower than [`Self::unary`], which should be preferred if `op` is + /// fallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary(&self, op: F) -> Result, E> + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, + { + let len = self.len(); + + let nulls = self.nulls().cloned(); + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + let f = |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(self.value_unchecked(idx))? }; + Ok::<_, E>(()) + }; + + match &nulls { + Some(nulls) => nulls.try_for_each_valid_idx(f)?, + None => (0..len).try_for_each(f)?, + } + + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(values, nulls)) + } + + /// Applies a unary fallible function to all valid values in a mutable + /// primitive array. + /// + /// # Null Handling + /// + /// See [`Self::try_unary`] for more information on null handling. + /// + /// # Buffer Reuse + /// + /// See [`Self::unary_mut`] for more information on buffer reuse. + /// + /// This returns an `Err` when the input array is shared buffer with other + /// array. In the case, returned `Err` wraps input array. If the function + /// encounters an error during applying on values. In the case, this returns an `Err` within + /// an `Ok` which wraps the actual error. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary_mut( + self, + op: F, + ) -> Result, E>, PrimitiveArray> + where + F: Fn(T::Native) -> Result, + { + let len = self.len(); + let null_count = self.null_count(); + let mut builder = self.into_builder()?; + + let (slice, null_buffer) = builder.slices_mut(); + + let r = try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; + Ok::<_, E>(()) + }); + + if let Err(err) = r { + return Ok(Err(err)); + } + + Ok(Ok(builder.finish())) + } + + /// Applies a unary and nullable function to all valid values in a primitive array + /// + /// Applies `op` to only rows that are valid, which is often significantly + /// slower than [`Self::unary`], which should be preferred if `op` is + /// fallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn unary_opt(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Option, + { + let len = self.len(); + let (nulls, null_count, offset) = match self.nulls() { + Some(n) => (Some(n.validity()), n.null_count(), n.offset()), + None => (None, 0, 0), + }; + + let mut null_builder = BooleanBufferBuilder::new(len); + match nulls { + Some(b) => null_builder.append_packed_range(offset..offset + len, b), + None => null_builder.append_n(len, true), + } + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + let mut out_null_count = null_count; + + let _ = try_for_each_valid_idx(len, offset, null_count, nulls, |idx| { + match op(unsafe { self.value_unchecked(idx) }) { + Some(v) => unsafe { *slice.get_unchecked_mut(idx) = v }, + None => { + out_null_count += 1; + null_builder.set_bit(idx, false); + } + } + Ok::<_, ()>(()) + }); + + let nulls = null_builder.finish(); + let values = buffer.finish().into(); + let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count) }; + PrimitiveArray::new(values, Some(nulls)) + } + + /// Applies a unary infallible function to each value in an array, producing a + /// new primitive array. + /// + /// # Null Handling + /// + /// See [`Self::unary`] for more information on null handling. + /// + /// # Example: create an [`Int16Array`] from an [`ArrayAccessor`] with item type `&[u8]` + /// ``` + /// use arrow_array::{Array, FixedSizeBinaryArray, Int16Array}; + /// let input_arg = vec![ vec![1, 0], vec![2, 0], vec![3, 0] ]; + /// let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); + /// let c = Int16Array::from_unary(&arr, |x| i16::from_le_bytes(x[..2].try_into().unwrap())); + /// assert_eq!(c, Int16Array::from(vec![Some(1i16), Some(2i16), Some(3i16)])); + /// ``` + pub fn from_unary(left: U, mut op: F) -> Self + where + F: FnMut(U::Item) -> T::Native, + { + let nulls = left.logical_nulls(); + let buffer = unsafe { + // SAFETY: i in range 0..left.len() + let iter = (0..left.len()).map(|i| op(left.value_unchecked(i))); + // SAFETY: upper bound is trusted because `iter` is over a range + Buffer::from_trusted_len_iter(iter) + }; + + PrimitiveArray::new(buffer.into(), nulls) + } + + /// Returns a `PrimitiveBuilder` for this array, suitable for mutating values + /// in place. + /// + /// # Buffer Reuse + /// + /// If the underlying data buffer has no other outstanding references, the + /// buffer is used without copying. + /// + /// If the underlying data buffer does have outstanding references, returns + /// `Err(self)` + pub fn into_builder(self) -> Result, Self> { + let len = self.len(); + let data = self.into_data(); + let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); + + let element_len = std::mem::size_of::(); + let buffer = + data.buffers()[0].slice_with_length(data.offset() * element_len, len * element_len); + + drop(data); + + let try_mutable_null_buffer = match null_bit_buffer { + None => Ok(None), + Some(null_buffer) => { + // Null buffer exists, tries to make it mutable + null_buffer.into_mutable().map(Some) + } + }; + + let try_mutable_buffers = match try_mutable_null_buffer { + Ok(mutable_null_buffer) => { + // Got mutable null buffer, tries to get mutable value buffer + let try_mutable_buffer = buffer.into_mutable(); + + // try_mutable_buffer.map(...).map_err(...) doesn't work as the compiler complains + // mutable_null_buffer is moved into map closure. + match try_mutable_buffer { + Ok(mutable_buffer) => Ok(PrimitiveBuilder::::new_from_buffer( + mutable_buffer, + mutable_null_buffer, + )), + Err(buffer) => Err((buffer, mutable_null_buffer.map(|b| b.into()))), + } + } + Err(mutable_null_buffer) => { + // Unable to get mutable null buffer + Err((buffer, Some(mutable_null_buffer))) + } + }; + + match try_mutable_buffers { + Ok(builder) => Ok(builder), + Err((buffer, null_bit_buffer)) => { + let builder = ArrayData::builder(T::DATA_TYPE) + .len(len) + .add_buffer(buffer) + .null_bit_buffer(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + let array = PrimitiveArray::::from(array_data); + + Err(array) + } + } + } +} + +impl From> for ArrayData { + fn from(array: PrimitiveArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.values.len()) + .nulls(array.nulls) + .buffers(vec![array.values.into_inner()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for PrimitiveArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a, T: ArrowPrimitiveType> ArrayAccessor for &'a PrimitiveArray { + type Item = T::Native; + + fn value(&self, index: usize) -> Self::Item { + PrimitiveArray::value(self, index) + } + + #[inline] + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + PrimitiveArray::value_unchecked(self, index) + } +} + +impl PrimitiveArray +where + i64: From, +{ + /// Returns value as a chrono `NaiveDateTime`, handling time resolution + /// + /// If a data type cannot be converted to `NaiveDateTime`, a `None` is returned. + /// A valid value is expected, thus the user should first check for validity. + pub fn value_as_datetime(&self, i: usize) -> Option { + as_datetime::(i64::from(self.value(i))) + } + + /// Returns value as a chrono `NaiveDateTime`, handling time resolution with the provided tz + /// + /// functionally it is same as `value_as_datetime`, however it adds + /// the passed tz to the to-be-returned NaiveDateTime + pub fn value_as_datetime_with_tz(&self, i: usize, tz: Tz) -> Option> { + as_datetime_with_timezone::(i64::from(self.value(i)), tz) + } + + /// Returns value as a chrono `NaiveDate` by using `Self::datetime()` + /// + /// If a data type cannot be converted to `NaiveDate`, a `None` is returned + pub fn value_as_date(&self, i: usize) -> Option { + self.value_as_datetime(i).map(|datetime| datetime.date()) + } + + /// Returns a value as a chrono `NaiveTime` + /// + /// `Date32` and `Date64` return UTC midnight as they do not have time resolution + pub fn value_as_time(&self, i: usize) -> Option { + as_time::(i64::from(self.value(i))) + } + + /// Returns a value as a chrono `Duration` + /// + /// If a data type cannot be converted to `Duration`, a `None` is returned + pub fn value_as_duration(&self, i: usize) -> Option { + as_duration::(i64::from(self.value(i))) + } +} + +impl std::fmt::Debug for PrimitiveArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let data_type = self.data_type(); + + write!(f, "PrimitiveArray<{data_type:?}>\n[\n")?; + print_long_array(self, f, |array, index, f| match data_type { + DataType::Date32 | DataType::Date64 => { + let v = self.value(index).to_i64().unwrap(); + match as_date::(v) { + Some(date) => write!(f, "{date:?}"), + None => { + write!( + f, + "Cast error: Failed to convert {v} to temporal for {data_type:?}" + ) + } + } + } + DataType::Time32(_) | DataType::Time64(_) => { + let v = self.value(index).to_i64().unwrap(); + match as_time::(v) { + Some(time) => write!(f, "{time:?}"), + None => { + write!( + f, + "Cast error: Failed to convert {v} to temporal for {data_type:?}" + ) + } + } + } + DataType::Timestamp(_, tz_string_opt) => { + let v = self.value(index).to_i64().unwrap(); + match tz_string_opt { + // for Timestamp with TimeZone + Some(tz_string) => { + match tz_string.parse::() { + // if the time zone is valid, construct a DateTime and format it as rfc3339 + Ok(tz) => match as_datetime_with_timezone::(v, tz) { + Some(datetime) => write!(f, "{}", datetime.to_rfc3339()), + None => write!(f, "null"), + }, + // if the time zone is invalid, shows NaiveDateTime with an error message + Err(_) => match as_datetime::(v) { + Some(datetime) => { + write!(f, "{datetime:?} (Unknown Time Zone '{tz_string}')") + } + None => write!(f, "null"), + }, + } + } + // for Timestamp without TimeZone + None => match as_datetime::(v) { + Some(datetime) => write!(f, "{datetime:?}"), + None => write!(f, "null"), + }, + } + } + _ => std::fmt::Debug::fmt(&array.value(index), f), + })?; + write!(f, "]") + } +} + +impl<'a, T: ArrowPrimitiveType> IntoIterator for &'a PrimitiveArray { + type Item = Option<::Native>; + type IntoIter = PrimitiveIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + PrimitiveIter::<'a, T>::new(self) + } +} + +impl<'a, T: ArrowPrimitiveType> PrimitiveArray { + /// constructs a new iterator + pub fn iter(&'a self) -> PrimitiveIter<'a, T> { + PrimitiveIter::<'a, T>::new(self) + } +} + +/// An optional primitive value +/// +/// This struct is used as an adapter when creating `PrimitiveArray` from an iterator. +/// `FromIterator` for `PrimitiveArray` takes an iterator where the elements can be `into` +/// this struct. So once implementing `From` or `Into` trait for a type, an iterator of +/// the type can be collected to `PrimitiveArray`. +#[derive(Debug)] +pub struct NativeAdapter { + /// Corresponding Rust native type if available + pub native: Option, +} + +macro_rules! def_from_for_primitive { + ( $ty:ident, $tt:tt) => { + impl From<$tt> for NativeAdapter<$ty> { + fn from(value: $tt) -> Self { + NativeAdapter { + native: Some(value), + } + } + } + }; +} + +def_from_for_primitive!(Int8Type, i8); +def_from_for_primitive!(Int16Type, i16); +def_from_for_primitive!(Int32Type, i32); +def_from_for_primitive!(Int64Type, i64); +def_from_for_primitive!(UInt8Type, u8); +def_from_for_primitive!(UInt16Type, u16); +def_from_for_primitive!(UInt32Type, u32); +def_from_for_primitive!(UInt64Type, u64); +def_from_for_primitive!(Float16Type, f16); +def_from_for_primitive!(Float32Type, f32); +def_from_for_primitive!(Float64Type, f64); +def_from_for_primitive!(Decimal128Type, i128); +def_from_for_primitive!(Decimal256Type, i256); + +impl From::Native>> for NativeAdapter { + fn from(value: Option<::Native>) -> Self { + NativeAdapter { native: value } + } +} + +impl From<&Option<::Native>> for NativeAdapter { + fn from(value: &Option<::Native>) -> Self { + NativeAdapter { native: *value } + } +} + +impl>> FromIterator for PrimitiveArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut null_builder = BooleanBufferBuilder::new(lower); + + let buffer: Buffer = iter + .map(|item| { + if let Some(a) = item.into().native { + null_builder.append(true); + a + } else { + null_builder.append(false); + // this ensures that null items on the buffer are not arbitrary. + // This is important because fallible operations can use null values (e.g. a vectorized "add") + // which may panic (e.g. overflow if the number on the slots happen to be very large). + T::Native::default() + } + }) + .collect(); + + let len = null_builder.len(); + + let data = unsafe { + ArrayData::new_unchecked( + T::DATA_TYPE, + len, + None, + Some(null_builder.into()), + 0, + vec![buffer], + vec![], + ) + }; + PrimitiveArray::from(data) + } +} + +impl PrimitiveArray { + /// Creates a [`PrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: std::borrow::Borrow::Native>>, + I: IntoIterator, + { + let iterator = iter.into_iter(); + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let (null, buffer) = trusted_len_unzip(iterator); + + let data = + ArrayData::new_unchecked(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]); + PrimitiveArray::from(data) + } +} + +// TODO: the macro is needed here because we'd get "conflicting implementations" error +// otherwise with both `From>` and `From>>`. +// We should revisit this in future. +macro_rules! def_numeric_from_vec { + ( $ty:ident ) => { + impl From::Native>> for PrimitiveArray<$ty> { + fn from(data: Vec<<$ty as ArrowPrimitiveType>::Native>) -> Self { + let array_data = ArrayData::builder($ty::DATA_TYPE) + .len(data.len()) + .add_buffer(Buffer::from_vec(data)); + let array_data = unsafe { array_data.build_unchecked() }; + PrimitiveArray::from(array_data) + } + } + + // Constructs a primitive array from a vector. Should only be used for testing. + impl From::Native>>> for PrimitiveArray<$ty> { + fn from(data: Vec::Native>>) -> Self { + PrimitiveArray::from_iter(data.iter()) + } + } + }; +} + +def_numeric_from_vec!(Int8Type); +def_numeric_from_vec!(Int16Type); +def_numeric_from_vec!(Int32Type); +def_numeric_from_vec!(Int64Type); +def_numeric_from_vec!(UInt8Type); +def_numeric_from_vec!(UInt16Type); +def_numeric_from_vec!(UInt32Type); +def_numeric_from_vec!(UInt64Type); +def_numeric_from_vec!(Float16Type); +def_numeric_from_vec!(Float32Type); +def_numeric_from_vec!(Float64Type); +def_numeric_from_vec!(Decimal128Type); +def_numeric_from_vec!(Decimal256Type); + +def_numeric_from_vec!(Date32Type); +def_numeric_from_vec!(Date64Type); +def_numeric_from_vec!(Time32SecondType); +def_numeric_from_vec!(Time32MillisecondType); +def_numeric_from_vec!(Time64MicrosecondType); +def_numeric_from_vec!(Time64NanosecondType); +def_numeric_from_vec!(IntervalYearMonthType); +def_numeric_from_vec!(IntervalDayTimeType); +def_numeric_from_vec!(IntervalMonthDayNanoType); +def_numeric_from_vec!(DurationSecondType); +def_numeric_from_vec!(DurationMillisecondType); +def_numeric_from_vec!(DurationMicrosecondType); +def_numeric_from_vec!(DurationNanosecondType); +def_numeric_from_vec!(TimestampSecondType); +def_numeric_from_vec!(TimestampMillisecondType); +def_numeric_from_vec!(TimestampMicrosecondType); +def_numeric_from_vec!(TimestampNanosecondType); + +impl PrimitiveArray { + /// Construct a timestamp array from a vec of i64 values and an optional timezone + #[deprecated(note = "Use with_timezone_opt instead")] + pub fn from_vec(data: Vec, timezone: Option) -> Self + where + Self: From>, + { + Self::from(data).with_timezone_opt(timezone) + } + + /// Construct a timestamp array from a vec of `Option` values and an optional timezone + #[deprecated(note = "Use with_timezone_opt instead")] + pub fn from_opt_vec(data: Vec>, timezone: Option) -> Self + where + Self: From>>, + { + Self::from(data).with_timezone_opt(timezone) + } + + /// Returns the timezone of this array if any + pub fn timezone(&self) -> Option<&str> { + match self.data_type() { + DataType::Timestamp(_, tz) => tz.as_deref(), + _ => unreachable!(), + } + } + + /// Construct a timestamp array with new timezone + pub fn with_timezone(self, timezone: impl Into>) -> Self { + self.with_timezone_opt(Some(timezone.into())) + } + + /// Construct a timestamp array with UTC + pub fn with_timezone_utc(self) -> Self { + self.with_timezone("+00:00") + } + + /// Construct a timestamp array with an optional timezone + pub fn with_timezone_opt>>(self, timezone: Option) -> Self { + Self { + data_type: DataType::Timestamp(T::UNIT, timezone.map(Into::into)), + ..self + } + } +} + +/// Constructs a `PrimitiveArray` from an array data reference. +impl From for PrimitiveArray { + fn from(data: ArrayData) -> Self { + Self::assert_compatible(data.data_type()); + assert_eq!( + data.buffers().len(), + 1, + "PrimitiveArray data should contain a single buffer only (values buffer)" + ); + + let values = ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + Self { + data_type: data.data_type().clone(), + values, + nulls: data.nulls().cloned(), + } + } +} + +impl PrimitiveArray { + /// Returns a Decimal array with the same data as self, with the + /// specified precision and scale. + /// + /// See [`validate_decimal_precision_and_scale`] + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { + validate_decimal_precision_and_scale::(precision, scale)?; + Ok(Self { + data_type: T::TYPE_CONSTRUCTOR(precision, scale), + ..self + }) + } + + /// Validates values in this array can be properly interpreted + /// with the specified precision. + pub fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> { + (0..self.len()).try_for_each(|idx| { + if self.is_valid(idx) { + let decimal = unsafe { self.value_unchecked(idx) }; + T::validate_decimal_precision(decimal, precision) + } else { + Ok(()) + } + }) + } + + /// Validates the Decimal Array, if the value of slot is overflow for the specified precision, and + /// will be casted to Null + pub fn null_if_overflow_precision(&self, precision: u8) -> Self { + self.unary_opt::<_, T>(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) + } + + /// Returns [`Self::value`] formatted as a string + pub fn value_as_string(&self, row: usize) -> String { + T::format_decimal(self.value(row), self.precision(), self.scale()) + } + + /// Returns the decimal precision of this array + pub fn precision(&self) -> u8 { + match T::BYTE_LENGTH { + 16 => { + if let DataType::Decimal128(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal128Array datatype is not DataType::Decimal128 but {}", + self.data_type() + ) + } + } + 32 => { + if let DataType::Decimal256(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal256Array datatype is not DataType::Decimal256 but {}", + self.data_type() + ) + } + } + other => unreachable!("Unsupported byte length for decimal array {}", other), + } + } + + /// Returns the decimal scale of this array + pub fn scale(&self) -> i8 { + match T::BYTE_LENGTH { + 16 => { + if let DataType::Decimal128(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal128Array datatype is not DataType::Decimal128 but {}", + self.data_type() + ) + } + } + 32 => { + if let DataType::Decimal256(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal256Array datatype is not DataType::Decimal256 but {}", + self.data_type() + ) + } + } + other => unreachable!("Unsupported byte length for decimal array {}", other), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{Decimal128Builder, Decimal256Builder}; + use crate::cast::downcast_array; + use crate::BooleanArray; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; + use arrow_schema::TimeUnit; + + #[test] + fn test_primitive_array_from_vec() { + let buf = Buffer::from_slice_ref([0, 1, 2, 3, 4]); + let arr = Int32Array::from(vec![0, 1, 2, 3, 4]); + assert_eq!(&buf, arr.values.inner()); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..5 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i as i32, arr.value(i)); + } + } + + #[test] + fn test_primitive_array_from_vec_option() { + // Test building a primitive array with null values + let arr = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(2, arr.null_count()); + for i in 0..5 { + if i % 2 == 0 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i as i32, arr.value(i)); + } else { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } + } + } + + #[test] + fn test_date64_array_from_vec_option() { + // Test building a primitive array with null values + // we use Int32 and Int64 as a backing array, so all Int32 and Int64 conventions + // work + let arr: PrimitiveArray = + vec![Some(1550902545147), None, Some(1550902545147)].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + for i in 0..3 { + if i % 2 == 0 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(1550902545147, arr.value(i)); + // roundtrip to and from datetime + assert_eq!( + 1550902545147, + arr.value_as_datetime(i) + .unwrap() + .and_utc() + .timestamp_millis() + ); + } else { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } + } + } + + #[test] + fn test_time32_millisecond_array_from_vec() { + // 1: 00:00:00.001 + // 37800005: 10:30:00.005 + // 86399210: 23:59:59.210 + let arr: PrimitiveArray = vec![1, 37_800_005, 86_399_210].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + for (i, formatted) in formatted.iter().enumerate().take(3) { + // check that we can't create dates or datetimes from time instances + assert_eq!(None, arr.value_as_datetime(i)); + assert_eq!(None, arr.value_as_date(i)); + let time = arr.value_as_time(i).unwrap(); + assert_eq!(*formatted, time.format("%H:%M:%S%.3f").to_string()); + } + } + + #[test] + fn test_time64_nanosecond_array_from_vec() { + // Test building a primitive array with null values + // we use Int32 and Int64 as a backing array, so all Int32 and Int64 conventions + // work + + // 1e6: 00:00:00.001 + // 37800005e6: 10:30:00.005 + // 86399210e6: 23:59:59.210 + let arr: PrimitiveArray = + vec![1_000_000, 37_800_005_000_000, 86_399_210_000_000].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + for (i, item) in formatted.iter().enumerate().take(3) { + // check that we can't create dates or datetimes from time instances + assert_eq!(None, arr.value_as_datetime(i)); + assert_eq!(None, arr.value_as_date(i)); + let time = arr.value_as_time(i).unwrap(); + assert_eq!(*item, time.format("%H:%M:%S%.3f").to_string()); + } + } + + #[test] + fn test_interval_array_from_vec() { + // intervals are currently not treated specially, but are Int32 and Int64 arrays + let arr = IntervalYearMonthArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let v0 = IntervalDayTime { + days: 34, + milliseconds: 1, + }; + let v2 = IntervalDayTime { + days: -2, + milliseconds: -5, + }; + + let arr = IntervalDayTimeArray::from(vec![Some(v0), None, Some(v2)]); + + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(v0, arr.value(0)); + assert_eq!(v0, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(v2, arr.value(2)); + assert_eq!(v2, arr.values()[2]); + + let v0 = IntervalMonthDayNano { + months: 2, + days: 34, + nanoseconds: -1, + }; + let v2 = IntervalMonthDayNano { + months: -3, + days: -2, + nanoseconds: 4, + }; + + let arr = IntervalMonthDayNanoArray::from(vec![Some(v0), None, Some(v2)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(v0, arr.value(0)); + assert_eq!(v0, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(v2, arr.value(2)); + assert_eq!(v2, arr.values()[2]); + } + + #[test] + fn test_duration_array_from_vec() { + let arr = DurationSecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationMillisecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationMicrosecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationNanosecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + } + + #[test] + fn test_timestamp_array_from_vec() { + let arr = TimestampSecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampMillisecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampMicrosecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampNanosecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + } + + #[test] + fn test_primitive_array_slice() { + let arr = Int32Array::from(vec![ + Some(0), + None, + Some(2), + None, + Some(4), + Some(5), + Some(6), + None, + None, + ]); + assert_eq!(9, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(4, arr.null_count()); + + let arr2 = arr.slice(2, 5); + assert_eq!(5, arr2.len()); + assert_eq!(1, arr2.null_count()); + + for i in 0..arr2.len() { + assert_eq!(i == 1, arr2.is_null(i)); + assert_eq!(i != 1, arr2.is_valid(i)); + } + let int_arr2 = arr2.as_any().downcast_ref::().unwrap(); + assert_eq!(2, int_arr2.values()[0]); + assert_eq!(&[4, 5, 6], &int_arr2.values()[2..5]); + + let arr3 = arr2.slice(2, 3); + assert_eq!(3, arr3.len()); + assert_eq!(0, arr3.null_count()); + + let int_arr3 = arr3.as_any().downcast_ref::().unwrap(); + assert_eq!(&[4, 5, 6], int_arr3.values()); + assert_eq!(4, int_arr3.value(0)); + assert_eq!(5, int_arr3.value(1)); + assert_eq!(6, int_arr3.value(2)); + } + + #[test] + fn test_boolean_array_slice() { + let arr = BooleanArray::from(vec![ + Some(true), + None, + Some(false), + None, + Some(true), + Some(false), + Some(true), + Some(false), + None, + Some(true), + ]); + + assert_eq!(10, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(3, arr.null_count()); + + let arr2 = arr.slice(3, 5); + assert_eq!(5, arr2.len()); + assert_eq!(3, arr2.offset()); + assert_eq!(1, arr2.null_count()); + + let bool_arr = arr2.as_any().downcast_ref::().unwrap(); + + assert!(!bool_arr.is_valid(0)); + + assert!(bool_arr.is_valid(1)); + assert!(bool_arr.value(1)); + + assert!(bool_arr.is_valid(2)); + assert!(!bool_arr.value(2)); + + assert!(bool_arr.is_valid(3)); + assert!(bool_arr.value(3)); + + assert!(bool_arr.is_valid(4)); + assert!(!bool_arr.value(4)); + } + + #[test] + fn test_int32_fmt_debug() { + let arr = Int32Array::from(vec![0, 1, 2, 3, 4]); + assert_eq!( + "PrimitiveArray\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_fmt_debug_up_to_20_elements() { + (1..=20).for_each(|i| { + let values = (0..i).collect::>(); + let array_expected = format!( + "PrimitiveArray\n[\n{}\n]", + values + .iter() + .map(|v| { format!(" {v},") }) + .collect::>() + .join("\n") + ); + let array = Int16Array::from(values); + + assert_eq!(array_expected, format!("{array:?}")); + }) + } + + #[test] + fn test_int32_with_null_fmt_debug() { + let mut builder = Int32Array::builder(3); + builder.append_slice(&[0, 1]); + builder.append_null(); + builder.append_slice(&[3, 4]); + let arr = builder.finish(); + assert_eq!( + "PrimitiveArray\n[\n 0,\n 1,\n null,\n 3,\n 4,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00,\n 2018-12-31T00:00:00,\n 1921-01-02T00:00:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_utc_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone_utc(); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00+00:00,\n 2018-12-31T00:00:00+00:00,\n 1921-01-02T00:00:00+00:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + #[cfg(feature = "chrono-tz")] + fn test_timestamp_with_named_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + format!("{:?}", arr) + ); + } + + #[test] + #[cfg(not(feature = "chrono-tz"))] + fn test_timestamp_with_named_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); + + println!("{arr:?}"); + + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_with_fixed_offset_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("+08:00".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_with_incorrect_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("xxx".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'xxx'),\n]", + format!("{arr:?}") + ); + } + + #[test] + #[cfg(feature = "chrono-tz")] + fn test_timestamp_with_tz_with_daylight_saving_fmt_debug() { + let arr: PrimitiveArray = TimestampMillisecondArray::from(vec![ + 1647161999000, + 1647162000000, + 1667717999000, + 1667718000000, + ]) + .with_timezone("America/Denver".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2022-03-13T01:59:59-07:00,\n 2022-03-13T03:00:00-06:00,\n 2022-11-06T00:59:59-06:00,\n 2022-11-06T01:00:00-06:00,\n]", + format!("{:?}", arr) + ); + } + + #[test] + fn test_date32_fmt_debug() { + let arr: PrimitiveArray = vec![12356, 13548, -365].into(); + assert_eq!( + "PrimitiveArray\n[\n 2003-10-31,\n 2007-02-04,\n 1969-01-01,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_time32second_fmt_debug() { + let arr: PrimitiveArray = vec![7201, 60054].into(); + assert_eq!( + "PrimitiveArray\n[\n 02:00:01,\n 16:40:54,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_time32second_invalid_neg() { + // chrono::NaiveDatetime::from_timestamp_opt returns None while input is invalid + let arr: PrimitiveArray = vec![-7201, -60054].into(); + assert_eq!( + "PrimitiveArray\n[\n Cast error: Failed to convert -7201 to temporal for Time32(Second),\n Cast error: Failed to convert -60054 to temporal for Time32(Second),\n]", + // "PrimitiveArray\n[\n null,\n null,\n]", + format!("{arr:?}") + ) + } + + #[test] + fn test_timestamp_micros_out_of_range() { + // replicate the issue from https://github.com/apache/arrow-datafusion/issues/3832 + let arr: PrimitiveArray = vec![9065525203050843594].into(); + assert_eq!( + "PrimitiveArray\n[\n null,\n]", + format!("{arr:?}") + ) + } + + #[test] + fn test_primitive_array_builder() { + // Test building a primitive array with ArrayData builder and offset + let buf = Buffer::from_slice_ref([0i32, 1, 2, 3, 4, 5, 6]); + let buf2 = buf.slice_with_length(8, 20); + let data = ArrayData::builder(DataType::Int32) + .len(5) + .offset(2) + .add_buffer(buf) + .build() + .unwrap(); + let arr = Int32Array::from(data); + assert_eq!(&buf2, arr.values.inner()); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.null_count()); + for i in 0..3 { + assert_eq!((i + 2) as i32, arr.value(i)); + } + } + + #[test] + fn test_primitive_from_iter_values() { + // Test building a primitive array with from_iter_values + let arr: PrimitiveArray = PrimitiveArray::from_iter_values(0..10); + assert_eq!(10, arr.len()); + assert_eq!(0, arr.null_count()); + for i in 0..10i32 { + assert_eq!(i, arr.value(i as usize)); + } + } + + #[test] + fn test_primitive_array_from_unbound_iter() { + // iterator that doesn't declare (upper) size bound + let value_iter = (0..) + .scan(0usize, |pos, i| { + if *pos < 10 { + *pos += 1; + Some(Some(i)) + } else { + // actually returns up to 10 values + None + } + }) + // limited using take() + .take(100); + + let (_, upper_size_bound) = value_iter.size_hint(); + // the upper bound, defined by take above, is 100 + assert_eq!(upper_size_bound, Some(100)); + let primitive_array: PrimitiveArray = value_iter.collect(); + // but the actual number of items in the array should be 10 + assert_eq!(primitive_array.len(), 10); + } + + #[test] + fn test_primitive_array_from_non_null_iter() { + let iter = (0..10_i32).map(Some); + let primitive_array = PrimitiveArray::::from_iter(iter); + assert_eq!(primitive_array.len(), 10); + assert_eq!(primitive_array.null_count(), 0); + assert!(primitive_array.nulls().is_none()); + assert_eq!(primitive_array.values(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + } + + #[test] + #[should_panic(expected = "PrimitiveArray data should contain a single buffer only \ + (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_primitive_array_invalid_buffer_len() { + let buffer = Buffer::from_slice_ref([0i32, 1, 2, 3, 4]); + let data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buffer.clone()) + .add_buffer(buffer) + .len(5) + .build_unchecked() + }; + + drop(Int32Array::from(data)); + } + + #[test] + fn test_access_array_concurrently() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + let ret = std::thread::spawn(move || a.value(3)).join(); + + assert!(ret.is_ok()); + assert_eq!(8, ret.ok().unwrap()); + } + + #[test] + fn test_primitive_array_creation() { + let array1: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().collect(); + let array2: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().map(Some).collect(); + + assert_eq!(array1, array2); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" + )] + fn test_string_array_get_value_index_out_of_bound() { + let array: Int8Array = [10_i8, 11, 12].into_iter().collect(); + + array.value(4); + } + + #[test] + #[should_panic(expected = "PrimitiveArray expected data type Int64 got Int32")] + fn test_from_array_data_validation() { + let foo = PrimitiveArray::::from_iter([1, 2, 3]); + let _ = PrimitiveArray::::from(foo.into_data()); + } + + #[test] + fn test_decimal128() { + let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX]; + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_decimal256() { + let values: Vec<_> = vec![i256::ZERO, i256::ONE, i256::MINUS_ONE, i256::MIN, i256::MAX]; + + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_decimal_array() { + // let val_8887: [u8; 16] = [192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + // let val_neg_8887: [u8; 16] = [64, 36, 75, 238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255]; + let values: [u8; 32] = [ + 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]; + let array_data = ArrayData::builder(DataType::Decimal128(38, 6)) + .len(2) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + let decimal_array = Decimal128Array::from(array_data); + assert_eq!(8_887_000_000_i128, decimal_array.value(0)); + assert_eq!(-8_887_000_000_i128, decimal_array.value(1)); + } + + #[test] + fn test_decimal_append_error_value() { + let mut decimal_builder = Decimal128Builder::with_capacity(10); + decimal_builder.append_value(123456); + decimal_builder.append_value(12345); + let result = decimal_builder.finish().with_precision_and_scale(5, 3); + assert!(result.is_ok()); + let arr = result.unwrap(); + assert_eq!("12.345", arr.value_as_string(1)); + + // Validate it explicitly + let result = arr.validate_decimal_precision(5); + let error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999", + error.to_string() + ); + + decimal_builder = Decimal128Builder::new(); + decimal_builder.append_value(100); + decimal_builder.append_value(99); + decimal_builder.append_value(-100); + decimal_builder.append_value(-99); + let result = decimal_builder.finish().with_precision_and_scale(2, 1); + assert!(result.is_ok()); + let arr = result.unwrap(); + assert_eq!("9.9", arr.value_as_string(1)); + assert_eq!("-9.9", arr.value_as_string(3)); + + // Validate it explicitly + let result = arr.validate_decimal_precision(2); + let error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 100 is too large to store in a Decimal128 of precision 2. Max is 99", + error.to_string() + ); + } + + #[test] + fn test_decimal_from_iter_values() { + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101]); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert_eq!(0_i128, array.value(1)); + assert!(!array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_from_iter() { + let array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_iter_sized() { + let data = vec![Some(-100), None, Some(101)]; + let array: Decimal128Array = data.into_iter().collect(); + let mut iter = array.into_iter(); + + // is exact sized + assert_eq!(array.len(), 3); + + // size_hint is reported correctly + assert_eq!(iter.size_hint(), (3, Some(3))); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (2, Some(2))); + iter.next().unwrap(); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert!(iter.next().is_none()); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + #[test] + fn test_decimal_array_value_as_string() { + let arr = [123450, -123450, 100, -100, 10, -10, 0] + .into_iter() + .map(Some) + .collect::() + .with_precision_and_scale(6, 3) + .unwrap(); + + assert_eq!("123.450", arr.value_as_string(0)); + assert_eq!("-123.450", arr.value_as_string(1)); + assert_eq!("0.100", arr.value_as_string(2)); + assert_eq!("-0.100", arr.value_as_string(3)); + assert_eq!("0.010", arr.value_as_string(4)); + assert_eq!("-0.010", arr.value_as_string(5)); + assert_eq!("0.000", arr.value_as_string(6)); + } + + #[test] + fn test_decimal_array_with_precision_and_scale() { + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) + .with_precision_and_scale(20, 2) + .unwrap(); + + assert_eq!(arr.data_type(), &DataType::Decimal128(20, 2)); + assert_eq!(arr.precision(), 20); + assert_eq!(arr.scale(), 2); + + let actual: Vec<_> = (0..arr.len()).map(|i| arr.value_as_string(i)).collect(); + let expected = vec!["123.45", "4.56", "78.90", "-1232234234324.32"]; + + assert_eq!(actual, expected); + } + + #[test] + #[should_panic( + expected = "-123223423432432 is too small to store in a Decimal128 of precision 5. Min is -99999" + )] + fn test_decimal_array_with_precision_and_scale_out_of_range() { + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) + // precision is too small to hold value + .with_precision_and_scale(5, 2) + .unwrap(); + arr.validate_decimal_precision(5).unwrap(); + } + + #[test] + #[should_panic(expected = "precision cannot be 0, has to be between [1, 38]")] + fn test_decimal_array_with_precision_zero() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(0, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "precision 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_precision() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(40, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_scale() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(20, 40) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 10 is greater than precision 4")] + fn test_decimal_array_with_precision_and_scale_invalid_precision_and_scale() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(4, 10) + .unwrap(); + } + + #[test] + fn test_decimal_array_set_null_if_overflow_with_precision() { + let array = Decimal128Array::from(vec![Some(123456), Some(123), None, Some(123456)]); + let result = array.null_if_overflow_precision(5); + let expected = Decimal128Array::from(vec![None, Some(123), None, None]); + assert_eq!(result, expected); + } + + #[test] + fn test_decimal256_iter() { + let mut builder = Decimal256Builder::with_capacity(30); + let decimal1 = i256::from_i128(12345); + builder.append_value(decimal1); + + builder.append_null(); + + let decimal2 = i256::from_i128(56789); + builder.append_value(decimal2); + + let array: Decimal256Array = builder.finish().with_precision_and_scale(76, 6).unwrap(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } + + #[test] + fn test_from_iter_decimal256array() { + let value1 = i256::from_i128(12345); + let value2 = i256::from_i128(56789); + + let mut array: Decimal256Array = + vec![Some(value1), None, Some(value2)].into_iter().collect(); + array = array.with_precision_and_scale(76, 10).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal256(76, 10)); + assert_eq!(value1, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(value2, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_from_iter_decimal128array() { + let mut array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); + array = array.with_precision_and_scale(38, 10).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_unary_opt() { + let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]); + let r = array.unary_opt::<_, Int32Type>(|x| (x % 2 != 0).then_some(x)); + + let expected = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + + let r = expected.unary_opt::<_, Int32Type>(|x| (x % 3 != 0).then_some(x)); + let expected = Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let array = Decimal128Array::from(vec![-100, 0, 101]); + array.value(4); + } + + #[test] + fn test_into_builder() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + let col: Int32Array = downcast_array(&boxed); + drop(boxed); + + let mut builder = col.into_builder().unwrap(); + + let slice = builder.values_slice_mut(); + assert_eq!(slice, &[1, 2, 3]); + + slice[0] = 4; + slice[1] = 2; + slice[2] = 1; + + let expected: Int32Array = vec![Some(4), Some(2), Some(1)].into_iter().collect(); + + let new_array = builder.finish(); + assert_eq!(expected, new_array); + } + + #[test] + fn test_into_builder_cloned_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + + let col: Int32Array = PrimitiveArray::::from(boxed.to_data()); + let err = col.into_builder(); + + match err { + Ok(_) => panic!("Should not get builder from cloned array"), + Err(returned) => { + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(expected, returned) + } + } + } + + #[test] + fn test_into_builder_on_sliced_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + let slice = array.slice(1, 2); + let col: Int32Array = downcast_array(&slice); + + drop(slice); + + col.into_builder() + .expect_err("Should not build builder from sliced array"); + } + + #[test] + fn test_unary_mut() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + let expected: Int32Array = vec![3, 5, 7].into_iter().map(Some).collect(); + + assert_eq!(expected, c); + + let array: Int32Array = Int32Array::from(vec![Some(5), Some(7), None]); + let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + } + + #[test] + #[should_panic( + expected = "PrimitiveArray expected data type Interval(MonthDayNano) got Interval(DayTime)" + )] + fn test_invalid_interval_type() { + let array = IntervalDayTimeArray::from(vec![IntervalDayTime::ZERO]); + let _ = IntervalMonthDayNanoArray::from(array.into_data()); + } + + #[test] + fn test_timezone() { + let array = TimestampNanosecondArray::from_iter_values([1, 2]); + assert_eq!(array.timezone(), None); + + let array = array.with_timezone("+02:00"); + assert_eq!(array.timezone(), Some("+02:00")); + } + + #[test] + fn test_try_new() { + Int32Array::new(vec![1, 2, 3, 4].into(), None); + Int32Array::new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(4))); + + let err = Int32Array::try_new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(3))) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for PrimitiveArray, expected 4 got 3" + ); + + TimestampNanosecondArray::new(vec![1, 2, 3, 4].into(), None).with_data_type( + DataType::Timestamp(TimeUnit::Nanosecond, Some("03:00".into())), + ); + } + + #[test] + #[should_panic(expected = "PrimitiveArray expected data type Int32 got Date32")] + fn test_with_data_type() { + Int32Array::new(vec![1, 2, 3, 4].into(), None).with_data_type(DataType::Date32); + } + + #[test] + fn test_time_32second_output() { + let array: Time32SecondArray = vec![ + Some(-1), + Some(0), + Some(86_399), + Some(86_400), + Some(86_401), + None, + ] + .into(); + let debug_str = format!("{:?}", array); + assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(Second),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400 to temporal for Time32(Second),\n Cast error: Failed to convert 86401 to temporal for Time32(Second),\n null,\n]", + debug_str + ); + } + + #[test] + fn test_time_32millisecond_debug_output() { + let array: Time32MillisecondArray = vec![ + Some(-1), + Some(0), + Some(86_399_000), + Some(86_400_000), + Some(86_401_000), + None, + ] + .into(); + let debug_str = format!("{:?}", array); + assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(Millisecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000 to temporal for Time32(Millisecond),\n Cast error: Failed to convert 86401000 to temporal for Time32(Millisecond),\n null,\n]", + debug_str + ); + } + + #[test] + fn test_time_64nanosecond_debug_output() { + let array: Time64NanosecondArray = vec![ + Some(-1), + Some(0), + Some(86_399 * 1_000_000_000), + Some(86_400 * 1_000_000_000), + Some(86_401 * 1_000_000_000), + None, + ] + .into(); + let debug_str = format!("{:?}", array); + assert_eq!( + "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(Nanosecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000000 to temporal for Time64(Nanosecond),\n Cast error: Failed to convert 86401000000000 to temporal for Time64(Nanosecond),\n null,\n]", + debug_str + ); + } + + #[test] + fn test_time_64microsecond_debug_output() { + let array: Time64MicrosecondArray = vec![ + Some(-1), + Some(0), + Some(86_399 * 1_000_000), + Some(86_400 * 1_000_000), + Some(86_401 * 1_000_000), + None, + ] + .into(); + let debug_str = format!("{:?}", array); + assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(Microsecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000 to temporal for Time64(Microsecond),\n Cast error: Failed to convert 86401000000 to temporal for Time64(Microsecond),\n null,\n]", debug_str); + } + + #[test] + fn test_primitive_with_nulls_into_builder() { + let array: Int32Array = vec![ + Some(1), + None, + Some(3), + Some(4), + None, + Some(7), + None, + Some(8), + ] + .into_iter() + .collect(); + let _ = array.into_builder(); + } +} diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs new file mode 100644 index 000000000000..aa8bb259a0eb --- /dev/null +++ b/arrow-array/src/array/run_array.rs @@ -0,0 +1,1093 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, RunEndBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field}; + +use crate::{ + builder::StringRunBuilder, + make_array, + run_iterator::RunArrayIter, + types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, + Array, ArrayAccessor, ArrayRef, PrimitiveArray, +}; + +/// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) +/// +/// This encoding is variation on [run-length encoding (RLE)](https://en.wikipedia.org/wiki/Run-length_encoding) +/// and is good for representing data containing same values repeated consecutively. +/// +/// [`RunArray`] contains `run_ends` array and `values` array of same length. +/// The `run_ends` array stores the indexes at which the run ends. The `values` array +/// stores the value of each run. Below example illustrates how a logical array is represented in +/// [`RunArray`] +/// +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ +/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ +/// │ │ A │ │ 2 │ │ │ A │ +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ │ D │ │ 3 │ │ │ A │ run length of 'A' = runs_ends[0] - 0 = 2 +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ │ B │ │ 6 │ │ │ D │ run length of 'D' = run_ends[1] - run_ends[0] = 1 +/// └─────────────────┘ └─────────┘ ├─────────────────┤ +/// │ values run_ends │ │ B │ +/// ├─────────────────┤ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ │ B │ +/// ├─────────────────┤ +/// RunArray │ B │ run length of 'B' = run_ends[2] - run_ends[1] = 3 +/// length = 3 └─────────────────┘ +/// +/// Logical array +/// Contents +/// ``` + +pub struct RunArray { + data_type: DataType, + run_ends: RunEndBuffer, + values: ArrayRef, +} + +impl Clone for RunArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + run_ends: self.run_ends.clone(), + values: self.values.clone(), + } + } +} + +impl RunArray { + /// Calculates the logical length of the array encoded + /// by the given run_ends array. + pub fn logical_len(run_ends: &PrimitiveArray) -> usize { + let len = run_ends.len(); + if len == 0 { + return 0; + } + run_ends.value(len - 1).as_usize() + } + + /// Attempts to create RunArray using given run_ends (index where a run ends) + /// and the values (value of the run). Returns an error if the given data is not compatible + /// with RunEndEncoded specification. + pub fn try_new(run_ends: &PrimitiveArray, values: &dyn Array) -> Result { + let run_ends_type = run_ends.data_type().clone(); + let values_type = values.data_type().clone(); + let ree_array_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", run_ends_type, false)), + Arc::new(Field::new("values", values_type, true)), + ); + let len = RunArray::logical_len(run_ends); + let builder = ArrayDataBuilder::new(ree_array_type) + .len(len) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + + // `build_unchecked` is used to avoid recursive validation of child arrays. + let array_data = unsafe { builder.build_unchecked() }; + + // Safety: `validate_data` checks below + // 1. The given array data has exactly two child arrays. + // 2. The first child array (run_ends) has valid data type. + // 3. run_ends array does not have null values + // 4. run_ends array has non-zero and strictly increasing values. + // 5. The length of run_ends array and values array are the same. + array_data.validate_data()?; + + Ok(array_data.into()) + } + + /// Returns a reference to [`RunEndBuffer`] + pub fn run_ends(&self) -> &RunEndBuffer { + &self.run_ends + } + + /// Returns a reference to values array + /// + /// Note: any slicing of this [`RunArray`] array is not applied to the returned array + /// and must be handled separately + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns the physical index at which the array slice starts. + pub fn get_start_physical_index(&self) -> usize { + self.run_ends.get_start_physical_index() + } + + /// Returns the physical index at which the array slice ends. + pub fn get_end_physical_index(&self) -> usize { + self.run_ends.get_end_physical_index() + } + + /// Downcast this [`RunArray`] to a [`TypedRunArray`] + /// + /// ``` + /// use arrow_array::{Array, ArrayAccessor, RunArray, StringArray, types::Int32Type}; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let run_array = RunArray::::from_iter(orig); + /// let typed = run_array.downcast::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.values().is_null(2)); + /// ``` + /// + pub fn downcast(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedRunArray { + run_array: self, + values, + }) + } + + /// Returns index to the physical array for the given index to the logical array. + /// This function adjusts the input logical index based on `ArrayData::offset` + /// Performs a binary search on the run_ends array for the input index. + /// + /// The result is arbitrary if `logical_index >= self.len()` + pub fn get_physical_index(&self, logical_index: usize) -> usize { + self.run_ends.get_physical_index(logical_index) + } + + /// Returns the physical indices of the input logical indices. Returns error if any of the logical + /// index cannot be converted to physical index. The logical indices are sorted and iterated along + /// with run_ends array to find matching physical index. The approach used here was chosen over + /// finding physical index for each logical index using binary search using the function + /// `get_physical_index`. Running benchmarks on both approaches showed that the approach used here + /// scaled well for larger inputs. + /// See for more details. + #[inline] + pub fn get_physical_indices(&self, logical_indices: &[I]) -> Result, ArrowError> + where + I: ArrowNativeType, + { + let len = self.run_ends().len(); + let offset = self.run_ends().offset(); + + let indices_len = logical_indices.len(); + + if indices_len == 0 { + return Ok(vec![]); + } + + // `ordered_indices` store index into `logical_indices` and can be used + // to iterate `logical_indices` in sorted order. + let mut ordered_indices: Vec = (0..indices_len).collect(); + + // Instead of sorting `logical_indices` directly, sort the `ordered_indices` + // whose values are index of `logical_indices` + ordered_indices.sort_unstable_by(|lhs, rhs| { + logical_indices[*lhs] + .partial_cmp(&logical_indices[*rhs]) + .unwrap() + }); + + // Return early if all the logical indices cannot be converted to physical indices. + let largest_logical_index = logical_indices[*ordered_indices.last().unwrap()].as_usize(); + if largest_logical_index >= len { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {largest_logical_index}.", + ))); + } + + // Skip some physical indices based on offset. + let skip_value = self.get_start_physical_index(); + + let mut physical_indices = vec![0; indices_len]; + + let mut ordered_index = 0_usize; + for (physical_index, run_end) in self.run_ends.values().iter().enumerate().skip(skip_value) + { + // Get the run end index (relative to offset) of current physical index + let run_end_value = run_end.as_usize() - offset; + + // All the `logical_indices` that are less than current run end index + // belongs to current physical index. + while ordered_index < indices_len + && logical_indices[ordered_indices[ordered_index]].as_usize() < run_end_value + { + physical_indices[ordered_indices[ordered_index]] = physical_index; + ordered_index += 1; + } + } + + // If there are input values >= run_ends.last_value then we'll not be able to convert + // all logical indices to physical indices. + if ordered_index < logical_indices.len() { + let logical_index = logical_indices[ordered_indices[ordered_index]].as_usize(); + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {logical_index}.", + ))); + } + Ok(physical_indices) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + run_ends: self.run_ends.slice(offset, length), + values: self.values.clone(), + } + } +} + +impl From for RunArray { + // The method assumes the caller already validated the data using `ArrayData::validate_data()` + fn from(data: ArrayData) -> Self { + match data.data_type() { + DataType::RunEndEncoded(_, _) => {} + _ => { + panic!("Invalid data type for RunArray. The data type should be DataType::RunEndEncoded"); + } + } + + // Safety + // ArrayData is valid + let child = &data.child_data()[0]; + assert_eq!(child.data_type(), &R::DATA_TYPE, "Incorrect run ends type"); + let run_ends = unsafe { + let scalar = child.buffers()[0].clone().into(); + RunEndBuffer::new_unchecked(scalar, data.offset(), data.len()) + }; + + let values = make_array(data.child_data()[1].clone()); + Self { + data_type: data.data_type().clone(), + run_ends, + values, + } + } +} + +impl From> for ArrayData { + fn from(array: RunArray) -> Self { + let len = array.run_ends.len(); + let offset = array.run_ends.offset(); + + let run_ends = ArrayDataBuilder::new(R::DATA_TYPE) + .len(array.run_ends.values().len()) + .buffers(vec![array.run_ends.into_inner().into_inner()]); + + let run_ends = unsafe { run_ends.build_unchecked() }; + + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .offset(offset) + .child_data(vec![run_ends, array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for RunArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.run_ends.len() + } + + fn is_empty(&self) -> bool { + self.run_ends.is_empty() + } + + fn offset(&self) -> usize { + self.run_ends.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + + fn logical_nulls(&self) -> Option { + let len = self.len(); + let nulls = self.values.logical_nulls()?; + let mut out = BooleanBufferBuilder::new(len); + let offset = self.run_ends.offset(); + let mut valid_start = 0; + let mut last_end = 0; + for (idx, end) in self.run_ends.values().iter().enumerate() { + let end = end.as_usize(); + if end < offset { + continue; + } + let end = (end - offset).min(len); + if nulls.is_null(idx) { + if valid_start < last_end { + out.append_n(last_end - valid_start, true); + } + out.append_n(end - last_end, false); + valid_start = end; + } + last_end = end; + if end == len { + break; + } + } + if valid_start < len { + out.append_n(len - valid_start, true) + } + // Sanity check + assert_eq!(out.len(), len); + Some(out.finish().into()) + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && self.values.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.run_ends.inner().inner().capacity() + self.values.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + + self.run_ends.inner().inner().capacity() + + self.values.get_array_memory_size() + } +} + +impl std::fmt::Debug for RunArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!( + f, + "RunArray {{run_ends: {:?}, values: {:?}}}", + self.run_ends.values(), + self.values + ) + } +} + +/// Constructs a `RunArray` from an iterator of optional strings. +/// +/// # Example: +/// ``` +/// use arrow_array::{RunArray, PrimitiveArray, StringArray, types::Int16Type}; +/// +/// let test = vec!["a", "a", "b", "c", "c"]; +/// let array: RunArray = test +/// .iter() +/// .map(|&x| if x == "b" { None } else { Some(x) }) +/// .collect(); +/// assert_eq!( +/// "RunArray {run_ends: [2, 3, 5], values: StringArray\n[\n \"a\",\n null,\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: RunEndIndexType> FromIterator> for RunArray { + fn from_iter>>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringRunBuilder::with_capacity(lower, 256); + it.for_each(|i| { + builder.append_option(i); + }); + + builder.finish() + } +} + +/// Constructs a `RunArray` from an iterator of strings. +/// +/// # Example: +/// +/// ``` +/// use arrow_array::{RunArray, PrimitiveArray, StringArray, types::Int16Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: RunArray = test.into_iter().collect(); +/// assert_eq!( +/// "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: RunEndIndexType> FromIterator<&'a str> for RunArray { + fn from_iter>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringRunBuilder::with_capacity(lower, 256); + it.for_each(|i| { + builder.append_value(i); + }); + + builder.finish() + } +} + +/// +/// A [`RunArray`] with `i16` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int16RunArray, Int16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int16RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int16RunArray = RunArray; + +/// +/// A [`RunArray`] with `i32` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int32RunArray, Int32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int32RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int32RunArray = RunArray; + +/// +/// A [`RunArray`] with `i64` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int64RunArray, Int64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int64RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int64RunArray = RunArray; + +/// A [`RunArray`] typed typed on its child values array +/// +/// Implements [`ArrayAccessor`] and [`IntoIterator`] allowing fast access to its elements +/// +/// ``` +/// use arrow_array::{RunArray, StringArray, types::Int32Type}; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let ree_array = RunArray::::from_iter(orig); +/// +/// // `TypedRunArray` allows you to access the values directly +/// let typed = ree_array.downcast::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +pub struct TypedRunArray<'a, R: RunEndIndexType, V> { + /// The run array + run_array: &'a RunArray, + + /// The values of the run_array + values: &'a V, +} + +// Manually implement `Clone` to avoid `V: Clone` type constraint +impl<'a, R: RunEndIndexType, V> Clone for TypedRunArray<'a, R, V> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, R: RunEndIndexType, V> Copy for TypedRunArray<'a, R, V> {} + +impl<'a, R: RunEndIndexType, V> std::fmt::Debug for TypedRunArray<'a, R, V> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "TypedRunArray({:?})", self.run_array) + } +} + +impl<'a, R: RunEndIndexType, V> TypedRunArray<'a, R, V> { + /// Returns the run_ends of this [`TypedRunArray`] + pub fn run_ends(&self) -> &'a RunEndBuffer { + self.run_array.run_ends() + } + + /// Returns the values of this [`TypedRunArray`] + pub fn values(&self) -> &'a V { + self.values + } + + /// Returns the run array of this [`TypedRunArray`] + pub fn run_array(&self) -> &'a RunArray { + self.run_array + } +} + +impl<'a, R: RunEndIndexType, V: Sync> Array for TypedRunArray<'a, R, V> { + fn as_any(&self) -> &dyn Any { + self.run_array + } + + fn to_data(&self) -> ArrayData { + self.run_array.to_data() + } + + fn into_data(self) -> ArrayData { + self.run_array.into_data() + } + + fn data_type(&self) -> &DataType { + self.run_array.data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.run_array.slice(offset, length)) + } + + fn len(&self) -> usize { + self.run_array.len() + } + + fn is_empty(&self) -> bool { + self.run_array.is_empty() + } + + fn offset(&self) -> usize { + self.run_array.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.run_array.nulls() + } + + fn logical_nulls(&self) -> Option { + self.run_array.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.run_array.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.run_array.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.run_array.get_array_memory_size() + } +} + +// Array accessor converts the index of logical array to the index of the physical array +// using binary search. The time complexity is O(log N) where N is number of runs. +impl<'a, R, V> ArrayAccessor for TypedRunArray<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, logical_index: usize) -> Self::Item { + assert!( + logical_index < self.len(), + "Trying to access an element at index {} from a TypedRunArray of length {}", + logical_index, + self.len() + ); + unsafe { self.value_unchecked(logical_index) } + } + + unsafe fn value_unchecked(&self, logical_index: usize) -> Self::Item { + let physical_index = self.run_array.get_physical_index(logical_index); + self.values().value_unchecked(physical_index) + } +} + +impl<'a, R, V> IntoIterator for TypedRunArray<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = Option<<&'a V as ArrayAccessor>::Item>; + type IntoIter = RunArrayIter<'a, R, V>; + + fn into_iter(self) -> Self::IntoIter { + RunArrayIter::new(self) + } +} + +#[cfg(test)] +mod tests { + use rand::seq::SliceRandom; + use rand::thread_rng; + use rand::Rng; + + use super::*; + use crate::builder::PrimitiveRunBuilder; + use crate::cast::AsArray; + use crate::types::{Int8Type, UInt32Type}; + use crate::{Int32Array, StringArray}; + + fn build_input_array(size: usize) -> Vec> { + // The input array is created by shuffling and repeating + // the seed values random number of times. + let mut seed: Vec> = vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + let mut result: Vec> = Vec::with_capacity(size); + let mut ix = 0; + let mut rng = thread_rng(); + // run length can go up to 8. Cap the max run length for smaller arrays to size / 2. + let max_run_length = 8_usize.min(1_usize.max(size / 2)); + while result.len() < size { + // shuffle the seed array if all the values are iterated. + if ix == 0 { + seed.shuffle(&mut rng); + } + // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + for _ in 0..num { + result.push(seed[ix]); + } + ix += 1; + if ix == seed.len() { + ix = 0 + } + } + result.resize(size, None); + result + } + + // Asserts that `logical_array[logical_indices[*]] == physical_array[physical_indices[*]]` + fn compare_logical_and_physical_indices( + logical_indices: &[u32], + logical_array: &[Option], + physical_indices: &[usize], + physical_array: &PrimitiveArray, + ) { + assert_eq!(logical_indices.len(), physical_indices.len()); + + // check value in logical index in the logical_array matches physical index in physical_array + logical_indices + .iter() + .map(|f| f.as_usize()) + .zip(physical_indices.iter()) + .for_each(|(logical_ix, physical_ix)| { + let expected = logical_array[logical_ix]; + match expected { + Some(val) => { + assert!(physical_array.is_valid(*physical_ix)); + let actual = physical_array.value(*physical_ix); + assert_eq!(val, actual); + } + None => { + assert!(physical_array.is_null(*physical_ix)) + } + }; + }); + } + #[test] + fn test_run_array() { + // Construct a value array + let value_data = + PrimitiveArray::::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + // Construct a run_ends array: + let run_ends_values = [4_i16, 6, 7, 9, 13, 18, 20, 22]; + let run_ends_data = + PrimitiveArray::::from_iter_values(run_ends_values.iter().copied()); + + // Construct a run ends encoded array from the above two + let ree_array = RunArray::::try_new(&run_ends_data, &value_data).unwrap(); + + assert_eq!(ree_array.len(), 22); + assert_eq!(ree_array.null_count(), 0); + + let values = ree_array.values(); + assert_eq!(value_data.into_data(), values.to_data()); + assert_eq!(&DataType::Int8, values.data_type()); + + let run_ends = ree_array.run_ends(); + assert_eq!(run_ends.values(), &run_ends_values); + } + + #[test] + fn test_run_array_fmt_debug() { + let mut builder = PrimitiveRunBuilder::::with_capacity(3); + builder.append_value(12345678); + builder.append_null(); + builder.append_value(22345678); + let array = builder.finish(); + assert_eq!( + "RunArray {run_ends: [1, 2, 3], values: PrimitiveArray\n[\n 12345678,\n null,\n 22345678,\n]}\n", + format!("{array:?}") + ); + + let mut builder = PrimitiveRunBuilder::::with_capacity(20); + for _ in 0..20 { + builder.append_value(1); + } + let array = builder.finish(); + + assert_eq!(array.len(), 20); + assert_eq!(array.null_count(), 0); + + assert_eq!( + "RunArray {run_ends: [20], values: PrimitiveArray\n[\n 1,\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_from_iter() { + let test = vec!["a", "a", "b", "c"]; + let array: RunArray = test + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + assert_eq!( + "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n null,\n \"c\",\n]}\n", + format!("{array:?}") + ); + + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + let array: RunArray = test.into_iter().collect(); + assert_eq!( + "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_run_ends_as_primitive_array() { + let test = vec!["a", "b", "c", "a"]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(&[1, 2, 3, 4], run_ends.values()); + } + + #[test] + fn test_run_array_as_primitive_array_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 6); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(&[1, 2, 3, 5, 6], run_ends.values()); + + let values_data = array.values(); + assert_eq!(2, values_data.null_count()); + assert_eq!(5, values_data.len()); + } + + #[test] + fn test_run_array_all_nulls() { + let test = vec![None, None, None]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 3); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(3, run_ends.len()); + assert_eq!(&[3], run_ends.values()); + + let values_data = array.values(); + assert_eq!(1, values_data.null_count()); + } + + #[test] + fn test_run_array_try_new() { + let values: StringArray = [Some("foo"), Some("bar"), None, Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); + + let array = RunArray::::try_new(&run_ends, &values).unwrap(); + assert_eq!(array.values().data_type(), &DataType::Utf8); + + assert_eq!(array.null_count(), 0); + assert_eq!(array.len(), 4); + assert_eq!(array.values().null_count(), 1); + + assert_eq!( + "RunArray {run_ends: [1, 2, 3, 4], values: StringArray\n[\n \"foo\",\n \"bar\",\n null,\n \"baz\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_int16_type_definition() { + let array: Int16RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); + let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); + assert_eq!(array.run_ends().values(), &[2, 3, 5]); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_run_array_empty_string() { + let array: Int16RunArray = vec!["a", "a", "", "", "c"].into_iter().collect(); + let values: Arc = Arc::new(StringArray::from(vec!["a", "", "c"])); + assert_eq!(array.run_ends().values(), &[2, 4, 5]); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_run_array_length_mismatch() { + let values: StringArray = [Some("foo"), Some("bar"), None, Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(2), Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The run_ends array length should be the same as values array length. Run_ends array length is 3, values array length is 4".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_with_null() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), None, Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError( + "Found null values in run_ends array. The run_ends array should not have null values." + .to_string(), + ); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_with_zeroes() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(0), Some(1), Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The values in run_ends array should be strictly positive. Found value 0 at index 0 that does not match the criteria.".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_non_increasing() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(4), Some(4)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The values in run_ends array should be strictly increasing. Found value 4 at index 2 with previous value 4 that does not match the criteria.".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + #[should_panic(expected = "Incorrect run ends type")] + fn test_run_array_run_ends_data_type_mismatch() { + let a = RunArray::::from_iter(["32"]); + let _ = RunArray::::from(a.into_data()); + } + + #[test] + fn test_ree_array_accessor() { + let input_array = build_input_array(256); + + // Encode the input_array to ree_array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + let typed = run_array.downcast::>().unwrap(); + + // Access every index and check if the value in the input array matches returned value. + for (i, inp_val) in input_array.iter().enumerate() { + if let Some(val) = inp_val { + let actual = typed.value(i); + assert_eq!(*val, actual) + } else { + let physical_ix = run_array.get_physical_index(i); + assert!(typed.values().is_null(physical_ix)); + }; + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_get_physical_indices() { + // Test for logical lengths starting from 10 to 250 increasing by 10 + for logical_len in (0..250).step_by(10) { + let input_array = build_input_array(logical_len); + + // create run array using input_array + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_array.clone().into_iter()); + + let run_array = builder.finish(); + let physical_values_array = run_array.values().as_primitive::(); + + // create an array consisting of all the indices repeated twice and shuffled. + let mut logical_indices: Vec = (0_u32..(logical_len as u32)).collect(); + // add same indices once more + logical_indices.append(&mut logical_indices.clone()); + let mut rng = thread_rng(); + logical_indices.shuffle(&mut rng); + + let physical_indices = run_array.get_physical_indices(&logical_indices).unwrap(); + + assert_eq!(logical_indices.len(), physical_indices.len()); + + // check value in logical index in the input_array matches physical index in typed_run_array + compare_logical_and_physical_indices( + &logical_indices, + &input_array, + &physical_indices, + physical_values_array, + ); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_get_physical_indices_sliced() { + let total_len = 80; + let input_array = build_input_array(total_len); + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + let physical_values_array = run_array.values().as_primitive::(); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // create an array consisting of all the indices repeated twice and shuffled. + let mut logical_indices: Vec = (0_u32..(slice_len as u32)).collect(); + // add same indices once more + logical_indices.append(&mut logical_indices.clone()); + let mut rng = thread_rng(); + logical_indices.shuffle(&mut rng); + + // test for offset = 0 and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[0..slice_len]; + + // slice the run array + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + + // Get physical indices. + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); + + // test for offset = total_len - slice_len and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[total_len - slice_len..total_len]; + + // slice the run array + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + + // Get physical indices + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); + } + } + + #[test] + fn test_logical_nulls() { + let run = Int32Array::from(vec![3, 6, 9, 12]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let array = RunArray::try_new(&run, &values).unwrap(); + + let expected = [ + true, true, true, false, false, false, true, true, true, false, false, false, + ]; + + let n = array.logical_nulls().unwrap(); + assert_eq!(n.null_count(), 6); + + let slices = [(0, 12), (0, 2), (2, 5), (3, 0), (3, 3), (3, 4), (4, 8)]; + for (offset, length) in slices { + let a = array.slice(offset, length); + let n = a.logical_nulls().unwrap(); + let n = n.into_iter().collect::>(); + assert_eq!(&n, &expected[offset..offset + length], "{offset} {length}"); + } + } +} diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs new file mode 100644 index 000000000000..25581cfaa49d --- /dev/null +++ b/arrow-array/src/array/string_array.rs @@ -0,0 +1,566 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::GenericStringType; +use crate::{GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait}; +use arrow_schema::{ArrowError, DataType}; + +/// A [`GenericByteArray`] for storing `str` +pub type GenericStringArray = GenericByteArray>; + +impl GenericStringArray { + /// Get the data type of the array. + #[deprecated(note = "please use `Self::DATA_TYPE` instead")] + pub const fn get_data_type() -> DataType { + Self::DATA_TYPE + } + + /// Returns the number of `Unicode Scalar Value` in the string at index `i`. + /// # Performance + /// This function has `O(n)` time complexity where `n` is the string length. + /// If you can make sure that all chars in the string are in the range `U+0x0000` ~ `U+0x007F`, + /// please use the function [`value_length`](#method.value_length) which has O(1) time complexity. + pub fn num_chars(&self, i: usize) -> usize { + self.value(i).chars().count() + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the indexes in the iterator are less than the `array.len()` + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } + + /// Fallibly creates a [`GenericStringArray`] from a [`GenericBinaryArray`] returning + /// an error if [`GenericBinaryArray`] contains invalid UTF-8 data + pub fn try_from_binary(v: GenericBinaryArray) -> Result { + let (offsets, values, nulls) = v.into_parts(); + Self::try_new(offsets, values, nulls) + } +} + +impl From> + for GenericStringArray +{ + fn from(v: GenericListArray) -> Self { + GenericBinaryArray::::from(v).into() + } +} + +impl From> + for GenericStringArray +{ + fn from(v: GenericBinaryArray) -> Self { + Self::try_from_binary(v).unwrap() + } +} + +impl From>> for GenericStringArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() + } +} + +impl From> for GenericStringArray { + fn from(v: Vec<&str>) -> Self { + Self::from_iter_values(v) + } +} + +impl From>> for GenericStringArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() + } +} + +impl From> for GenericStringArray { + fn from(v: Vec) -> Self { + Self::from_iter_values(v) + } +} + +/// A [`GenericStringArray`] of `str` using `i32` offsets +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::StringArray; +/// // Create from Vec> +/// let arr = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); +/// // Create from Vec<&str> +/// let arr = StringArray::from(vec!["foo", "bar", "baz"]); +/// // Create from iter/collect (requires Option<&str>) +/// let arr: StringArray = std::iter::repeat(Some("foo")).take(10).collect(); +/// ``` +/// +/// Construction and Access +/// +/// ``` +/// # use arrow_array::StringArray; +/// let array = StringArray::from(vec![Some("foo"), None, Some("bar")]); +/// assert_eq!(array.value(0), "foo"); +/// ``` +/// +/// See [`GenericByteArray`] for more information and examples +pub type StringArray = GenericStringArray; + +/// A [`GenericStringArray`] of `str` using `i64` offsets +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::LargeStringArray; +/// // Create from Vec> +/// let arr = LargeStringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); +/// // Create from Vec<&str> +/// let arr = LargeStringArray::from(vec!["foo", "bar", "baz"]); +/// // Create from iter/collect (requires Option<&str>) +/// let arr: LargeStringArray = std::iter::repeat(Some("foo")).take(10).collect(); +/// ``` +/// +/// Construction and Access +/// +/// ``` +/// use arrow_array::LargeStringArray; +/// let array = LargeStringArray::from(vec![Some("foo"), None, Some("bar")]); +/// assert_eq!(array.value(2), "bar"); +/// ``` +/// +/// See [`GenericByteArray`] for more information and examples +pub type LargeStringArray = GenericStringArray; + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; + use crate::types::UInt8Type; + use crate::Array; + use arrow_buffer::Buffer; + use arrow_data::ArrayData; + use arrow_schema::Field; + use std::sync::Arc; + + #[test] + fn test_string_array_from_u8_slice() { + let values: Vec<&str> = vec!["hello", "", "A£ऀ𖼚𝌆৩ƐZ"]; + + // Array data: ["hello", "", "A£ऀ𖼚𝌆৩ƐZ"] + let string_array = StringArray::from(values); + + assert_eq!(3, string_array.len()); + assert_eq!(0, string_array.null_count()); + assert_eq!("hello", string_array.value(0)); + assert_eq!("hello", unsafe { string_array.value_unchecked(0) }); + assert_eq!("", string_array.value(1)); + assert_eq!("", unsafe { string_array.value_unchecked(1) }); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", string_array.value(2)); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", unsafe { + string_array.value_unchecked(2) + }); + assert_eq!(20, string_array.value_length(2)); // 1 + 2 + 3 + 4 + 4 + 3 + 2 + 1 + assert_eq!(8, string_array.num_chars(2)); + for i in 0..3 { + assert!(string_array.is_valid(i)); + assert!(!string_array.is_null(i)); + } + } + + #[test] + #[should_panic(expected = "StringArray expects DataType::Utf8")] + fn test_string_array_from_int() { + let array = LargeStringArray::from(vec!["a", "b"]); + drop(StringArray::from(array.into_data())); + } + + #[test] + fn test_large_string_array_from_u8_slice() { + let values: Vec<&str> = vec!["hello", "", "A£ऀ𖼚𝌆৩ƐZ"]; + + // Array data: ["hello", "", "A£ऀ𖼚𝌆৩ƐZ"] + let string_array = LargeStringArray::from(values); + + assert_eq!(3, string_array.len()); + assert_eq!(0, string_array.null_count()); + assert_eq!("hello", string_array.value(0)); + assert_eq!("hello", unsafe { string_array.value_unchecked(0) }); + assert_eq!("", string_array.value(1)); + assert_eq!("", unsafe { string_array.value_unchecked(1) }); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", string_array.value(2)); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", unsafe { + string_array.value_unchecked(2) + }); + assert_eq!(5, string_array.value_offsets()[2]); + assert_eq!(20, string_array.value_length(2)); // 1 + 2 + 3 + 4 + 4 + 3 + 2 + 1 + assert_eq!(8, string_array.num_chars(2)); + for i in 0..3 { + assert!(string_array.is_valid(i)); + assert!(!string_array.is_null(i)); + } + } + + #[test] + fn test_nested_string_array() { + let string_builder = StringBuilder::with_capacity(3, 10); + let mut list_of_string_builder = ListBuilder::new(string_builder); + + list_of_string_builder.values().append_value("foo"); + list_of_string_builder.values().append_value("bar"); + list_of_string_builder.append(true); + + list_of_string_builder.values().append_value("foobar"); + list_of_string_builder.append(true); + let list_of_strings = list_of_string_builder.finish(); + + assert_eq!(list_of_strings.len(), 2); + + let first_slot = list_of_strings.value(0); + let first_list = first_slot.as_any().downcast_ref::().unwrap(); + assert_eq!(first_list.len(), 2); + assert_eq!(first_list.value(0), "foo"); + assert_eq!(unsafe { first_list.value_unchecked(0) }, "foo"); + assert_eq!(first_list.value(1), "bar"); + assert_eq!(unsafe { first_list.value_unchecked(1) }, "bar"); + + let second_slot = list_of_strings.value(1); + let second_list = second_slot.as_any().downcast_ref::().unwrap(); + assert_eq!(second_list.len(), 1); + assert_eq!(second_list.value(0), "foobar"); + assert_eq!(unsafe { second_list.value_unchecked(0) }, "foobar"); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a StringArray of length 3" + )] + fn test_string_array_get_value_index_out_of_bound() { + let values: [u8; 12] = [ + b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', + ]; + let offsets: [i32; 4] = [0, 5, 5, 12]; + let array_data = ArrayData::builder(DataType::Utf8) + .len(3) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) + .build() + .unwrap(); + let string_array = StringArray::from(array_data); + string_array.value(4); + } + + #[test] + fn test_string_array_fmt_debug() { + let arr: StringArray = vec!["hello", "arrow"].into(); + assert_eq!( + "StringArray\n[\n \"hello\",\n \"arrow\",\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_large_string_array_fmt_debug() { + let arr: LargeStringArray = vec!["hello", "arrow"].into(); + assert_eq!( + "LargeStringArray\n[\n \"hello\",\n \"arrow\",\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_string_array_from_iter() { + let data = [Some("hello"), None, Some("arrow")]; + let data_vec = data.to_vec(); + // from Vec> + let array1 = StringArray::from(data_vec.clone()); + // from Iterator> + let array2: StringArray = data_vec.clone().into_iter().collect(); + // from Iterator> + let array3: StringArray = data_vec + .into_iter() + .map(|x| x.map(|s| s.to_string())) + .collect(); + // from Iterator<&Option<&str>> + let array4: StringArray = data.iter().collect::(); + + assert_eq!(array1, array2); + assert_eq!(array2, array3); + assert_eq!(array3, array4); + } + + #[test] + fn test_string_array_from_iter_values() { + let data = ["hello", "hello2"]; + let array1 = StringArray::from_iter_values(data.iter()); + + assert_eq!(array1.value(0), "hello"); + assert_eq!(array1.value(1), "hello2"); + + // Also works with String types. + let data2 = ["goodbye".to_string(), "goodbye2".to_string()]; + let array2 = StringArray::from_iter_values(data2.iter()); + + assert_eq!(array2.value(0), "goodbye"); + assert_eq!(array2.value(1), "goodbye2"); + } + + #[test] + fn test_string_array_from_unbound_iter() { + // iterator that doesn't declare (upper) size bound + let string_iter = (0..) + .scan(0usize, |pos, i| { + if *pos < 10 { + *pos += 1; + Some(Some(format!("value {i}"))) + } else { + // actually returns up to 10 values + None + } + }) + // limited using take() + .take(100); + + let (_, upper_size_bound) = string_iter.size_hint(); + // the upper bound, defined by take above, is 100 + assert_eq!(upper_size_bound, Some(100)); + let string_array: StringArray = string_iter.collect(); + // but the actual number of items in the array should be 10 + assert_eq!(string_array.len(), 10); + } + + #[test] + fn test_string_array_all_null() { + let data: Vec> = vec![None]; + let array = StringArray::from(data); + array + .into_data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_large_string_array_all_null() { + let data: Vec> = vec![None]; + let array = LargeStringArray::from(data); + array + .into_data() + .validate_full() + .expect("All null array has valid array data"); + } + + fn _test_generic_string_array_from_list_array() { + let values = b"HelloArrowAndParquet"; + // "ArrowAndParquet" + let child_data = ArrayData::builder(DataType::UInt8) + .len(15) + .offset(5) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + + let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); + let null_buffer = Buffer::from_slice_ref([0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); + + // [None, Some("Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) + .offset(1) + .add_buffer(Buffer::from_slice_ref(offsets)) + .null_bit_buffer(Some(null_buffer)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + let string_array = GenericStringArray::::from(list_array); + + assert_eq!(2, string_array.len()); + assert_eq!(1, string_array.null_count()); + assert!(string_array.is_null(0)); + assert!(string_array.is_valid(1)); + assert_eq!("Parquet", string_array.value(1)); + } + + #[test] + fn test_string_array_from_list_array() { + _test_generic_string_array_from_list_array::(); + } + + #[test] + fn test_large_string_array_from_list_array() { + _test_generic_string_array_from_list_array::(); + } + + fn _test_generic_string_array_from_list_array_with_child_nulls_failed() { + let values = b"HelloArrow"; + let child_data = ArrayData::builder(DataType::UInt8) + .len(10) + .add_buffer(Buffer::from(&values[..])) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) + .build() + .unwrap(); + + let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); + + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); + + // [None, Some(b"Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + drop(GenericStringArray::::from(list_array)); + } + + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_string_array_from_list_array_with_child_nulls_failed() { + _test_generic_string_array_from_list_array_with_child_nulls_failed::(); + } + + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_large_string_array_from_list_array_with_child_nulls_failed() { + _test_generic_string_array_from_list_array_with_child_nulls_failed::(); + } + + fn _test_generic_string_array_from_list_array_wrong_type() { + let values = b"HelloArrow"; + let child_data = ArrayData::builder(DataType::UInt16) + .len(5) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + + let offsets = [0, 2, 3].map(|n| O::from_usize(n).unwrap()); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt16, + false, + ))); + + let array_data = ArrayData::builder(data_type) + .len(2) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + drop(GenericStringArray::::from(list_array)); + } + + #[test] + #[should_panic( + expected = "BinaryArray can only be created from List arrays, mismatched data types." + )] + fn test_string_array_from_list_array_wrong_type() { + _test_generic_string_array_from_list_array_wrong_type::(); + } + + #[test] + #[should_panic( + expected = "BinaryArray can only be created from List arrays, mismatched data types." + )] + fn test_large_string_array_from_list_array_wrong_type() { + _test_generic_string_array_from_list_array_wrong_type::(); + } + + #[test] + #[should_panic( + expected = "Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 0" + )] + fn test_list_array_utf8_validation() { + let mut builder = ListBuilder::new(PrimitiveBuilder::::new()); + builder.values().append_value(0xFF); + builder.append(true); + let list = builder.finish(); + let _ = StringArray::from(list); + } + + #[test] + fn test_empty_offsets() { + let string = StringArray::from( + ArrayData::builder(DataType::Utf8) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + + let string = LargeStringArray::from( + ArrayData::builder(DataType::LargeUtf8) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_into_builder() { + let array: StringArray = vec!["hello", "arrow"].into(); + + // Append values + let mut builder = array.into_builder().unwrap(); + + builder.append_value("rust"); + + let expected: StringArray = vec!["hello", "arrow", "rust"].into(); + let array = builder.finish(); + assert_eq!(expected, array); + } + + #[test] + fn test_into_builder_err() { + let array: StringArray = vec!["hello", "arrow"].into(); + + // Clone it, so we cannot get a mutable builder back + let shared_array = array.clone(); + + let err_return = array.into_builder().unwrap_err(); + assert_eq!(&err_return, &shared_array); + } +} diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs new file mode 100644 index 000000000000..059bc0b5e65b --- /dev/null +++ b/arrow-array/src/array/struct_array.rs @@ -0,0 +1,734 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields}; +use std::sync::Arc; +use std::{any::Any, ops::Index}; + +/// An array of [structs](https://arrow.apache.org/docs/format/Columnar.html#struct-layout) +/// +/// Each child (called *field*) is represented by a separate array. +/// +/// # Comparison with [RecordBatch] +/// +/// Both [`RecordBatch`] and [`StructArray`] represent a collection of columns / arrays with the +/// same length. +/// +/// However, there are a couple of key differences: +/// +/// * [`StructArray`] can be nested within other [`Array`], including itself +/// * [`RecordBatch`] can contain top-level metadata on its associated [`Schema`][arrow_schema::Schema] +/// * [`StructArray`] can contain top-level nulls, i.e. `null` +/// * [`RecordBatch`] can only represent nulls in its child columns, i.e. `{"field": null}` +/// +/// [`StructArray`] is therefore a more general data container than [`RecordBatch`], and as such +/// code that needs to handle both will typically share an implementation in terms of +/// [`StructArray`] and convert to/from [`RecordBatch`] as necessary. +/// +/// [`From`] implementations are provided to facilitate this conversion, however, converting +/// from a [`StructArray`] containing top-level nulls to a [`RecordBatch`] will panic, as there +/// is no way to preserve them. +/// +/// # Example: Create an array from a vector of fields +/// +/// ``` +/// use std::sync::Arc; +/// use arrow_array::{Array, ArrayRef, BooleanArray, Int32Array, StructArray}; +/// use arrow_schema::{DataType, Field}; +/// +/// let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); +/// let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); +/// +/// let struct_array = StructArray::from(vec![ +/// ( +/// Arc::new(Field::new("b", DataType::Boolean, false)), +/// boolean.clone() as ArrayRef, +/// ), +/// ( +/// Arc::new(Field::new("c", DataType::Int32, false)), +/// int.clone() as ArrayRef, +/// ), +/// ]); +/// assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref()); +/// assert_eq!(struct_array.column(1).as_ref(), int.as_ref()); +/// assert_eq!(4, struct_array.len()); +/// assert_eq!(0, struct_array.null_count()); +/// assert_eq!(0, struct_array.offset()); +/// ``` +#[derive(Clone)] +pub struct StructArray { + len: usize, + data_type: DataType, + nulls: Option, + fields: Vec, +} + +impl StructArray { + /// Create a new [`StructArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(fields: Fields, arrays: Vec, nulls: Option) -> Self { + Self::try_new(fields, arrays, nulls).unwrap() + } + + /// Create a new [`StructArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// Errors if + /// + /// * `fields.len() != arrays.len()` + /// * `fields[i].data_type() != arrays[i].data_type()` + /// * `arrays[i].len() != arrays[j].len()` + /// * `arrays[i].len() != nulls.len()` + /// * `!fields[i].is_nullable() && !nulls.contains(arrays[i].nulls())` + pub fn try_new( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Result { + if fields.len() != arrays.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of arrays for StructArray fields, expected {} got {}", + fields.len(), + arrays.len() + ))); + } + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of nulls for StructArray, expected {len} got {}", + n.len(), + ))); + } + } + + for (f, a) in fields.iter().zip(&arrays) { + if f.data_type() != a.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect datatype for StructArray field {:?}, expected {} got {}", + f.name(), + f.data_type(), + a.data_type() + ))); + } + + if a.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect array length for StructArray field {:?}, expected {} got {}", + f.name(), + len, + a.len() + ))); + } + + if !f.is_nullable() { + if let Some(a) = a.logical_nulls() { + if !nulls.as_ref().map(|n| n.contains(&a)).unwrap_or_default() { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable StructArray field {:?}", + f.name() + ))); + } + } + } + } + + Ok(Self { + len, + data_type: DataType::Struct(fields), + nulls: nulls.filter(|n| n.null_count() > 0), + fields: arrays, + }) + } + + /// Create a new [`StructArray`] of length `len` where all values are null + pub fn new_null(fields: Fields, len: usize) -> Self { + let arrays = fields + .iter() + .map(|f| new_null_array(f.data_type(), len)) + .collect(); + + Self { + len, + data_type: DataType::Struct(fields), + nulls: Some(NullBuffer::new_null(len)), + fields: arrays, + } + } + + /// Create a new [`StructArray`] from the provided parts without validation + /// + /// # Safety + /// + /// Safe if [`Self::new`] would not panic with the given arguments + pub unsafe fn new_unchecked( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Self { + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + Self { + len, + data_type: DataType::Struct(fields), + nulls, + fields: arrays, + } + } + + /// Create a new [`StructArray`] containing no fields + /// + /// # Panics + /// + /// If `len != nulls.len()` + pub fn new_empty_fields(len: usize, nulls: Option) -> Self { + if let Some(n) = &nulls { + assert_eq!(len, n.len()) + } + Self { + len, + data_type: DataType::Struct(Fields::empty()), + fields: vec![], + nulls, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (Fields, Vec, Option) { + let f = match self.data_type { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + (f, self.fields, self.nulls) + } + + /// Returns the field at `pos`. + pub fn column(&self, pos: usize) -> &ArrayRef { + &self.fields[pos] + } + + /// Return the number of fields in this struct array + pub fn num_columns(&self) -> usize { + self.fields.len() + } + + /// Returns the fields of the struct array + pub fn columns(&self) -> &[ArrayRef] { + &self.fields + } + + /// Returns child array refs of the struct array + #[deprecated(note = "Use columns().to_vec()")] + pub fn columns_ref(&self) -> Vec { + self.columns().to_vec() + } + + /// Return field names in this struct array + pub fn column_names(&self) -> Vec<&str> { + match self.data_type() { + DataType::Struct(fields) => fields + .iter() + .map(|f| f.name().as_str()) + .collect::>(), + _ => unreachable!("Struct array's data type is not struct!"), + } + } + + /// Returns the [`Fields`] of this [`StructArray`] + pub fn fields(&self) -> &Fields { + match self.data_type() { + DataType::Struct(f) => f, + _ => unreachable!(), + } + } + + /// Return child array whose field name equals to column_name + /// + /// Note: A schema can currently have duplicate field names, in which case + /// the first field will always be selected. + /// This issue will be addressed in [ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178) + pub fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> { + self.column_names() + .iter() + .position(|c| c == &column_name) + .map(|pos| self.column(pos)) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced StructArray cannot exceed the existing length" + ); + + let fields = self.fields.iter().map(|a| a.slice(offset, len)).collect(); + + Self { + len, + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + fields, + } + } +} + +impl From for StructArray { + fn from(data: ArrayData) -> Self { + let fields = data + .child_data() + .iter() + .map(|cd| make_array(cd.clone())) + .collect(); + + Self { + len: data.len(), + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + fields, + } + } +} + +impl From for ArrayData { + fn from(array: StructArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .nulls(array.nulls) + .child_data(array.fields.iter().map(|x| x.to_data()).collect()); + + unsafe { builder.build_unchecked() } + } +} + +impl TryFrom> for StructArray { + type Error = ArrowError; + + /// builds a StructArray from a vector of names and arrays. + fn try_from(values: Vec<(&str, ArrayRef)>) -> Result { + let (fields, arrays): (Vec<_>, _) = values + .into_iter() + .map(|(name, array)| { + ( + Field::new(name, array.data_type().clone(), array.is_nullable()), + array, + ) + }) + .unzip(); + + StructArray::try_new(fields.into(), arrays, None) + } +} + +impl Array for StructArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.fields.iter().map(|a| a.get_buffer_memory_size()).sum(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = self.fields.iter().map(|a| a.get_array_memory_size()).sum(); + size += std::mem::size_of::(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl From> for StructArray { + fn from(v: Vec<(FieldRef, ArrayRef)>) -> Self { + let (fields, arrays): (Vec<_>, _) = v.into_iter().unzip(); + StructArray::new(fields.into(), arrays, None) + } +} + +impl std::fmt::Debug for StructArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "StructArray\n[\n")?; + for (child_index, name) in self.column_names().iter().enumerate() { + let column = self.column(child_index); + writeln!( + f, + "-- child {}: \"{}\" ({:?})", + child_index, + name, + column.data_type() + )?; + std::fmt::Debug::fmt(column, f)?; + writeln!(f)?; + } + write!(f, "]") + } +} + +impl From<(Vec<(FieldRef, ArrayRef)>, Buffer)> for StructArray { + fn from(pair: (Vec<(FieldRef, ArrayRef)>, Buffer)) -> Self { + let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default(); + let (fields, arrays): (Vec<_>, Vec<_>) = pair.0.into_iter().unzip(); + let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len)); + Self::new(fields.into(), arrays, Some(nulls)) + } +} + +impl From for StructArray { + fn from(value: RecordBatch) -> Self { + Self { + len: value.num_rows(), + data_type: DataType::Struct(value.schema().fields().clone()), + nulls: None, + fields: value.columns().to_vec(), + } + } +} + +impl Index<&str> for StructArray { + type Output = ArrayRef; + + /// Get a reference to a column's array by name. + /// + /// Note: A schema can currently have duplicate field names, in which case + /// the first field will always be selected. + /// This issue will be addressed in [ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178) + /// + /// # Panics + /// + /// Panics if the name is not in the schema. + fn index(&self, name: &str) -> &Self::Output { + self.column_by_name(name).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray}; + use arrow_buffer::ToByteSlice; + + #[test] + fn test_struct_array_builder() { + let boolean_array = BooleanArray::from(vec![false, false, true, true]); + let int_array = Int64Array::from(vec![42, 28, 19, 31]); + + let fields = vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int64, false), + ]; + let struct_array_data = ArrayData::builder(DataType::Struct(fields.into())) + .len(4) + .add_child_data(boolean_array.to_data()) + .add_child_data(int_array.to_data()) + .build() + .unwrap(); + let struct_array = StructArray::from(struct_array_data); + + assert_eq!(struct_array.column(0).as_ref(), &boolean_array); + assert_eq!(struct_array.column(1).as_ref(), &int_array); + } + + #[test] + fn test_struct_array_from() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref()); + assert_eq!(struct_array.column(1).as_ref(), int.as_ref()); + assert_eq!(4, struct_array.len()); + assert_eq!(0, struct_array.null_count()); + assert_eq!(0, struct_array.offset()); + } + + /// validates that struct can be accessed using `column_name` as index i.e. `struct_array["column_name"]`. + #[test] + fn test_struct_array_index_access() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + assert_eq!(struct_array["b"].as_ref(), boolean.as_ref()); + assert_eq!(struct_array["c"].as_ref(), int.as_ref()); + } + + /// validates that the in-memory representation follows [the spec](https://arrow.apache.org/docs/format/Columnar.html#struct-layout) + #[test] + fn test_struct_array_from_vec() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + + let arr = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]).unwrap(); + + let struct_data = arr.into_data(); + assert_eq!(4, struct_data.len()); + assert_eq!(0, struct_data.null_count()); + + let expected_string_data = ArrayData::builder(DataType::Utf8) + .len(4) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) + .add_buffer(Buffer::from([0, 3, 3, 3, 7].to_byte_slice())) + .add_buffer(Buffer::from(b"joemark")) + .build() + .unwrap(); + + let expected_int_data = ArrayData::builder(DataType::Int32) + .len(4) + .null_bit_buffer(Some(Buffer::from(&[11_u8]))) + .add_buffer(Buffer::from([1, 2, 0, 4].to_byte_slice())) + .build() + .unwrap(); + + assert_eq!(expected_string_data, struct_data.child_data()[0]); + assert_eq!(expected_int_data, struct_data.child_data()[1]); + } + + #[test] + fn test_struct_array_from_vec_error() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + // 3 elements, not 4 + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + + let err = StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4" + ) + } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" + )] + fn test_struct_array_from_mismatched_types_single() { + drop(StructArray::from(vec![( + Arc::new(Field::new("b", DataType::Int16, false)), + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, + )])); + } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" + )] + fn test_struct_array_from_mismatched_types_multiple() { + drop(StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Int16, false)), + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ), + ])); + } + + #[test] + fn test_struct_array_slice() { + let boolean_data = ArrayData::builder(DataType::Boolean) + .len(5) + .add_buffer(Buffer::from([0b00010000])) + .null_bit_buffer(Some(Buffer::from([0b00010001]))) + .build() + .unwrap(); + let int_data = ArrayData::builder(DataType::Int32) + .len(5) + .add_buffer(Buffer::from([0, 28, 42, 0, 0].to_byte_slice())) + .null_bit_buffer(Some(Buffer::from([0b00000110]))) + .build() + .unwrap(); + + let field_types = vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), + ]; + let struct_array_data = ArrayData::builder(DataType::Struct(field_types.into())) + .len(5) + .add_child_data(boolean_data.clone()) + .add_child_data(int_data.clone()) + .null_bit_buffer(Some(Buffer::from([0b00010111]))) + .build() + .unwrap(); + let struct_array = StructArray::from(struct_array_data); + + assert_eq!(5, struct_array.len()); + assert_eq!(1, struct_array.null_count()); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_valid(2)); + assert!(struct_array.is_null(3)); + assert!(struct_array.is_valid(4)); + assert_eq!(boolean_data, struct_array.column(0).to_data()); + assert_eq!(int_data, struct_array.column(1).to_data()); + + let c0 = struct_array.column(0); + let c0 = c0.as_any().downcast_ref::().unwrap(); + assert_eq!(5, c0.len()); + assert_eq!(3, c0.null_count()); + assert!(c0.is_valid(0)); + assert!(!c0.value(0)); + assert!(c0.is_null(1)); + assert!(c0.is_null(2)); + assert!(c0.is_null(3)); + assert!(c0.is_valid(4)); + assert!(c0.value(4)); + + let c1 = struct_array.column(1); + let c1 = c1.as_any().downcast_ref::().unwrap(); + assert_eq!(5, c1.len()); + assert_eq!(3, c1.null_count()); + assert!(c1.is_null(0)); + assert!(c1.is_valid(1)); + assert_eq!(28, c1.value(1)); + assert!(c1.is_valid(2)); + assert_eq!(42, c1.value(2)); + assert!(c1.is_null(3)); + assert!(c1.is_null(4)); + + let sliced_array = struct_array.slice(2, 3); + let sliced_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_array.len()); + assert_eq!(1, sliced_array.null_count()); + assert!(sliced_array.is_valid(0)); + assert!(sliced_array.is_null(1)); + assert!(sliced_array.is_valid(2)); + + let sliced_c0 = sliced_array.column(0); + let sliced_c0 = sliced_c0.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_c0.len()); + assert!(sliced_c0.is_null(0)); + assert!(sliced_c0.is_null(1)); + assert!(sliced_c0.is_valid(2)); + assert!(sliced_c0.value(2)); + + let sliced_c1 = sliced_array.column(1); + let sliced_c1 = sliced_c1.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_c1.len()); + assert!(sliced_c1.is_valid(0)); + assert_eq!(42, sliced_c1.value(0)); + assert!(sliced_c1.is_null(1)); + assert!(sliced_c1.is_null(2)); + } + + #[test] + #[should_panic( + expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2" + )] + fn test_invalid_struct_child_array_lengths() { + drop(StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Float32, false)), + Arc::new(Float32Array::from(vec![1.1])) as Arc, + ), + ( + Arc::new(Field::new("c", DataType::Float64, false)), + Arc::new(Float64Array::from(vec![2.2, 3.3])), + ), + ])); + } + + #[test] + fn test_struct_array_from_empty() { + let sa = StructArray::from(vec![]); + assert!(sa.is_empty()) + } + + #[test] + #[should_panic(expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"")] + fn test_struct_array_from_mismatched_nullability() { + drop(StructArray::from(vec![( + Arc::new(Field::new("c", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![Some(42), None, Some(19)])) as ArrayRef, + )])); + } +} diff --git a/arrow/src/array/array_union.rs b/arrow-array/src/array/union_array.rs similarity index 52% rename from arrow/src/array/array_union.rs rename to arrow-array/src/array/union_array.rs index b221239b2dbe..ea4853cd1528 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow-array/src/array/union_array.rs @@ -15,26 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::{make_array, Array, ArrayRef}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ScalarBuffer; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode}; /// Contains the `UnionArray` type. /// -use crate::array::{make_array, Array, ArrayData, ArrayRef}; -use crate::buffer::Buffer; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; - -use core::fmt; use std::any::Any; +use std::sync::Arc; -/// An Array that can represent slots of varying types. +/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout) /// /// Each slot in a [UnionArray] can have a value chosen from a number /// of types. Each of the possible types are named like the fields of -/// a [`StructArray`](crate::array::StructArray). A `UnionArray` can +/// a [`StructArray`](crate::StructArray). A `UnionArray` can /// have two possible memory layouts, "dense" or "sparse". For more /// information on please see the /// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout). /// -/// [UnionBuilder](crate::array::UnionBuilder) can be used to +/// [UnionBuilder](crate::builder::UnionBuilder) can be used to /// create [UnionArray]'s of primitive types. `UnionArray`'s of nested /// types are also supported but not via `UnionBuilder`, see the tests /// for examples. @@ -42,25 +42,30 @@ use std::any::Any; /// # Examples /// ## Create a dense UnionArray `[1, 3.2, 34]` /// ``` -/// use arrow::buffer::Buffer; -/// use arrow::datatypes::*; +/// use arrow_buffer::ScalarBuffer; +/// use arrow_schema::*; /// use std::sync::Arc; -/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; /// /// let int_array = Int32Array::from(vec![1, 34]); /// let float_array = Float64Array::from(vec![3.2]); -/// let type_id_buffer = Buffer::from_slice_ref(&[0_i8, 1, 0]); -/// let value_offsets_buffer = Buffer::from_slice_ref(&[0_i32, 0, 1]); +/// let type_ids = [0, 1, 0].into_iter().collect::>(); +/// let offsets = [0, 0, 1].into_iter().collect::>(); +/// +/// let union_fields = [ +/// (0, Arc::new(Field::new("A", DataType::Int32, false))), +/// (1, Arc::new(Field::new("B", DataType::Float64, false))), +/// ].into_iter().collect::(); /// -/// let children: Vec<(Field, Arc)> = vec![ -/// (Field::new("A", DataType::Int32, false), Arc::new(int_array)), -/// (Field::new("B", DataType::Float64, false), Arc::new(float_array)), +/// let children = vec![ +/// Arc::new(int_array) as Arc, +/// Arc::new(float_array), /// ]; /// /// let array = UnionArray::try_new( -/// &vec![0, 1], -/// type_id_buffer, -/// Some(value_offsets_buffer), +/// union_fields, +/// type_ids, +/// Some(offsets), /// children, /// ).unwrap(); /// @@ -76,23 +81,28 @@ use std::any::Any; /// /// ## Create a sparse UnionArray `[1, 3.2, 34]` /// ``` -/// use arrow::buffer::Buffer; -/// use arrow::datatypes::*; +/// use arrow_buffer::ScalarBuffer; +/// use arrow_schema::*; /// use std::sync::Arc; -/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; /// /// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); /// let float_array = Float64Array::from(vec![None, Some(3.2), None]); -/// let type_id_buffer = Buffer::from_slice_ref(&[0_i8, 1, 0]); +/// let type_ids = [0_i8, 1, 0].into_iter().collect::>(); +/// +/// let union_fields = [ +/// (0, Arc::new(Field::new("A", DataType::Int32, false))), +/// (1, Arc::new(Field::new("B", DataType::Float64, false))), +/// ].into_iter().collect::(); /// -/// let children: Vec<(Field, Arc)> = vec![ -/// (Field::new("A", DataType::Int32, false), Arc::new(int_array)), -/// (Field::new("B", DataType::Float64, false), Arc::new(float_array)), +/// let children = vec![ +/// Arc::new(int_array) as Arc, +/// Arc::new(float_array), /// ]; /// /// let array = UnionArray::try_new( -/// &vec![0, 1], -/// type_id_buffer, +/// union_fields, +/// type_ids, /// None, /// children, /// ).unwrap(); @@ -106,9 +116,12 @@ use std::any::Any; /// let value = array.value(2).as_any().downcast_ref::().unwrap().value(0); /// assert_eq!(34, value); /// ``` +#[derive(Clone)] pub struct UnionArray { - data: ArrayData, - boxed_fields: Vec, + data_type: DataType, + type_ids: ScalarBuffer, + offsets: Option>, + fields: Vec>, } impl UnionArray { @@ -121,143 +134,143 @@ impl UnionArray { /// /// # Safety /// - /// The `type_ids` `Buffer` should contain `i8` values. These values should be greater than - /// zero and must be less than the number of children provided in `child_arrays`. These values - /// are used to index into the `child_arrays`. + /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`. + /// These values are used to index into the `children` arrays. /// - /// The `value_offsets` `Buffer` is only provided in the case of a dense union, sparse unions - /// should use `None`. If provided the `value_offsets` `Buffer` should contain `i32` values. - /// The values in this array should be greater than zero and must be less than the length of the - /// overall array. + /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`. + /// If provided the `offsets` values should be positive and must be less than the length of the + /// corresponding array. /// /// In both cases above we use signed integer types to maintain compatibility with other /// Arrow implementations. - /// - /// In both of the cases above we are accepting `Buffer`'s which are assumed to be representing - /// `i8` and `i32` values respectively. `Buffer` objects are untyped and no attempt is made - /// to ensure that the data provided is valid. pub unsafe fn new_unchecked( - field_type_ids: &[i8], - type_ids: Buffer, - value_offsets: Option, - child_arrays: Vec<(Field, ArrayRef)>, + fields: UnionFields, + type_ids: ScalarBuffer, + offsets: Option>, + children: Vec, ) -> Self { - let (field_types, field_values): (Vec<_>, Vec<_>) = - child_arrays.into_iter().unzip(); - let len = type_ids.len(); - - let mode = if value_offsets.is_some() { + let mode = if offsets.is_some() { UnionMode::Dense } else { UnionMode::Sparse }; - let builder = ArrayData::builder(DataType::Union( - field_types, - Vec::from(field_type_ids), - mode, - )) - .add_buffer(type_ids) - .child_data(field_values.into_iter().map(|a| a.into_data()).collect()) - .len(len); - - let data = match value_offsets { - Some(b) => builder.add_buffer(b).build_unchecked(), + let len = type_ids.len(); + let builder = ArrayData::builder(DataType::Union(fields, mode)) + .add_buffer(type_ids.into_inner()) + .child_data(children.into_iter().map(Array::into_data).collect()) + .len(len); + + let data = match offsets { + Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(), None => builder.build_unchecked(), }; Self::from(data) } /// Attempts to create a new `UnionArray`, validating the inputs provided. + /// + /// The order of child arrays child array order must match the fields order pub fn try_new( - field_type_ids: &[i8], - type_ids: Buffer, - value_offsets: Option, - child_arrays: Vec<(Field, ArrayRef)>, - ) -> Result { - if let Some(b) = &value_offsets { - if ((type_ids.len()) * 4) != b.len() { + fields: UnionFields, + type_ids: ScalarBuffer, + offsets: Option>, + children: Vec, + ) -> Result { + // There must be a child array for every field. + if fields.len() != children.len() { + return Err(ArrowError::InvalidArgumentError( + "Union fields length must match child arrays length".to_string(), + )); + } + + // There must be an offset value for every type id value. + if let Some(offsets) = &offsets { + if offsets.len() != type_ids.len() { return Err(ArrowError::InvalidArgumentError( - "Type Ids and Offsets represent a different number of array slots." - .to_string(), + "Type Ids and Offsets lengths must match".to_string(), )); } } - // Check the type_ids - let type_id_slice: &[i8] = type_ids.typed_data(); - let invalid_type_ids = type_id_slice - .iter() - .filter(|i| *i < &0) - .collect::>(); - if !invalid_type_ids.is_empty() { - return Err(ArrowError::InvalidArgumentError(format!( - "Type Ids must be positive and cannot be greater than the number of \ - child arrays, found:\n{:?}", - invalid_type_ids - ))); - } - - // Check the value offsets if provided - if let Some(offset_buffer) = &value_offsets { - let max_len = type_ids.len() as i32; - let offsets_slice: &[i32] = offset_buffer.typed_data(); - let invalid_offsets = offsets_slice - .iter() - .filter(|i| *i < &0 || *i > &max_len) - .collect::>(); - if !invalid_offsets.is_empty() { - return Err(ArrowError::InvalidArgumentError(format!( - "Offsets must be positive and within the length of the Array, \ - found:\n{:?}", - invalid_offsets - ))); + // Create mapping from type id to array lengths. + let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize; + let mut array_lens = vec![i32::MIN; max_id + 1]; + for (cd, (field_id, _)) in children.iter().zip(fields.iter()) { + array_lens[field_id as usize] = cd.len() as i32; + } + + // Type id values must match one of the fields. + for id in &type_ids { + match array_lens.get(*id as usize) { + Some(x) if *x != i32::MIN => {} + _ => { + return Err(ArrowError::InvalidArgumentError( + "Type Ids values must match one of the field type ids".to_owned(), + )) + } } } - // Unsafe Justification: arguments were validated above (and - // re-revalidated as part of data().validate() below) - let new_self = unsafe { - Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) - }; - new_self.data().validate()?; + // Check the value offsets are in bounds. + if let Some(offsets) = &offsets { + let mut iter = type_ids.iter().zip(offsets.iter()); + if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize]) + { + return Err(ArrowError::InvalidArgumentError( + "Offsets must be positive and within the length of the Array".to_owned(), + )); + } + } - Ok(new_self) + // Safety: + // - Arguments validated above. + let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) }; + Ok(union_array) } /// Accesses the child array for `type_id`. /// /// # Panics /// - /// Panics if the `type_id` provided is less than zero or greater than the number of types + /// Panics if the `type_id` provided is not present in the array's DataType /// in the `Union`. pub fn child(&self, type_id: i8) -> &ArrayRef { - assert!(0 <= type_id); - assert!((type_id as usize) < self.boxed_fields.len()); - &self.boxed_fields[type_id as usize] + assert!((type_id as usize) < self.fields.len()); + let boxed = &self.fields[type_id as usize]; + boxed.as_ref().expect("invalid type id") } /// Returns the `type_id` for the array slot at `index`. /// /// # Panics /// - /// Panics if `index` is greater than the length of the array. + /// Panics if `index` is greater than or equal to the number of child arrays pub fn type_id(&self, index: usize) -> i8 { - assert!(index < self.len()); - self.data().buffers()[0].as_slice()[self.offset() + index] as i8 + assert!(index < self.type_ids.len()); + self.type_ids[index] + } + + /// Returns the `type_ids` buffer for this array + pub fn type_ids(&self) -> &ScalarBuffer { + &self.type_ids + } + + /// Returns the `offsets` buffer if this is a dense array + pub fn offsets(&self) -> Option<&ScalarBuffer> { + self.offsets.as_ref() } /// Returns the offset into the underlying values array for the array slot at `index`. /// /// # Panics /// - /// Panics if `index` is greater than the length of the array. - pub fn value_offset(&self, index: usize) -> i32 { + /// Panics if `index` is greater than or equal the length of the array. + pub fn value_offset(&self, index: usize) -> usize { assert!(index < self.len()); - if self.is_dense() { - self.data().buffers()[1].typed_data::()[self.offset() + index] - } else { - (self.offset() + index) as i32 + match &self.offsets { + Some(offsets) => offsets[index] as usize, + None => self.offset() + index, } } @@ -266,17 +279,17 @@ impl UnionArray { /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { let type_id = self.type_id(i); - let value_offset = self.value_offset(i) as usize; - let child_data = self.boxed_fields[type_id as usize].clone(); - child_data.slice(value_offset, 1) + let value_offset = self.value_offset(i); + let child = self.child(type_id); + child.slice(value_offset, 1) } /// Returns the names of the types in the union. pub fn type_names(&self) -> Vec<&str> { - match self.data.data_type() { - DataType::Union(fields, _, _) => fields + match self.data_type() { + DataType::Union(fields, _) => fields .iter() - .map(|f| f.name().as_str()) + .map(|(_, f)| f.name().as_str()) .collect::>(), _ => unreachable!("Union array's data type is not a union!"), } @@ -284,26 +297,148 @@ impl UnionArray { /// Returns whether the `UnionArray` is dense (or sparse if `false`). fn is_dense(&self) -> bool { - match self.data.data_type() { - DataType::Union(_, _, mode) => mode == &UnionMode::Dense, + match self.data_type() { + DataType::Union(_, mode) => mode == &UnionMode::Dense, _ => unreachable!("Union array's data type is not a union!"), } } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + let (offsets, fields) = match self.offsets.as_ref() { + // If dense union, slice offsets + Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()), + // Otherwise need to slice sparse children + None => { + let fields = self + .fields + .iter() + .map(|x| x.as_ref().map(|x| x.slice(offset, length))) + .collect(); + (None, fields) + } + }; + + Self { + data_type: self.data_type.clone(), + type_ids: self.type_ids.slice(offset, length), + offsets, + fields, + } + } + + /// Deconstruct this array into its constituent parts + /// + /// # Example + /// + /// ``` + /// # use arrow_array::array::UnionArray; + /// # use arrow_array::types::Int32Type; + /// # use arrow_array::builder::UnionBuilder; + /// # use arrow_buffer::ScalarBuffer; + /// # fn main() -> Result<(), arrow_schema::ArrowError> { + /// let mut builder = UnionBuilder::new_dense(); + /// builder.append::("a", 1).unwrap(); + /// let union_array = builder.build()?; + /// + /// // Deconstruct into parts + /// let (union_fields, type_ids, offsets, children) = union_array.into_parts(); + /// + /// // Reconstruct from parts + /// let union_array = UnionArray::try_new( + /// union_fields, + /// type_ids, + /// offsets, + /// children, + /// ); + /// # Ok(()) + /// # } + /// ``` + #[allow(clippy::type_complexity)] + pub fn into_parts( + self, + ) -> ( + UnionFields, + ScalarBuffer, + Option>, + Vec, + ) { + let Self { + data_type, + type_ids, + offsets, + mut fields, + } = self; + match data_type { + DataType::Union(union_fields, _) => { + let children = union_fields + .iter() + .map(|(type_id, _)| fields[type_id as usize].take().unwrap()) + .collect(); + (union_fields, type_ids, offsets, children) + } + _ => unreachable!(), + } + } } impl From for UnionArray { fn from(data: ArrayData) -> Self { - let mut boxed_fields = vec![]; - for cd in data.child_data() { - boxed_fields.push(make_array(cd.clone())); + let (fields, mode) = match data.data_type() { + DataType::Union(fields, mode) => (fields, *mode), + d => panic!("UnionArray expected ArrayData with type Union got {d}"), + }; + let (type_ids, offsets) = match mode { + UnionMode::Sparse => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + None, + ), + UnionMode::Dense => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + Some(ScalarBuffer::new( + data.buffers()[1].clone(), + data.offset(), + data.len(), + )), + ), + }; + + let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize; + let mut boxed_fields = vec![None; max_id + 1]; + for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) { + boxed_fields[field_id as usize] = Some(make_array(cd.clone())); + } + Self { + data_type: data.data_type().clone(), + type_ids, + offsets, + fields: boxed_fields, } - Self { data, boxed_fields } } } impl From for ArrayData { fn from(array: UnionArray) -> Self { - array.data + let len = array.len(); + let f = match &array.data_type { + DataType::Union(f, _) => f, + _ => unreachable!(), + }; + let buffers = match array.offsets { + Some(o) => vec![array.type_ids.into_inner(), o.into_inner()], + None => vec![array.type_ids.into_inner()], + }; + + let child = f + .iter() + .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data()) + .collect(); + + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .buffers(buffers) + .child_data(child); + unsafe { builder.build_unchecked() } } } @@ -312,14 +447,38 @@ impl Array for UnionArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.type_ids.len() + } + + fn is_empty(&self) -> bool { + self.type_ids.is_empty() + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + /// Union types always return non null as there is no validity buffer. /// To check validity correctly you must check the underlying vector. fn is_null(&self, _index: usize) -> bool { @@ -337,35 +496,66 @@ impl Array for UnionArray { fn null_count(&self) -> usize { 0 } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.type_ids.inner().capacity(); + if let Some(o) = self.offsets.as_ref() { + sum += o.inner().capacity() + } + self.fields + .iter() + .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size())) + .sum::() + + sum + } + + fn get_array_memory_size(&self) -> usize { + let mut sum = self.type_ids.inner().capacity(); + if let Some(o) = self.offsets.as_ref() { + sum += o.inner().capacity() + } + std::mem::size_of::() + + self + .fields + .iter() + .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size())) + .sum::() + + sum + } } -impl fmt::Debug for UnionArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl std::fmt::Debug for UnionArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let header = if self.is_dense() { "UnionArray(Dense)\n[" } else { "UnionArray(Sparse)\n[" }; - writeln!(f, "{}", header)?; + writeln!(f, "{header}")?; writeln!(f, "-- type id buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[0])?; + writeln!(f, "{:?}", self.type_ids)?; - if self.is_dense() { + if let Some(offsets) = &self.offsets { writeln!(f, "-- offsets buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[1])?; + writeln!(f, "{:?}", offsets)?; } - for (child_index, name) in self.type_names().iter().enumerate() { - let column = &self.boxed_fields[child_index]; + let fields = match self.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + for (type_id, field) in fields.iter() { + let child = self.child(type_id); writeln!( f, "-- child {}: \"{}\" ({:?})", - child_index, - *name, - column.data_type() + type_id, + field.name(), + field.data_type() )?; - fmt::Debug::fmt(column, f)?; + std::fmt::Debug::fmt(child, f)?; writeln!(f)?; } writeln!(f, "]") @@ -375,13 +565,16 @@ impl fmt::Debug for UnionArray { #[cfg(test)] mod tests { use super::*; + use std::collections::HashSet; - use std::sync::Arc; - - use crate::array::*; - use crate::buffer::Buffer; - use crate::datatypes::{DataType, Field}; - use crate::record_batch::RecordBatch; + use crate::array::Int8Type; + use crate::builder::UnionBuilder; + use crate::cast::AsArray; + use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type}; + use crate::RecordBatch; + use crate::{Float64Array, Int32Array, Int64Array, StringArray}; + use arrow_buffer::Buffer; + use arrow_schema::{Field, Schema}; #[test] fn test_dense_i32() { @@ -396,39 +589,33 @@ mod tests { let union = builder.build().unwrap(); let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1]; - let expected_value_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; + let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } // Check data assert_eq!( - union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref(&[1_i32, 4, 6]) + *union.child(0).as_primitive::().values(), + [1_i32, 4, 6] ); assert_eq!( - union.data().child_data()[1].buffers()[0], - Buffer::from_slice_ref(&[2_i32, 7]) + *union.child(1).as_primitive::().values(), + [2_i32, 7] ); assert_eq!( - union.data().child_data()[2].buffers()[0], - Buffer::from_slice_ref(&[3_i32, 5]), + *union.child(2).as_primitive::().values(), + [3_i32, 5] ); assert_eq!(expected_array_values.len(), union.len()); @@ -448,7 +635,7 @@ mod tests { let mut builder = UnionBuilder::new_dense(); let expected_type_ids = vec![0_i8; 1024]; - let expected_value_offsets: Vec<_> = (0..1024).collect(); + let expected_offsets: Vec<_> = (0..1024).collect(); let expected_array_values: Vec<_> = (1..=1024).collect(); expected_array_values @@ -458,27 +645,21 @@ mod tests { let union = builder.build().unwrap(); // Check type ids - assert_eq!( - union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } for (i, expected_value) in expected_array_values.iter().enumerate() { assert!(!union.is_null(i)); let slot = union.value(i); - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(expected_value, &value); @@ -626,44 +807,38 @@ mod tests { let int_array = Int32Array::from(vec![5, 6]); let float_array = Float64Array::from(vec![10.0]); - let type_ids = [1_i8, 0, 0, 2, 0, 1]; - let value_offsets = [0_i32, 0, 1, 0, 2, 1]; - - let type_id_buffer = Buffer::from_slice_ref(&type_ids); - let value_offsets_buffer = Buffer::from_slice_ref(&value_offsets); - - let children: Vec<(Field, Arc)> = vec![ - ( - Field::new("A", DataType::Utf8, false), - Arc::new(string_array), - ), - (Field::new("B", DataType::Int32, false), Arc::new(int_array)), - ( - Field::new("C", DataType::Float64, false), - Arc::new(float_array), - ), - ]; - let array = UnionArray::try_new( - &[0, 1, 2], - type_id_buffer, - Some(value_offsets_buffer), - children, - ) - .unwrap(); + let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::>(); + let offsets = [0, 0, 1, 0, 2, 1] + .into_iter() + .collect::>(); + + let fields = [ + (0, Arc::new(Field::new("A", DataType::Utf8, false))), + (1, Arc::new(Field::new("B", DataType::Int32, false))), + (2, Arc::new(Field::new("C", DataType::Float64, false))), + ] + .into_iter() + .collect::(); + let children = [ + Arc::new(string_array) as Arc, + Arc::new(int_array), + Arc::new(float_array), + ] + .into_iter() + .collect(); + let array = + UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap(); // Check type ids - assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]); + assert_eq!(*array.type_ids(), type_ids); for (i, id) in type_ids.iter().enumerate() { assert_eq!(id, &array.type_id(i)); } // Check offsets - assert_eq!( - Buffer::from_slice_ref(&value_offsets), - array.data().buffers()[1] - ); - for (i, id) in value_offsets.iter().enumerate() { - assert_eq!(id, &array.value_offset(i)); + assert_eq!(*array.offsets().unwrap(), offsets); + for (i, id) in offsets.iter().enumerate() { + assert_eq!(*id as usize, array.value_offset(i)); } // Check values @@ -726,29 +901,26 @@ mod tests { let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); // Check data assert_eq!( - union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref(&[1_i32, 0, 0, 4, 0, 6, 0]), + *union.child(0).as_primitive::().values(), + [1_i32, 0, 0, 4, 0, 6, 0], ); assert_eq!( - Buffer::from_slice_ref(&[0_i32, 2_i32, 0, 0, 0, 0, 7]), - union.data().child_data()[1].buffers()[0] + *union.child(1).as_primitive::().values(), + [0_i32, 2_i32, 0, 0, 0, 0, 7] ); assert_eq!( - Buffer::from_slice_ref(&[0_i32, 0, 3_i32, 0, 5, 0, 0]), - union.data().child_data()[2].buffers()[0] + *union.child(2).as_primitive::().values(), + [0_i32, 0, 3_i32, 0, 5, 0, 0] ); assert_eq!(expected_array_values.len(), union.len()); @@ -775,16 +947,13 @@ mod tests { let expected_type_ids = vec![0_i8, 1, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -837,16 +1006,13 @@ mod tests { let expected_type_ids = vec![0_i8, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -897,7 +1063,7 @@ mod tests { match i { 0 => assert!(slot.is_null(0)), 1 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -905,7 +1071,7 @@ mod tests { } 2 => assert!(slot.is_null(0)), 3 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -926,7 +1092,7 @@ mod tests { } #[test] - fn test_union_array_validaty() { + fn test_union_array_validity() { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); @@ -953,7 +1119,13 @@ mod tests { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1.0).unwrap(); let err = builder.append::("a", 1).unwrap_err().to_string(); - assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); + assert!( + err.contains( + "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32" + ), + "{}", + err + ); } #[test] @@ -990,18 +1162,18 @@ mod tests { assert_eq!(union_slice.type_id(2), 1); let slot = union_slice.value(0); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); let slot = union_slice.value(1); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_valid(0)); assert_eq!(array.value(0), 3.0); let slot = union_slice.value(2); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); } @@ -1020,4 +1192,225 @@ mod tests { let record_batch_slice = record_batch.slice(1, 3); test_slice_union(record_batch_slice); } + + #[test] + fn test_custom_type_ids() { + let data_type = DataType::Union( + UnionFields::new( + vec![8, 4, 9], + vec![ + Field::new("strings", DataType::Utf8, false), + Field::new("integers", DataType::Int32, false), + Field::new("floats", DataType::Float64, false), + ], + ), + UnionMode::Dense, + ); + + let string_array = StringArray::from(vec!["foo", "bar", "baz"]); + let int_array = Int32Array::from(vec![5, 6, 4]); + let float_array = Float64Array::from(vec![10.0]); + + let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); + let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); + + let data = ArrayData::builder(data_type) + .len(7) + .buffers(vec![type_ids, value_offsets]) + .child_data(vec![ + string_array.into_data(), + int_array.into_data(), + float_array.into_data(), + ]) + .build() + .unwrap(); + + let array = UnionArray::from(data); + + let v = array.value(0); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 5); + + let v = array.value(1); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "foo"); + + let v = array.value(2); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 6); + + let v = array.value(3); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "bar"); + + let v = array.value(4); + assert_eq!(v.data_type(), &DataType::Float64); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 10.0); + + let v = array.value(5); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 4); + + let v = array.value(6); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "baz"); + } + + #[test] + fn into_parts() { + let mut builder = UnionBuilder::new_dense(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("a", 3).unwrap(); + let dense_union = builder.build().unwrap(); + + let field = [ + &Arc::new(Field::new("a", DataType::Int32, false)), + &Arc::new(Field::new("b", DataType::Int8, false)), + ]; + let (union_fields, type_ids, offsets, children) = dense_union.into_parts(); + assert_eq!( + union_fields + .iter() + .map(|(_, field)| field) + .collect::>(), + field + ); + assert_eq!(type_ids, [0, 1, 0]); + assert!(offsets.is_some()); + assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]); + + let result = UnionArray::try_new(union_fields, type_ids, offsets, children); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 3); + + let mut builder = UnionBuilder::new_sparse(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("a", 3).unwrap(); + let sparse_union = builder.build().unwrap(); + + let (union_fields, type_ids, offsets, children) = sparse_union.into_parts(); + assert_eq!(type_ids, [0, 1, 0]); + assert!(offsets.is_none()); + + let result = UnionArray::try_new(union_fields, type_ids, offsets, children); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 3); + } + + #[test] + fn into_parts_custom_type_ids() { + let set_field_type_ids: [i8; 3] = [8, 4, 9]; + let data_type = DataType::Union( + UnionFields::new( + set_field_type_ids, + [ + Field::new("strings", DataType::Utf8, false), + Field::new("integers", DataType::Int32, false), + Field::new("floats", DataType::Float64, false), + ], + ), + UnionMode::Dense, + ); + let string_array = StringArray::from(vec!["foo", "bar", "baz"]); + let int_array = Int32Array::from(vec![5, 6, 4]); + let float_array = Float64Array::from(vec![10.0]); + let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); + let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); + let data = ArrayData::builder(data_type) + .len(7) + .buffers(vec![type_ids, value_offsets]) + .child_data(vec![ + string_array.into_data(), + int_array.into_data(), + float_array.into_data(), + ]) + .build() + .unwrap(); + let array = UnionArray::from(data); + + let (union_fields, type_ids, offsets, children) = array.into_parts(); + assert_eq!( + type_ids.iter().collect::>(), + set_field_type_ids.iter().collect::>() + ); + let result = UnionArray::try_new(union_fields, type_ids, offsets, children); + assert!(result.is_ok()); + let array = result.unwrap(); + assert_eq!(array.len(), 7); + } + + #[test] + fn test_invalid() { + let fields = UnionFields::new( + [3, 2], + [ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ], + ); + let children = vec![ + Arc::new(StringArray::from_iter_values(["a", "b"])) as _, + Arc::new(StringArray::from_iter_values(["c", "d"])) as _, + ]; + + let type_ids = vec![3, 3, 2].into(); + UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap(); + + let type_ids = vec![1, 2].into(); + let err = + UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Type Ids values must match one of the field type ids" + ); + + let type_ids = vec![7, 2].into(); + let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Type Ids values must match one of the field type ids" + ); + + let children = vec![ + Arc::new(StringArray::from_iter_values(["a", "b"])) as _, + Arc::new(StringArray::from_iter_values(["c"])) as _, + ]; + let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]); + let offsets = Some(vec![0, 1, 0].into()); + UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap(); + + let offsets = Some(vec![0, 1, 1].into()); + let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Offsets must be positive and within the length of the Array" + ); + + let offsets = Some(vec![0, 1].into()); + let err = + UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Type Ids and Offsets lengths must match" + ); + + let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Union fields length must match child arrays length" + ); + } } diff --git a/arrow/src/array/builder/boolean_builder.rs b/arrow-array/src/builder/boolean_builder.rs similarity index 62% rename from arrow/src/array/builder/boolean_builder.rs rename to arrow-array/src/builder/boolean_builder.rs index eed14a55fd91..3a2caf80cff7 100644 --- a/arrow/src/array/builder/boolean_builder.rs +++ b/arrow-array/src/builder/boolean_builder.rs @@ -15,50 +15,45 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::{ArrayBuilder, BooleanBufferBuilder}; +use crate::{ArrayRef, BooleanArray}; +use arrow_buffer::Buffer; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use crate::array::ArrayBuilder; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::BooleanArray; -use crate::datatypes::DataType; - -use crate::error::ArrowError; -use crate::error::Result; - -use super::BooleanBufferBuilder; -use super::NullBufferBuilder; - -/// Array builder for fixed-width primitive types +/// Builder for [`BooleanArray`] /// /// # Example /// /// Create a `BooleanArray` from a `BooleanBuilder` /// /// ``` -/// use arrow::array::{Array, BooleanArray, BooleanBuilder}; /// -/// let mut b = BooleanBuilder::new(); -/// b.append_value(true); -/// b.append_null(); -/// b.append_value(false); -/// b.append_value(true); -/// let arr = b.finish(); +/// # use arrow_array::{Array, BooleanArray, builder::BooleanBuilder}; /// -/// assert_eq!(4, arr.len()); -/// assert_eq!(1, arr.null_count()); -/// assert_eq!(true, arr.value(0)); -/// assert!(arr.is_valid(0)); -/// assert!(!arr.is_null(0)); -/// assert!(!arr.is_valid(1)); -/// assert!(arr.is_null(1)); -/// assert_eq!(false, arr.value(2)); -/// assert!(arr.is_valid(2)); -/// assert!(!arr.is_null(2)); -/// assert_eq!(true, arr.value(3)); -/// assert!(arr.is_valid(3)); -/// assert!(!arr.is_null(3)); +/// let mut b = BooleanBuilder::new(); +/// b.append_value(true); +/// b.append_null(); +/// b.append_value(false); +/// b.append_value(true); +/// let arr = b.finish(); +/// +/// assert_eq!(4, arr.len()); +/// assert_eq!(1, arr.null_count()); +/// assert_eq!(true, arr.value(0)); +/// assert!(arr.is_valid(0)); +/// assert!(!arr.is_null(0)); +/// assert!(!arr.is_valid(1)); +/// assert!(arr.is_null(1)); +/// assert_eq!(false, arr.value(2)); +/// assert!(arr.is_valid(2)); +/// assert!(!arr.is_null(2)); +/// assert_eq!(true, arr.value(3)); +/// assert!(arr.is_valid(3)); +/// assert!(!arr.is_null(3)); /// ``` #[derive(Debug)] pub struct BooleanBuilder { @@ -132,7 +127,7 @@ impl BooleanBuilder { /// /// Returns an error if the slices are of different lengths #[inline] - pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<()> { + pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<(), ArrowError> { if values.len() != is_valid.len() { Err(ArrowError::InvalidArgumentError( "Value and validity lengths must be equal".to_string(), @@ -150,12 +145,39 @@ impl BooleanBuilder { let null_bit_buffer = self.null_buffer_builder.finish(); let builder = ArrayData::builder(DataType::Boolean) .len(len) - .add_buffer(self.values_builder.finish()) - .null_bit_buffer(null_bit_buffer); + .add_buffer(self.values_builder.finish().into_inner()) + .nulls(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + BooleanArray::from(array_data) + } + + /// Builds the [BooleanArray] without resetting the builder. + pub fn finish_cloned(&self) -> BooleanArray { + let len = self.len(); + let nulls = self.null_buffer_builder.finish_cloned(); + let value_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let builder = ArrayData::builder(DataType::Boolean) + .len(len) + .add_buffer(value_buffer) + .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; BooleanArray::from(array_data) } + + /// Returns the current values buffer as a slice + /// + /// Boolean values are bit-packed into bytes. To extract the i-th boolean + /// from the bytes, you can use `arrow_buffer::bit_util::get_bit()`. + pub fn values_slice(&self) -> &[u8] { + self.values_builder.as_slice() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } } impl ArrayBuilder for BooleanBuilder { @@ -179,21 +201,30 @@ impl ArrayBuilder for BooleanBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl Extend> for BooleanBuilder { + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } } #[cfg(test)] mod tests { use super::*; - use crate::{array::Array, buffer::Buffer}; + use crate::Array; #[test] fn test_boolean_array_builder() { @@ -209,21 +240,20 @@ mod tests { } let arr = builder.finish(); - assert_eq!(&buf, arr.values()); + assert_eq!(&buf, arr.values().inner()); assert_eq!(10, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); for i in 0..10 { assert!(!arr.is_null(i)); assert!(arr.is_valid(i)); - assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {}", i) + assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {i}") } } #[test] fn test_boolean_array_builder_append_slice() { - let arr1 = - BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); + let arr1 = BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); let mut builder = BooleanArray::builder(0); builder.append_slice(&[true, false]); @@ -258,6 +288,41 @@ mod tests { let array = builder.finish(); assert_eq!(0, array.null_count()); - assert!(array.data().null_buffer().is_none()); + assert!(array.nulls().is_none()); + } + + #[test] + fn test_boolean_array_builder_finish_cloned() { + let mut builder = BooleanArray::builder(16); + builder.append_option(Some(true)); + builder.append_value(false); + builder.append_slice(&[true, false, true]); + let mut array = builder.finish_cloned(); + assert_eq!(3, array.true_count()); + assert_eq!(2, array.false_count()); + + builder + .append_values(&[false, false, true], &[true, true, true]) + .unwrap(); + + array = builder.finish(); + assert_eq!(4, array.true_count()); + assert_eq!(4, array.false_count()); + + assert_eq!(0, array.null_count()); + assert!(array.nulls().is_none()); + } + + #[test] + fn test_extend() { + let mut builder = BooleanBuilder::new(); + builder.extend([false, false, true, false, false].into_iter().map(Some)); + builder.extend([true, true, false].into_iter().map(Some)); + let array = builder.finish(); + let values = array.iter().map(|x| x.unwrap()).collect::>(); + assert_eq!( + &values, + &[false, false, true, false, false, true, true, false] + ) } } diff --git a/arrow-array/src/builder/buffer_builder.rs b/arrow-array/src/builder/buffer_builder.rs new file mode 100644 index 000000000000..ab67669febb8 --- /dev/null +++ b/arrow-array/src/builder/buffer_builder.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub use arrow_buffer::BufferBuilder; +use half::f16; + +use crate::types::*; + +/// Buffer builder for signed 8-bit integer type. +pub type Int8BufferBuilder = BufferBuilder; +/// Buffer builder for signed 16-bit integer type. +pub type Int16BufferBuilder = BufferBuilder; +/// Buffer builder for signed 32-bit integer type. +pub type Int32BufferBuilder = BufferBuilder; +/// Buffer builder for signed 64-bit integer type. +pub type Int64BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 8-bit integer type. +pub type UInt8BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 16-bit integer type. +pub type UInt16BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 32-bit integer type. +pub type UInt32BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 64-bit integer type. +pub type UInt64BufferBuilder = BufferBuilder; +/// Buffer builder for 16-bit floating point type. +pub type Float16BufferBuilder = BufferBuilder; +/// Buffer builder for 32-bit floating point type. +pub type Float32BufferBuilder = BufferBuilder; +/// Buffer builder for 64-bit floating point type. +pub type Float64BufferBuilder = BufferBuilder; + +/// Buffer builder for 128-bit decimal type. +pub type Decimal128BufferBuilder = BufferBuilder<::Native>; +/// Buffer builder for 256-bit decimal type. +pub type Decimal256BufferBuilder = BufferBuilder<::Native>; + +/// Buffer builder for timestamp type of second unit. +pub type TimestampSecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of millisecond unit. +pub type TimestampMillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of microsecond unit. +pub type TimestampMicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of nanosecond unit. +pub type TimestampNanosecondBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for 32-bit date type. +pub type Date32BufferBuilder = BufferBuilder<::Native>; +/// Buffer builder for 64-bit date type. +pub type Date64BufferBuilder = BufferBuilder<::Native>; + +/// Buffer builder for 32-bit elaspsed time since midnight of second unit. +pub type Time32SecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 32-bit elaspsed time since midnight of millisecond unit. +pub type Time32MillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 64-bit elaspsed time since midnight of microsecond unit. +pub type Time64MicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 64-bit elaspsed time since midnight of nanosecond unit. +pub type Time64NanosecondBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for “calendar” interval in months. +pub type IntervalYearMonthBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for “calendar” interval in days and milliseconds. +pub type IntervalDayTimeBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder “calendar” interval in months, days, and nanoseconds. +pub type IntervalMonthDayNanoBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for elaspsed time of second unit. +pub type DurationSecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of milliseconds unit. +pub type DurationMillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of microseconds unit. +pub type DurationMicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of nanoseconds unit. +pub type DurationNanosecondBufferBuilder = + BufferBuilder<::Native>; + +#[cfg(test)] +mod tests { + use crate::builder::{ArrayBuilder, Int32BufferBuilder, Int8Builder, UInt8BufferBuilder}; + use crate::Array; + + #[test] + fn test_builder_i32_empty() { + let mut b = Int32BufferBuilder::new(5); + assert_eq!(0, b.len()); + assert_eq!(16, b.capacity()); + let a = b.finish(); + assert_eq!(0, a.len()); + } + + #[test] + fn test_builder_i32_alloc_zero_bytes() { + let mut b = Int32BufferBuilder::new(0); + b.append(123); + let a = b.finish(); + assert_eq!(4, a.len()); + } + + #[test] + fn test_builder_i32() { + let mut b = Int32BufferBuilder::new(5); + for i in 0..5 { + b.append(i); + } + assert_eq!(16, b.capacity()); + let a = b.finish(); + assert_eq!(20, a.len()); + } + + #[test] + fn test_builder_i32_grow_buffer() { + let mut b = Int32BufferBuilder::new(2); + assert_eq!(16, b.capacity()); + for i in 0..20 { + b.append(i); + } + assert_eq!(32, b.capacity()); + let a = b.finish(); + assert_eq!(80, a.len()); + } + + #[test] + fn test_builder_finish() { + let mut b = Int32BufferBuilder::new(5); + assert_eq!(16, b.capacity()); + for i in 0..10 { + b.append(i); + } + let mut a = b.finish(); + assert_eq!(40, a.len()); + assert_eq!(0, b.len()); + assert_eq!(0, b.capacity()); + + // Try build another buffer after cleaning up. + for i in 0..20 { + b.append(i) + } + assert_eq!(32, b.capacity()); + a = b.finish(); + assert_eq!(80, a.len()); + } + + #[test] + fn test_reserve() { + let mut b = UInt8BufferBuilder::new(2); + assert_eq!(64, b.capacity()); + b.reserve(64); + assert_eq!(64, b.capacity()); + b.reserve(65); + assert_eq!(128, b.capacity()); + + let mut b = Int32BufferBuilder::new(2); + assert_eq!(16, b.capacity()); + b.reserve(16); + assert_eq!(16, b.capacity()); + b.reserve(17); + assert_eq!(32, b.capacity()); + } + + #[test] + fn test_append_slice() { + let mut b = UInt8BufferBuilder::new(0); + b.append_slice(b"Hello, "); + b.append_slice(b"World!"); + let buffer = b.finish(); + assert_eq!(13, buffer.len()); + + let mut b = Int32BufferBuilder::new(0); + b.append_slice(&[32, 54]); + let buffer = b.finish(); + assert_eq!(8, buffer.len()); + } + + #[test] + fn test_append_values() { + let mut a = Int8Builder::new(); + a.append_value(1); + a.append_null(); + a.append_value(-2); + assert_eq!(a.len(), 3); + + // append values + let values = &[1, 2, 3, 4]; + let is_valid = &[true, true, false, true]; + a.append_values(values, is_valid); + + assert_eq!(a.len(), 7); + let array = a.finish(); + assert_eq!(array.value(0), 1); + assert!(array.is_null(1)); + assert_eq!(array.value(2), -2); + assert_eq!(array.value(3), 1); + assert_eq!(array.value(4), 2); + assert!(array.is_null(5)); + assert_eq!(array.value(6), 4); + } +} diff --git a/arrow/src/array/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs similarity index 64% rename from arrow/src/array/builder/fixed_size_binary_builder.rs rename to arrow-array/src/builder/fixed_size_binary_builder.rs index 30c25e0a62b9..65072a09f603 100644 --- a/arrow/src/array/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -15,16 +15,31 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{ - ArrayBuilder, ArrayData, ArrayRef, FixedSizeBinaryArray, UInt8BufferBuilder, -}; -use crate::datatypes::DataType; -use crate::error::{ArrowError, Result}; +use crate::builder::{ArrayBuilder, UInt8BufferBuilder}; +use crate::{ArrayRef, FixedSizeBinaryArray}; +use arrow_buffer::Buffer; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use super::NullBufferBuilder; - +/// Builder for [`FixedSizeBinaryArray`] +/// ``` +/// # use arrow_array::builder::FixedSizeBinaryBuilder; +/// # use arrow_array::Array; +/// # +/// let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); +/// // [b"hello", null, b"arrow"] +/// builder.append_value(b"hello").unwrap(); +/// builder.append_null(); +/// builder.append_value(b"arrow").unwrap(); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value(0), b"hello"); +/// assert!(array.is_null(1)); +/// assert_eq!(array.value(2), b"arrow"); +/// ``` #[derive(Debug)] pub struct FixedSizeBinaryBuilder { values_builder: UInt8BufferBuilder, @@ -43,8 +58,7 @@ impl FixedSizeBinaryBuilder { pub fn with_capacity(capacity: usize, byte_width: i32) -> Self { assert!( byte_width >= 0, - "value length ({}) of the array must >= 0", - byte_width + "value length ({byte_width}) of the array must >= 0" ); Self { values_builder: UInt8BufferBuilder::new(capacity * byte_width as usize), @@ -58,10 +72,11 @@ impl FixedSizeBinaryBuilder { /// Automatically update the null buffer to delimit the slice appended in as a /// distinct value element. #[inline] - pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> { + pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<(), ArrowError> { if self.value_length != value.as_ref().len() as i32 { Err(ArrowError::InvalidArgumentError( - "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths".to_string() + "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths" + .to_string(), )) } else { self.values_builder.append_slice(value.as_ref()); @@ -81,14 +96,30 @@ impl FixedSizeBinaryBuilder { /// Builds the [`FixedSizeBinaryArray`] and reset this builder. pub fn finish(&mut self) -> FixedSizeBinaryArray { let array_length = self.len(); - let array_data_builder = - ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(self.values_builder.finish()) - .null_bit_buffer(self.null_buffer_builder.finish()) - .len(array_length); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(self.values_builder.finish()) + .nulls(self.null_buffer_builder.finish()) + .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; FixedSizeBinaryArray::from(array_data) } + + /// Builds the [`FixedSizeBinaryArray`] without resetting the builder. + pub fn finish_cloned(&self) -> FixedSizeBinaryArray { + let array_length = self.len(); + let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(values_buffer) + .nulls(self.null_buffer_builder.finish_cloned()) + .len(array_length); + let array_data = unsafe { array_data_builder.build_unchecked() }; + FixedSizeBinaryArray::from(array_data) + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } } impl ArrayBuilder for FixedSizeBinaryBuilder { @@ -112,24 +143,22 @@ impl ArrayBuilder for FixedSizeBinaryBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } } #[cfg(test)] mod tests { use super::*; - use crate::array::Array; - use crate::array::FixedSizeBinaryArray; - use crate::datatypes::DataType; + use crate::Array; #[test] fn test_fixed_size_binary_builder() { @@ -148,6 +177,36 @@ mod tests { assert_eq!(5, array.value_length()); } + #[test] + fn test_fixed_size_binary_builder_finish_cloned() { + let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); + + // [b"hello", null, "arrow"] + builder.append_value(b"hello").unwrap(); + builder.append_null(); + builder.append_value(b"arrow").unwrap(); + let mut array: FixedSizeBinaryArray = builder.finish_cloned(); + + assert_eq!(&DataType::FixedSizeBinary(5), array.data_type()); + assert_eq!(3, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(10, array.value_offset(2)); + assert_eq!(5, array.value_length()); + + // [b"finis", null, "clone"] + builder.append_value(b"finis").unwrap(); + builder.append_null(); + builder.append_value(b"clone").unwrap(); + + array = builder.finish(); + + assert_eq!(&DataType::FixedSizeBinary(5), array.data_type()); + assert_eq!(6, array.len()); + assert_eq!(2, array.null_count()); + assert_eq!(25, array.value_offset(5)); + assert_eq!(5, array.value_length()); + } + #[test] fn test_fixed_size_binary_builder_with_zero_value_length() { let mut builder = FixedSizeBinaryBuilder::new(0); diff --git a/arrow-array/src/builder/fixed_size_list_builder.rs b/arrow-array/src/builder/fixed_size_list_builder.rs new file mode 100644 index 000000000000..5dff67650687 --- /dev/null +++ b/arrow-array/src/builder/fixed_size_list_builder.rs @@ -0,0 +1,492 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::ArrayBuilder; +use crate::{ArrayRef, FixedSizeListArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_schema::{Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`FixedSizeListArray`] +/// ``` +/// use arrow_array::{builder::{Int32Builder, FixedSizeListBuilder}, Array, Int32Array}; +/// let values_builder = Int32Builder::new(); +/// let mut builder = FixedSizeListBuilder::new(values_builder, 3); +/// +/// // [[0, 1, 2], null, [3, null, 5], [6, 7, null]] +/// builder.values().append_value(0); +/// builder.values().append_value(1); +/// builder.values().append_value(2); +/// builder.append(true); +/// builder.values().append_null(); +/// builder.values().append_null(); +/// builder.values().append_null(); +/// builder.append(false); +/// builder.values().append_value(3); +/// builder.values().append_null(); +/// builder.values().append_value(5); +/// builder.append(true); +/// builder.values().append_value(6); +/// builder.values().append_value(7); +/// builder.values().append_null(); +/// builder.append(true); +/// let list_array = builder.finish(); +/// assert_eq!( +/// *list_array.value(0), +/// Int32Array::from(vec![Some(0), Some(1), Some(2)]) +/// ); +/// assert!(list_array.is_null(1)); +/// assert_eq!( +/// *list_array.value(2), +/// Int32Array::from(vec![Some(3), None, Some(5)]) +/// ); +/// assert_eq!( +/// *list_array.value(3), +/// Int32Array::from(vec![Some(6), Some(7), None]) +/// ) +/// ``` +/// +#[derive(Debug)] +pub struct FixedSizeListBuilder { + null_buffer_builder: NullBufferBuilder, + values_builder: T, + list_len: i32, + field: Option, +} + +impl FixedSizeListBuilder { + /// Creates a new [`FixedSizeListBuilder`] from a given values array builder + /// `value_length` is the number of values within each array + pub fn new(values_builder: T, value_length: i32) -> Self { + let capacity = values_builder + .len() + .checked_div(value_length as _) + .unwrap_or_default(); + + Self::with_capacity(values_builder, value_length, capacity) + } + + /// Creates a new [`FixedSizeListBuilder`] from a given values array builder + /// `value_length` is the number of values within each array + /// `capacity` is the number of items to pre-allocate space for in this builder + pub fn with_capacity(values_builder: T, value_length: i32, capacity: usize) -> Self { + Self { + null_buffer_builder: NullBufferBuilder::new(capacity), + values_builder, + list_len: value_length, + field: None, + } + } + + /// Override the field passed to [`FixedSizeListArray::new`] + /// + /// By default, a nullable field is created with the name `item` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `T` + pub fn with_field(self, field: impl Into) -> Self { + Self { + field: Some(field.into()), + ..self + } + } +} + +impl ArrayBuilder for FixedSizeListBuilder +where + T: 'static, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl FixedSizeListBuilder +where + T: 'static, +{ + /// Returns the child array builder as a mutable reference. + /// + /// This mutable reference can be used to append values into the child array builder, + /// but you must call [`append`](#method.append) to delimit each distinct list value. + pub fn values(&mut self) -> &mut T { + &mut self.values_builder + } + + /// Returns the length of the list + pub fn value_length(&self) -> i32 { + self.list_len + } + + /// Finish the current fixed-length list array slot + #[inline] + pub fn append(&mut self, is_valid: bool) { + self.null_buffer_builder.append(is_valid); + } + + /// Builds the [`FixedSizeListBuilder`] and reset this builder. + pub fn finish(&mut self) -> FixedSizeListArray { + let len = self.len(); + let values = self.values_builder.finish(); + let nulls = self.null_buffer_builder.finish(); + + assert_eq!( + values.len(), len * self.list_len as usize, + "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", + values.len(), + self.list_len, + len, + ); + + let field = self + .field + .clone() + .unwrap_or_else(|| Arc::new(Field::new("item", values.data_type().clone(), true))); + + FixedSizeListArray::new(field, self.list_len, values, nulls) + } + + /// Builds the [`FixedSizeListBuilder`] without resetting the builder. + pub fn finish_cloned(&self) -> FixedSizeListArray { + let len = self.len(); + let values = self.values_builder.finish_cloned(); + let nulls = self.null_buffer_builder.finish_cloned(); + + assert_eq!( + values.len(), len * self.list_len as usize, + "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", + values.len(), + self.list_len, + len, + ); + + let field = self + .field + .clone() + .unwrap_or_else(|| Arc::new(Field::new("item", values.data_type().clone(), true))); + + FixedSizeListArray::new(field, self.list_len, values, nulls) + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::DataType; + + use crate::builder::Int32Builder; + use crate::Array; + use crate::Int32Array; + + fn make_list_builder( + include_null_element: bool, + include_null_in_values: bool, + ) -> FixedSizeListBuilder> { + let values_builder = Int32Builder::new(); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + + builder.values().append_value(2); + builder.values().append_value(3); + builder.values().append_value(4); + builder.append(true); + + if include_null_element { + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + } else { + builder.values().append_value(2); + builder.values().append_value(3); + builder.values().append_value(4); + builder.append(true); + } + + if include_null_in_values { + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + } else { + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + } + + builder + } + + #[test] + fn test_fixed_size_list_array_builder() { + let mut builder = make_list_builder(true, true); + + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + fn test_fixed_size_list_array_builder_with_field() { + let builder = make_list_builder(false, false); + let mut builder = builder.with_field(Field::new("list_element", DataType::Int32, false)); + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + fn test_fixed_size_list_array_builder_with_field_and_null() { + let builder = make_list_builder(true, false); + let mut builder = builder.with_field(Field::new("list_element", DataType::Int32, false)); + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic(expected = "Found unmasked nulls for non-nullable FixedSizeListArray")] + fn test_fixed_size_list_array_builder_with_field_null_panic() { + let builder = make_list_builder(true, true); + let mut builder = builder.with_field(Field::new("list_item", DataType::Int32, false)); + + builder.finish(); + } + + #[test] + #[should_panic(expected = "FixedSizeListArray expected data type Int64 got Int32")] + fn test_fixed_size_list_array_builder_with_field_type_panic() { + let values_builder = Int32Builder::new(); + let builder = FixedSizeListBuilder::new(values_builder, 3); + let mut builder = builder.with_field(Field::new("list_item", DataType::Int64, true)); + + // [[0, 1, 2], null, [3, null, 5], [6, 7, null]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + + builder.finish(); + } + + #[test] + fn test_fixed_size_list_array_builder_cloned_with_field() { + let builder = make_list_builder(true, true); + let builder = builder.with_field(Field::new("list_element", DataType::Int32, true)); + + let list_array = builder.finish_cloned(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic(expected = "Found unmasked nulls for non-nullable FixedSizeListArray")] + fn test_fixed_size_list_array_builder_cloned_with_field_null_panic() { + let builder = make_list_builder(true, true); + let builder = builder.with_field(Field::new("list_item", DataType::Int32, false)); + + builder.finish_cloned(); + } + + #[test] + fn test_fixed_size_list_array_builder_cloned_with_field_and_null() { + let builder = make_list_builder(true, false); + let mut builder = builder.with_field(Field::new("list_element", DataType::Int32, false)); + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic(expected = "FixedSizeListArray expected data type Int64 got Int32")] + fn test_fixed_size_list_array_builder_cloned_with_field_type_panic() { + let builder = make_list_builder(false, false); + let builder = builder.with_field(Field::new("list_item", DataType::Int64, true)); + + builder.finish_cloned(); + } + + #[test] + fn test_fixed_size_list_array_builder_finish_cloned() { + let mut builder = make_list_builder(true, true); + + let mut list_array = builder.finish_cloned(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(3, list_array.value_length()); + + builder.values().append_value(6); + builder.values().append_value(7); + builder.values().append_null(); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(6, list_array.len()); + assert_eq!(2, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + fn test_fixed_size_list_array_builder_with_field_empty() { + let values_builder = Int32Array::builder(0); + let mut builder = FixedSizeListBuilder::new(values_builder, 3).with_field(Field::new( + "list_item", + DataType::Int32, + false, + )); + assert!(builder.is_empty()); + let arr = builder.finish(); + assert_eq!(0, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_fixed_size_list_array_builder_cloned_with_field_empty() { + let values_builder = Int32Array::builder(0); + let builder = FixedSizeListBuilder::new(values_builder, 3).with_field(Field::new( + "list_item", + DataType::Int32, + false, + )); + assert!(builder.is_empty()); + let arr = builder.finish_cloned(); + assert_eq!(0, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_fixed_size_list_array_builder_empty() { + let values_builder = Int32Array::builder(5); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + assert!(builder.is_empty()); + let arr = builder.finish(); + assert_eq!(0, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_fixed_size_list_array_builder_finish() { + let values_builder = Int32Array::builder(5); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish(); + assert_eq!(2, arr.len()); + assert_eq!(0, builder.len()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(1, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + #[should_panic( + expected = "Length of the child array (10) must be the multiple of the value length (3) and the array length (3)." + )] + fn test_fixed_size_list_array_builder_fail() { + let values_builder = Int32Array::builder(5); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + builder.values().append_slice(&[7, 8, 9, 10]); + builder.append(true); + + builder.finish(); + } +} diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs new file mode 100644 index 000000000000..3cde76c4a039 --- /dev/null +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -0,0 +1,514 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::bytes::ByteArrayNativeType; +use std::{any::Any, sync::Arc}; + +use crate::{ + types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, Utf8Type}, + ArrayRef, ArrowPrimitiveType, RunArray, +}; + +use super::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; + +use arrow_buffer::ArrowNativeType; + +/// Builder for [`RunArray`] of [`GenericByteArray`](crate::array::GenericByteArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::GenericByteRunBuilder; +/// # use arrow_array::{GenericByteArray, BinaryArray}; +/// # use arrow_array::types::{BinaryType, Int16Type}; +/// # use arrow_array::{Array, Int16Array}; +/// # use arrow_array::cast::AsArray; +/// +/// let mut builder = +/// GenericByteRunBuilder::::new(); +/// builder.extend([Some(b"abc"), Some(b"abc"), None, Some(b"def")].into_iter()); +/// builder.append_value(b"def"); +/// builder.append_null(); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[2, 3, 5, 6]); +/// +/// let av = array.values(); +/// +/// assert!(!av.is_null(0)); +/// assert!(av.is_null(1)); +/// assert!(!av.is_null(2)); +/// assert!(av.is_null(3)); +/// +/// // Values are polymorphic and so require a downcast. +/// let ava: &BinaryArray = av.as_binary(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert_eq!(ava.value(2), b"def"); +/// ``` +#[derive(Debug)] +pub struct GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + run_ends_builder: PrimitiveBuilder, + values_builder: GenericByteBuilder, + current_value: Vec, + has_current_value: bool, + current_run_end_index: usize, + prev_run_end_index: usize, +} + +impl Default for GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + fn default() -> Self { + Self::new() + } +} + +impl GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + /// Creates a new `GenericByteRunBuilder` + pub fn new() -> Self { + Self { + run_ends_builder: PrimitiveBuilder::new(), + values_builder: GenericByteBuilder::::new(), + current_value: Vec::new(), + has_current_value: false, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } + + /// Creates a new `GenericByteRunBuilder` with the provided capacity + /// + /// `capacity`: the expected number of run-end encoded values. + /// `data_capacity`: the expected number of bytes of run end encoded values + pub fn with_capacity(capacity: usize, data_capacity: usize) -> Self { + Self { + run_ends_builder: PrimitiveBuilder::with_capacity(capacity), + values_builder: GenericByteBuilder::::with_capacity(capacity, data_capacity), + current_value: Vec::new(), + has_current_value: false, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } +} + +impl ArrayBuilder for GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the length of logical array encoded by + /// the eventual runs array. + fn len(&self) -> usize { + self.current_run_end_index + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, +{ + /// Appends optional value to the logical array encoded by the RunArray. + pub fn append_option(&mut self, input_value: Option>) { + match input_value { + Some(value) => self.append_value(value), + None => self.append_null(), + } + } + + /// Appends value to the logical array encoded by the RunArray. + pub fn append_value(&mut self, input_value: impl AsRef) { + let value: &[u8] = input_value.as_ref().as_ref(); + if !self.has_current_value { + self.append_run_end(); + self.current_value.extend_from_slice(value); + self.has_current_value = true; + } else if self.current_value.as_slice() != value { + self.append_run_end(); + self.current_value.clear(); + self.current_value.extend_from_slice(value); + } + self.current_run_end_index += 1; + } + + /// Appends null to the logical array encoded by the RunArray. + pub fn append_null(&mut self) { + if self.has_current_value { + self.append_run_end(); + self.current_value.clear(); + self.has_current_value = false; + } + self.current_run_end_index += 1; + } + + /// Creates the RunArray and resets the builder. + /// Panics if RunArray cannot be built. + pub fn finish(&mut self) -> RunArray { + // write the last run end to the array. + self.append_run_end(); + + // reset the run end index to zero. + self.current_value.clear(); + self.has_current_value = false; + self.current_run_end_index = 0; + self.prev_run_end_index = 0; + + // build the run encoded array by adding run_ends and values array as its children. + let run_ends_array = self.run_ends_builder.finish(); + let values_array = self.values_builder.finish(); + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + /// Creates the RunArray and without resetting the builder. + /// Panics if RunArray cannot be built. + pub fn finish_cloned(&self) -> RunArray { + let mut run_ends_array = self.run_ends_builder.finish_cloned(); + let mut values_array = self.values_builder.finish_cloned(); + + // Add current run if one exists + if self.prev_run_end_index != self.current_run_end_index { + let mut run_end_builder = run_ends_array.into_builder().unwrap(); + let mut values_builder = values_array.into_builder().unwrap(); + self.append_run_end_with_builders(&mut run_end_builder, &mut values_builder); + run_ends_array = run_end_builder.finish(); + values_array = values_builder.finish(); + } + + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + // Appends the current run to the array. + fn append_run_end(&mut self) { + // empty array or the function called without appending any value. + if self.prev_run_end_index == self.current_run_end_index { + return; + } + let run_end_index = self.run_end_index_as_native(); + self.run_ends_builder.append_value(run_end_index); + if self.has_current_value { + let slice = self.current_value.as_slice(); + let native = unsafe { + // Safety: + // As self.current_value is created from V::Native. The value V::Native can be + // built back from the bytes without validations + V::Native::from_bytes_unchecked(slice) + }; + self.values_builder.append_value(native); + } else { + self.values_builder.append_null(); + } + self.prev_run_end_index = self.current_run_end_index; + } + + // Similar to `append_run_end` but on custom builders. + // Used in `finish_cloned` which is not suppose to mutate `self`. + fn append_run_end_with_builders( + &self, + run_ends_builder: &mut PrimitiveBuilder, + values_builder: &mut GenericByteBuilder, + ) { + let run_end_index = self.run_end_index_as_native(); + run_ends_builder.append_value(run_end_index); + if self.has_current_value { + let slice = self.current_value.as_slice(); + let native = unsafe { + // Safety: + // As self.current_value is created from V::Native. The value V::Native can be + // built back from the bytes without validations + V::Native::from_bytes_unchecked(slice) + }; + values_builder.append_value(native); + } else { + values_builder.append_null(); + } + } + + fn run_end_index_as_native(&self) -> R::Native { + R::Native::from_usize(self.current_run_end_index).unwrap_or_else(|| { + panic!( + "Cannot convert the value {} from `usize` to native form of arrow datatype {}", + self.current_run_end_index, + R::DATA_TYPE + ) + }) + } +} + +impl Extend> for GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, + S: AsRef, +{ + fn extend>>(&mut self, iter: T) { + for elem in iter { + self.append_option(elem); + } + } +} + +/// Builder for [`RunArray`] of [`StringArray`](crate::array::StringArray) +/// +/// ``` +/// // Create a run-end encoded array with run-end indexes data type as `i16`. +/// // The encoded values are Strings. +/// +/// # use arrow_array::builder::StringRunBuilder; +/// # use arrow_array::{Int16Array, StringArray}; +/// # use arrow_array::types::Int16Type; +/// # use arrow_array::cast::AsArray; +/// # +/// let mut builder = StringRunBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append_value("abc"); +/// builder.append_null(); +/// builder.extend([Some("def"), Some("def"), Some("abc")]); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &StringArray = av.as_string::(); +/// +/// assert_eq!(ava.value(0), "abc"); +/// assert!(av.is_null(1)); +/// assert_eq!(ava.value(2), "def"); +/// assert_eq!(ava.value(3), "abc"); +/// +/// ``` +pub type StringRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`LargeStringArray`](crate::array::LargeStringArray) +pub type LargeStringRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`BinaryArray`](crate::array::BinaryArray) +/// +/// ``` +/// // Create a run-end encoded array with run-end indexes data type as `i16`. +/// // The encoded data is binary values. +/// +/// # use arrow_array::builder::BinaryRunBuilder; +/// # use arrow_array::{BinaryArray, Int16Array}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::Int16Type; +/// +/// let mut builder = BinaryRunBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append_value(b"abc"); +/// builder.append_null(); +/// builder.extend([Some(b"def"), Some(b"def"), Some(b"abc")]); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &BinaryArray = av.as_binary(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert!(av.is_null(1)); +/// assert_eq!(ava.value(2), b"def"); +/// assert_eq!(ava.value(3), b"abc"); +/// +/// ``` +pub type BinaryRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`LargeBinaryArray`](crate::array::LargeBinaryArray) +pub type LargeBinaryRunBuilder = GenericByteRunBuilder; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Array; + use crate::cast::AsArray; + use crate::types::{Int16Type, Int32Type}; + use crate::GenericByteArray; + use crate::Int16RunArray; + + fn test_bytes_run_builder(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteRunBuilder::::new(); + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_null(); + builder.append_null(); + builder.append_value(values[1]); + builder.append_value(values[1]); + builder.append_value(values[2]); + builder.append_value(values[2]); + builder.append_value(values[2]); + builder.append_value(values[2]); + let array = builder.finish(); + + assert_eq!(array.len(), 11); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[3, 5, 7, 11]); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(*ava.value(0), *values[0]); + assert!(ava.is_null(1)); + assert_eq!(*ava.value(2), *values[1]); + assert_eq!(*ava.value(3), *values[2]); + } + + #[test] + fn test_string_run_builder() { + test_bytes_run_builder::(vec!["abc", "def", "ghi"]); + } + + #[test] + fn test_string_run_builder_with_empty_strings() { + test_bytes_run_builder::(vec!["abc", "", "ghi"]); + } + + #[test] + fn test_binary_run_builder() { + test_bytes_run_builder::(vec![b"abc", b"def", b"ghi"]); + } + + fn test_bytes_run_builder_finish_cloned(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteRunBuilder::::new(); + + builder.append_value(values[0]); + builder.append_null(); + builder.append_value(values[1]); + builder.append_value(values[1]); + builder.append_value(values[0]); + let mut array: Int16RunArray = builder.finish_cloned(); + + assert_eq!(array.len(), 5); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava.value(0), values[0]); + assert!(ava.is_null(1)); + assert_eq!(ava.value(2), values[1]); + assert_eq!(ava.value(3), values[0]); + + // Append last value before `finish_cloned` (`value[0]`) again and ensure it has only + // one entry in final output. + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_value(values[1]); + array = builder.finish(); + + assert_eq!(array.len(), 8); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[1, 2, 4, 7, 8]); + + // Values are polymorphic and so require a downcast. + let av2 = array.values(); + let ava2: &GenericByteArray = + av2.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava2.value(0), values[0]); + assert!(ava2.is_null(1)); + assert_eq!(ava2.value(2), values[1]); + // The value appended before and after `finish_cloned` has only one entry. + assert_eq!(ava2.value(3), values[0]); + assert_eq!(ava2.value(4), values[1]); + } + + #[test] + fn test_string_run_builder_finish_cloned() { + test_bytes_run_builder_finish_cloned::(vec!["abc", "def", "ghi"]); + } + + #[test] + fn test_binary_run_builder_finish_cloned() { + test_bytes_run_builder_finish_cloned::(vec![b"abc", b"def", b"ghi"]); + } + + #[test] + fn test_extend() { + let mut builder = StringRunBuilder::::new(); + builder.extend(["a", "a", "a", "", "", "b", "b"].into_iter().map(Some)); + builder.extend(["b", "cupcakes", "cupcakes"].into_iter().map(Some)); + let array = builder.finish(); + + assert_eq!(array.len(), 10); + assert_eq!(array.run_ends().values(), &[3, 5, 8, 10]); + + let str_array = array.values().as_string::(); + assert_eq!(str_array.value(0), "a"); + assert_eq!(str_array.value(1), ""); + assert_eq!(str_array.value(2), "b"); + assert_eq!(str_array.value(3), "cupcakes"); + } +} diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs new file mode 100644 index 000000000000..a465f3e4d60e --- /dev/null +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -0,0 +1,543 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder}; +use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use crate::{ArrayRef, GenericByteArray, OffsetSizeTrait}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_data::ArrayDataBuilder; +use std::any::Any; +use std::fmt::Write; +use std::sync::Arc; + +/// Builder for [`GenericByteArray`] +/// +/// For building strings, see docs on [`GenericStringBuilder`]. +/// For building binary, see docs on [`GenericBinaryBuilder`]. +pub struct GenericByteBuilder { + value_builder: UInt8BufferBuilder, + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, +} + +impl GenericByteBuilder { + /// Creates a new [`GenericByteBuilder`]. + pub fn new() -> Self { + Self::with_capacity(1024, 1024) + } + + /// Creates a new [`GenericByteBuilder`]. + /// + /// - `item_capacity` is the number of items to pre-allocate. + /// The size of the preallocated buffer of offsets is the number of items plus one. + /// - `data_capacity` is the total number of bytes of data to pre-allocate + /// (for all items, not per item). + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_builder = BufferBuilder::::new(item_capacity + 1); + offsets_builder.append(T::Offset::from_usize(0).unwrap()); + Self { + value_builder: UInt8BufferBuilder::new(data_capacity), + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(item_capacity), + } + } + + /// Creates a new [`GenericByteBuilder`] from buffers. + /// + /// # Safety + /// + /// This doesn't verify buffer contents as it assumes the buffers are from + /// existing and valid [`GenericByteArray`]. + pub unsafe fn new_from_buffer( + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, + null_buffer: Option, + ) -> Self { + let offsets_builder = BufferBuilder::::new_from_buffer(offsets_buffer); + let value_builder = BufferBuilder::::new_from_buffer(value_buffer); + + let null_buffer_builder = null_buffer + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1)) + .unwrap_or_else(|| NullBufferBuilder::new_with_len(offsets_builder.len() - 1)); + + Self { + offsets_builder, + value_builder, + null_buffer_builder, + } + } + + #[inline] + fn next_offset(&self) -> T::Offset { + T::Offset::from_usize(self.value_builder.len()).expect("byte array offset overflow") + } + + /// Appends a value into the builder. + /// + /// See the [GenericStringBuilder] documentation for examples of + /// incrementally building string values with multiple `write!` calls. + /// + /// # Panics + /// + /// Panics if the resulting length of [`Self::values_slice`] would exceed + /// `T::Offset::MAX` bytes. + /// + /// For example, this can happen with [`StringArray`] or [`BinaryArray`] + /// where the total length of all values exceeds 2GB + /// + /// [`StringArray`]: crate::StringArray + /// [`BinaryArray`]: crate::BinaryArray + #[inline] + pub fn append_value(&mut self, value: impl AsRef) { + self.value_builder.append_slice(value.as_ref().as_ref()); + self.null_buffer_builder.append(true); + self.offsets_builder.append(self.next_offset()); + } + + /// Append an `Option` value into the builder. + /// + /// - A `None` value will append a null value. + /// - A `Some` value will append the value. + /// + /// See [`Self::append_value`] for more panic information. + #[inline] + pub fn append_option(&mut self, value: Option>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Append a null value into the builder. + #[inline] + pub fn append_null(&mut self) { + self.null_buffer_builder.append(false); + self.offsets_builder.append(self.next_offset()); + } + + /// Builds the [`GenericByteArray`] and reset this builder. + pub fn finish(&mut self) -> GenericByteArray { + let array_type = T::DATA_TYPE; + let array_builder = ArrayDataBuilder::new(array_type) + .len(self.len()) + .add_buffer(self.offsets_builder.finish()) + .add_buffer(self.value_builder.finish()) + .nulls(self.null_buffer_builder.finish()); + + self.offsets_builder.append(self.next_offset()); + let array_data = unsafe { array_builder.build_unchecked() }; + GenericByteArray::from(array_data) + } + + /// Builds the [`GenericByteArray`] without resetting the builder. + pub fn finish_cloned(&self) -> GenericByteArray { + let array_type = T::DATA_TYPE; + let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let value_buffer = Buffer::from_slice_ref(self.value_builder.as_slice()); + let array_builder = ArrayDataBuilder::new(array_type) + .len(self.len()) + .add_buffer(offset_buffer) + .add_buffer(value_buffer) + .nulls(self.null_buffer_builder.finish_cloned()); + + let array_data = unsafe { array_builder.build_unchecked() }; + GenericByteArray::from(array_data) + } + + /// Returns the current values buffer as a slice + pub fn values_slice(&self) -> &[u8] { + self.value_builder.as_slice() + } + + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[T::Offset] { + self.offsets_builder.as_slice() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Returns the current null buffer as a mutable slice + pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { + self.null_buffer_builder.as_slice_mut() + } +} + +impl std::fmt::Debug for GenericByteBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}Builder", T::Offset::PREFIX, T::PREFIX)?; + f.debug_struct("") + .field("value_builder", &self.value_builder) + .field("offsets_builder", &self.offsets_builder) + .field("null_buffer_builder", &self.null_buffer_builder) + .finish() + } +} + +impl Default for GenericByteBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ArrayBuilder for GenericByteBuilder { + /// Returns the number of binary slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } +} + +impl> Extend> for GenericByteBuilder { + #[inline] + fn extend>>(&mut self, iter: I) { + for v in iter { + self.append_option(v) + } + } +} + +/// Array builder for [`GenericStringArray`][crate::GenericStringArray] +/// +/// Values can be appended using [`GenericByteBuilder::append_value`], and nulls with +/// [`GenericByteBuilder::append_null`]. +/// +/// This builder also implements [`std::fmt::Write`] with any written data +/// included in the next appended value. This allows using [`std::fmt::Display`] +/// with standard Rust idioms like `write!` and `writeln!` to write data +/// directly to the builder without intermediate allocations. +/// +/// # Example writing strings with `append_value` +/// ``` +/// # use arrow_array::builder::GenericStringBuilder; +/// let mut builder = GenericStringBuilder::::new(); +/// +/// // Write one string value +/// builder.append_value("foobarbaz"); +/// +/// // Write a second string +/// builder.append_value("v2"); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value(0), "foobarbaz"); +/// assert_eq!(array.value(1), "v2"); +/// ``` +/// +/// # Example incrementally writing strings with `std::fmt::Write` +/// +/// ``` +/// # use std::fmt::Write; +/// # use arrow_array::builder::GenericStringBuilder; +/// let mut builder = GenericStringBuilder::::new(); +/// +/// // Write data in multiple `write!` calls +/// write!(builder, "foo").unwrap(); +/// write!(builder, "bar").unwrap(); +/// // The next call to append_value finishes the current string +/// // including all previously written strings. +/// builder.append_value("baz"); +/// +/// // Write second value with a single write call +/// write!(builder, "v2").unwrap(); +/// // finish the value by calling append_value with an empty string +/// builder.append_value(""); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value(0), "foobarbaz"); +/// assert_eq!(array.value(1), "v2"); +/// ``` +/// +pub type GenericStringBuilder = GenericByteBuilder>; + +impl Write for GenericStringBuilder { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.value_builder.append_slice(s.as_bytes()); + Ok(()) + } +} + +/// Array builder for [`GenericBinaryArray`][crate::GenericBinaryArray] +/// +/// Values can be appended using [`GenericByteBuilder::append_value`], and nulls with +/// [`GenericByteBuilder::append_null`]. +/// +/// # Example +/// ``` +/// # use arrow_array::builder::GenericBinaryBuilder; +/// let mut builder = GenericBinaryBuilder::::new(); +/// +/// // Write data +/// builder.append_value("foo"); +/// +/// // Write second value +/// builder.append_value(&[0,1,2]); +/// +/// let array = builder.finish(); +/// // binary values +/// assert_eq!(array.value(0), b"foo"); +/// assert_eq!(array.value(1), b"\x00\x01\x02"); +/// ``` +pub type GenericBinaryBuilder = GenericByteBuilder>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Array; + use crate::GenericStringArray; + + fn _test_generic_binary_builder() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"hello", array.value(0)); + assert_eq!([] as [u8; 0], array.value(1)); + assert!(array.is_null(2)); + assert_eq!(b"rust", array.value(3)); + assert_eq!(O::from_usize(5).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(4).unwrap(), array.value_length(3)); + } + + #[test] + fn test_binary_builder() { + _test_generic_binary_builder::() + } + + #[test] + fn test_large_binary_builder() { + _test_generic_binary_builder::() + } + + fn _test_generic_binary_builder_all_nulls() { + let mut builder = GenericBinaryBuilder::::new(); + builder.append_null(); + builder.append_null(); + builder.append_null(); + assert_eq!(3, builder.len()); + assert!(!builder.is_empty()); + + let array = builder.finish(); + assert_eq!(3, array.null_count()); + assert_eq!(3, array.len()); + assert!(array.is_null(0)); + assert!(array.is_null(1)); + assert!(array.is_null(2)); + } + + #[test] + fn test_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + #[test] + fn test_large_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + fn _test_generic_binary_builder_reset() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + builder.finish(); + + assert!(builder.is_empty()); + + builder.append_value(b"parquet"); + builder.append_null(); + builder.append_value(b"arrow"); + builder.append_value(b""); + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"parquet", array.value(0)); + assert!(array.is_null(1)); + assert_eq!(b"arrow", array.value(2)); + assert_eq!(b"", array.value(1)); + assert_eq!(O::zero(), array.value_offsets()[0]); + assert_eq!(O::from_usize(7).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(5).unwrap(), array.value_length(2)); + } + + #[test] + fn test_binary_builder_reset() { + _test_generic_binary_builder_reset::() + } + + #[test] + fn test_large_binary_builder_reset() { + _test_generic_binary_builder_reset::() + } + + fn _test_generic_string_array_builder() { + let mut builder = GenericStringBuilder::::new(); + let owned = "arrow".to_owned(); + + builder.append_value("hello"); + builder.append_value(""); + builder.append_value(&owned); + builder.append_null(); + builder.append_option(Some("rust")); + builder.append_option(None::<&str>); + builder.append_option(None::); + assert_eq!(7, builder.len()); + + assert_eq!( + GenericStringArray::::from(vec![ + Some("hello"), + Some(""), + Some("arrow"), + None, + Some("rust"), + None, + None + ]), + builder.finish() + ); + } + + #[test] + fn test_string_array_builder() { + _test_generic_string_array_builder::() + } + + #[test] + fn test_large_string_array_builder() { + _test_generic_string_array_builder::() + } + + fn _test_generic_string_array_builder_finish() { + let mut builder = GenericStringBuilder::::with_capacity(3, 11); + + builder.append_value("hello"); + builder.append_value("rust"); + builder.append_null(); + + builder.finish(); + assert!(builder.is_empty()); + assert_eq!(&[O::zero()], builder.offsets_slice()); + + builder.append_value("arrow"); + builder.append_value("parquet"); + let arr = builder.finish(); + // array should not have null buffer because there is not `null` value. + assert!(arr.nulls().is_none()); + assert_eq!(GenericStringArray::::from(vec!["arrow", "parquet"]), arr,) + } + + #[test] + fn test_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() + } + + #[test] + fn test_large_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() + } + + fn _test_generic_string_array_builder_finish_cloned() { + let mut builder = GenericStringBuilder::::with_capacity(3, 11); + + builder.append_value("hello"); + builder.append_value("rust"); + builder.append_null(); + + let mut arr = builder.finish_cloned(); + assert!(!builder.is_empty()); + assert_eq!(3, arr.len()); + + builder.append_value("arrow"); + builder.append_value("parquet"); + arr = builder.finish(); + + assert!(arr.nulls().is_some()); + assert_eq!(&[O::zero()], builder.offsets_slice()); + assert_eq!(5, arr.len()); + } + + #[test] + fn test_string_array_builder_finish_cloned() { + _test_generic_string_array_builder_finish_cloned::() + } + + #[test] + fn test_large_string_array_builder_finish_cloned() { + _test_generic_string_array_builder_finish_cloned::() + } + + #[test] + fn test_extend() { + let mut builder = GenericStringBuilder::::new(); + builder.extend(["a", "b", "c", "", "a", "b", "c"].into_iter().map(Some)); + builder.extend(["d", "cupcakes", "hello"].into_iter().map(Some)); + let array = builder.finish(); + assert_eq!(array.value_offsets(), &[0, 1, 2, 3, 3, 4, 5, 6, 7, 15, 20]); + assert_eq!(array.value_data(), b"abcabcdcupcakeshello"); + } + + #[test] + fn test_write() { + let mut builder = GenericStringBuilder::::new(); + write!(builder, "foo").unwrap(); + builder.append_value(""); + writeln!(builder, "bar").unwrap(); + builder.append_value(""); + write!(builder, "fiz").unwrap(); + write!(builder, "buz").unwrap(); + builder.append_value(""); + let a = builder.finish(); + let r: Vec<_> = a.iter().flatten().collect(); + assert_eq!(r, &["foo", "bar\n", "fizbuz"]) + } +} diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs new file mode 100644 index 000000000000..285a4f035e24 --- /dev/null +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -0,0 +1,630 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; +use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType}; +use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray}; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType}; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`DictionaryArray`] of [`GenericByteArray`] +/// +/// For example to map a set of byte indices to String values. Note that +/// the use of a `HashMap` here will not scale to very large arrays or +/// result in an ordered dictionary. +#[derive(Debug)] +pub struct GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + state: ahash::RandomState, + /// Used to provide a lookup from string value to key type + /// + /// Note: usize's hash implementation is not used, instead the raw entry + /// API is used to store keys w.r.t the hash of the strings themselves + /// + dedup: HashMap, + + keys_builder: PrimitiveBuilder, + values_builder: GenericByteBuilder, +} + +impl Default for GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + fn default() -> Self { + Self::new() + } +} + +impl GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Creates a new `GenericByteDictionaryBuilder` + pub fn new() -> Self { + let keys_builder = PrimitiveBuilder::new(); + let values_builder = GenericByteBuilder::::new(); + Self { + state: Default::default(), + dedup: HashMap::with_capacity_and_hasher(keys_builder.capacity(), ()), + keys_builder, + values_builder, + } + } + + /// Creates a new `GenericByteDictionaryBuilder` with the provided capacities + /// + /// `keys_capacity`: the number of keys, i.e. length of array to build + /// `value_capacity`: the number of distinct dictionary values, i.e. size of dictionary + /// `data_capacity`: the total number of bytes of all distinct bytes in the dictionary + pub fn with_capacity( + keys_capacity: usize, + value_capacity: usize, + data_capacity: usize, + ) -> Self { + Self { + state: Default::default(), + dedup: Default::default(), + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder: GenericByteBuilder::::with_capacity(value_capacity, data_capacity), + } + } + + /// Creates a new `GenericByteDictionaryBuilder` from a keys capacity and a dictionary + /// which is initialized with the given values. + /// The indices of those dictionary values are used as keys. + /// + /// # Example + /// + /// ``` + /// # use arrow_array::builder::StringDictionaryBuilder; + /// # use arrow_array::{Int16Array, StringArray}; + /// + /// let dictionary_values = StringArray::from(vec![None, Some("abc"), Some("def")]); + /// + /// let mut builder = StringDictionaryBuilder::new_with_dictionary(3, &dictionary_values).unwrap(); + /// builder.append("def").unwrap(); + /// builder.append_null(); + /// builder.append("abc").unwrap(); + /// + /// let dictionary_array = builder.finish(); + /// + /// let keys = dictionary_array.keys(); + /// + /// assert_eq!(keys, &Int16Array::from(vec![Some(2), None, Some(1)])); + /// ``` + pub fn new_with_dictionary( + keys_capacity: usize, + dictionary_values: &GenericByteArray, + ) -> Result { + let state = ahash::RandomState::default(); + let dict_len = dictionary_values.len(); + + let mut dedup = HashMap::with_capacity_and_hasher(dict_len, ()); + + let values_len = dictionary_values.value_data().len(); + let mut values_builder = GenericByteBuilder::::with_capacity(dict_len, values_len); + + K::Native::from_usize(dictionary_values.len()) + .ok_or(ArrowError::DictionaryKeyOverflowError)?; + + for (idx, maybe_value) in dictionary_values.iter().enumerate() { + match maybe_value { + Some(value) => { + let value_bytes: &[u8] = value.as_ref(); + let hash = state.hash_one(value_bytes); + + let entry = dedup.raw_entry_mut().from_hash(hash, |idx: &usize| { + value_bytes == get_bytes(&values_builder, *idx) + }); + + if let RawEntryMut::Vacant(v) = entry { + v.insert_with_hasher(hash, idx, (), |idx| { + state.hash_one(get_bytes(&values_builder, *idx)) + }); + } + + values_builder.append_value(value); + } + None => values_builder.append_null(), + } + } + + Ok(Self { + state, + dedup, + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder, + }) + } +} + +impl ArrayBuilder for GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Returns the builder as an non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as an mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.keys_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Append a value to the array. Return an existing index + /// if already present in the values array or a new index if the + /// value is appended to the values array. + /// + /// Returns an error if the new index would overflow the key type. + pub fn append(&mut self, value: impl AsRef) -> Result { + let value_native: &T::Native = value.as_ref(); + let value_bytes: &[u8] = value_native.as_ref(); + + let state = &self.state; + let storage = &mut self.values_builder; + let hash = state.hash_one(value_bytes); + + let entry = self + .dedup + .raw_entry_mut() + .from_hash(hash, |idx| value_bytes == get_bytes(storage, *idx)); + + let key = match entry { + RawEntryMut::Occupied(entry) => K::Native::usize_as(*entry.into_key()), + RawEntryMut::Vacant(entry) => { + let idx = storage.len(); + storage.append_value(value); + + entry.insert_with_hasher(hash, idx, (), |idx| { + state.hash_one(get_bytes(storage, *idx)) + }); + + K::Native::from_usize(idx).ok_or(ArrowError::DictionaryKeyOverflowError)? + } + }; + self.keys_builder.append_value(key); + + Ok(key) + } + + /// Infallibly append a value to this builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + pub fn append_value(&mut self, value: impl AsRef) { + self.append(value).expect("dictionary key overflow"); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.keys_builder.append_null() + } + + /// Append an `Option` value into the builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_option(&mut self, value: Option>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Builds the `DictionaryArray` and reset this builder. + pub fn finish(&mut self) -> DictionaryArray { + self.dedup.clear(); + let values = self.values_builder.finish(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Builds the `DictionaryArray` without resetting the builder. + pub fn finish_cloned(&self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish_cloned(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.keys_builder.validity_slice() + } +} + +impl> Extend> + for GenericByteDictionaryBuilder +{ + #[inline] + fn extend>>(&mut self, iter: I) { + for v in iter { + self.append_option(v) + } + } +} + +fn get_bytes(values: &GenericByteBuilder, idx: usize) -> &[u8] { + let offsets = values.offsets_slice(); + let values = values.values_slice(); + + let end_offset = offsets[idx + 1].as_usize(); + let start_offset = offsets[idx].as_usize(); + + &values[start_offset..end_offset] +} + +/// Builder for [`DictionaryArray`] of [`StringArray`](crate::array::StringArray) +/// +/// ``` +/// // Create a dictionary array indexed by bytes whose values are Strings. +/// // It can thus hold up to 256 distinct string values. +/// +/// # use arrow_array::builder::StringDictionaryBuilder; +/// # use arrow_array::{Int8Array, StringArray}; +/// # use arrow_array::types::Int8Type; +/// +/// let mut builder = StringDictionaryBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append("abc").unwrap(); +/// builder.append_null(); +/// builder.append("def").unwrap(); +/// builder.append("def").unwrap(); +/// builder.append("abc").unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &StringArray = av.as_any().downcast_ref::().unwrap(); +/// +/// assert_eq!(ava.value(0), "abc"); +/// assert_eq!(ava.value(1), "def"); +/// +/// ``` +pub type StringDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`LargeStringArray`](crate::array::LargeStringArray) +pub type LargeStringDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`BinaryArray`](crate::array::BinaryArray) +/// +/// ``` +/// // Create a dictionary array indexed by bytes whose values are binary. +/// // It can thus hold up to 256 distinct binary values. +/// +/// # use arrow_array::builder::BinaryDictionaryBuilder; +/// # use arrow_array::{BinaryArray, Int8Array}; +/// # use arrow_array::types::Int8Type; +/// +/// let mut builder = BinaryDictionaryBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append(b"abc").unwrap(); +/// builder.append_null(); +/// builder.append(b"def").unwrap(); +/// builder.append(b"def").unwrap(); +/// builder.append(b"abc").unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &BinaryArray = av.as_any().downcast_ref::().unwrap(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert_eq!(ava.value(1), b"def"); +/// +/// ``` +pub type BinaryDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`LargeBinaryArray`](crate::array::LargeBinaryArray) +pub type LargeBinaryDictionaryBuilder = GenericByteDictionaryBuilder>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Int8Array; + use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type}; + use crate::{BinaryArray, StringArray}; + + fn test_bytes_dictionary_builder(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(*ava.value(0), *values[0]); + assert_eq!(*ava.value(1), *values[1]); + } + + #[test] + fn test_string_dictionary_builder() { + test_bytes_dictionary_builder::>(vec!["abc", "def"]); + } + + #[test] + fn test_binary_dictionary_builder() { + test_bytes_dictionary_builder::>(vec![b"abc", b"def"]); + } + + fn test_bytes_dictionary_builder_finish_cloned(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteDictionaryBuilder::::new(); + + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let mut array = builder.finish_cloned(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava.value(0), values[0]); + assert_eq!(ava.value(1), values[1]); + + builder.append(values[0]).unwrap(); + builder.append(values[2]).unwrap(); + builder.append(values[1]).unwrap(); + + array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![ + Some(0), + None, + Some(1), + Some(1), + Some(0), + Some(0), + Some(2), + Some(1) + ]) + ); + + // Values are polymorphic and so require a downcast. + let av2 = array.values(); + let ava2: &GenericByteArray = + av2.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava2.value(0), values[0]); + assert_eq!(ava2.value(1), values[1]); + assert_eq!(ava2.value(2), values[2]); + } + + #[test] + fn test_string_dictionary_builder_finish_cloned() { + test_bytes_dictionary_builder_finish_cloned::>(vec![ + "abc", "def", "ghi", + ]); + } + + #[test] + fn test_binary_dictionary_builder_finish_cloned() { + test_bytes_dictionary_builder_finish_cloned::>(vec![ + b"abc", b"def", b"ghi", + ]); + } + + fn test_bytes_dictionary_builder_with_existing_dictionary( + dictionary: GenericByteArray, + values: Vec<&T::Native>, + ) where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = + GenericByteDictionaryBuilder::::new_with_dictionary(6, &dictionary) + .unwrap(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + builder.append(values[2]).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(2), None, Some(1), Some(1), Some(2), Some(3)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert!(!ava.is_valid(0)); + assert_eq!(ava.value(1), values[1]); + assert_eq!(ava.value(2), values[0]); + assert_eq!(ava.value(3), values[2]); + } + + #[test] + fn test_string_dictionary_builder_with_existing_dictionary() { + test_bytes_dictionary_builder_with_existing_dictionary::>( + StringArray::from(vec![None, Some("def"), Some("abc")]), + vec!["abc", "def", "ghi"], + ); + } + + #[test] + fn test_binary_dictionary_builder_with_existing_dictionary() { + let values: Vec> = vec![None, Some(b"def"), Some(b"abc")]; + test_bytes_dictionary_builder_with_existing_dictionary::>( + BinaryArray::from(values), + vec![b"abc", b"def", b"ghi"], + ); + } + + fn test_bytes_dictionary_builder_with_reserved_null_value( + dictionary: GenericByteArray, + values: Vec<&T::Native>, + ) where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = + GenericByteDictionaryBuilder::::new_with_dictionary(4, &dictionary) + .unwrap(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let array = builder.finish(); + + assert!(array.is_null(1)); + assert!(!array.is_valid(1)); + + let keys = array.keys(); + + assert_eq!(keys.value(0), 1); + assert!(keys.is_null(1)); + // zero initialization is currently guaranteed by Buffer allocation and resizing + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + assert_eq!(keys.value(3), 1); + } + + #[test] + fn test_string_dictionary_builder_with_reserved_null_value() { + let v: Vec> = vec![None]; + test_bytes_dictionary_builder_with_reserved_null_value::>( + StringArray::from(v), + vec!["abc", "def"], + ); + } + + #[test] + fn test_binary_dictionary_builder_with_reserved_null_value() { + let values: Vec> = vec![None]; + test_bytes_dictionary_builder_with_reserved_null_value::>( + BinaryArray::from(values), + vec![b"abc", b"def"], + ); + } + + #[test] + fn test_extend() { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.extend(["a", "b", "c", "a", "b", "c"].into_iter().map(Some)); + builder.extend(["c", "d", "a"].into_iter().map(Some)); + let dict = builder.finish(); + assert_eq!(dict.keys().values(), &[0, 1, 2, 0, 1, 2, 2, 3, 0]); + assert_eq!(dict.values().len(), 4); + } +} diff --git a/arrow-array/src/builder/generic_bytes_view_builder.rs b/arrow-array/src/builder/generic_bytes_view_builder.rs new file mode 100644 index 000000000000..d12c2b7db468 --- /dev/null +++ b/arrow-array/src/builder/generic_bytes_view_builder.rs @@ -0,0 +1,733 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; +use arrow_data::ByteView; +use arrow_schema::ArrowError; +use hashbrown::hash_table::Entry; +use hashbrown::HashTable; + +use crate::builder::ArrayBuilder; +use crate::types::bytes::ByteArrayNativeType; +use crate::types::{BinaryViewType, ByteViewType, StringViewType}; +use crate::{ArrayRef, GenericByteViewArray}; + +const STARTING_BLOCK_SIZE: u32 = 8 * 1024; // 8KiB +const MAX_BLOCK_SIZE: u32 = 2 * 1024 * 1024; // 2MiB + +enum BlockSizeGrowthStrategy { + Fixed { size: u32 }, + Exponential { current_size: u32 }, +} + +impl BlockSizeGrowthStrategy { + fn next_size(&mut self) -> u32 { + match self { + Self::Fixed { size } => *size, + Self::Exponential { current_size } => { + if *current_size < MAX_BLOCK_SIZE { + // we have fixed start/end block sizes, so we can't overflow + *current_size = current_size.saturating_mul(2); + *current_size + } else { + MAX_BLOCK_SIZE + } + } + } + } +} + +/// A builder for [`GenericByteViewArray`] +/// +/// A [`GenericByteViewArray`] consists of a list of data blocks containing string data, +/// and a list of views into those buffers. +/// +/// See examples on [`StringViewBuilder`] and [`BinaryViewBuilder`] +/// +/// This builder can be used in two ways +/// +/// # Append Values +/// +/// To avoid bump allocating, this builder allocates data in fixed size blocks, configurable +/// using [`GenericByteViewBuilder::with_fixed_block_size`]. [`GenericByteViewBuilder::append_value`] +/// writes values larger than 12 bytes to the current in-progress block, with values smaller +/// than 12 bytes inlined into the views. If a value is appended that will not fit in the +/// in-progress block, it will be closed, and a new block of sufficient size allocated +/// +/// # Append Views +/// +/// Some use-cases may wish to reuse an existing allocation containing string data, for example, +/// when parsing data from a parquet data page. In such a case entire blocks can be appended +/// using [`GenericByteViewBuilder::append_block`] and then views into this block appended +/// using [`GenericByteViewBuilder::try_append_view`] +pub struct GenericByteViewBuilder { + views_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, + completed: Vec, + in_progress: Vec, + block_size: BlockSizeGrowthStrategy, + /// Some if deduplicating strings + /// map ` -> ` + string_tracker: Option<(HashTable, ahash::RandomState)>, + phantom: PhantomData, +} + +impl GenericByteViewBuilder { + /// Creates a new [`GenericByteViewBuilder`]. + pub fn new() -> Self { + Self::with_capacity(1024) + } + + /// Creates a new [`GenericByteViewBuilder`] with space for `capacity` string values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + views_builder: BufferBuilder::new(capacity), + null_buffer_builder: NullBufferBuilder::new(capacity), + completed: vec![], + in_progress: vec![], + block_size: BlockSizeGrowthStrategy::Exponential { + current_size: STARTING_BLOCK_SIZE, + }, + string_tracker: None, + phantom: Default::default(), + } + } + + /// Set a fixed buffer size for variable length strings + /// + /// The block size is the size of the buffer used to store values greater + /// than 12 bytes. The builder allocates new buffers when the current + /// buffer is full. + /// + /// By default the builder balances buffer size and buffer count by + /// growing buffer size exponentially from 8KB up to 2MB. The + /// first buffer allocated is 8KB, then 16KB, then 32KB, etc up to 2MB. + /// + /// If this method is used, any new buffers allocated are + /// exactly this size. This can be useful for advanced users + /// that want to control the memory usage and buffer count. + /// + /// See for more details on the implications. + pub fn with_fixed_block_size(self, block_size: u32) -> Self { + debug_assert!(block_size > 0, "Block size must be greater than 0"); + Self { + block_size: BlockSizeGrowthStrategy::Fixed { size: block_size }, + ..self + } + } + + /// Override the size of buffers to allocate for holding string data + /// Use `with_fixed_block_size` instead. + #[deprecated(note = "Use `with_fixed_block_size` instead")] + pub fn with_block_size(self, block_size: u32) -> Self { + self.with_fixed_block_size(block_size) + } + + /// Deduplicate strings while building the array + /// + /// This will potentially decrease the memory usage if the array have repeated strings + /// It will also increase the time to build the array as it needs to hash the strings + pub fn with_deduplicate_strings(self) -> Self { + Self { + string_tracker: Some(( + HashTable::with_capacity(self.views_builder.capacity()), + Default::default(), + )), + ..self + } + } + + /// Append a new data block returning the new block offset + /// + /// Note: this will first flush any in-progress block + /// + /// This allows appending views from blocks added using [`Self::append_block`]. See + /// [`Self::append_value`] for appending individual values + /// + /// ``` + /// # use arrow_array::builder::StringViewBuilder; + /// let mut builder = StringViewBuilder::new(); + /// + /// let block = builder.append_block(b"helloworldbingobongo".into()); + /// + /// builder.try_append_view(block, 0, 5).unwrap(); + /// builder.try_append_view(block, 5, 5).unwrap(); + /// builder.try_append_view(block, 10, 5).unwrap(); + /// builder.try_append_view(block, 15, 5).unwrap(); + /// builder.try_append_view(block, 0, 15).unwrap(); + /// let array = builder.finish(); + /// + /// let actual: Vec<_> = array.iter().flatten().collect(); + /// let expected = &["hello", "world", "bingo", "bongo", "helloworldbingo"]; + /// assert_eq!(actual, expected); + /// ``` + pub fn append_block(&mut self, buffer: Buffer) -> u32 { + assert!(buffer.len() < u32::MAX as usize); + + self.flush_in_progress(); + let offset = self.completed.len(); + self.push_completed(buffer); + offset as u32 + } + + /// Append a view of the given `block`, `offset` and `length` + /// + /// # Safety + /// (1) The block must have been added using [`Self::append_block`] + /// (2) The range `offset..offset+length` must be within the bounds of the block + /// (3) The data in the block must be valid of type `T` + pub unsafe fn append_view_unchecked(&mut self, block: u32, offset: u32, len: u32) { + let b = self.completed.get_unchecked(block as usize); + let start = offset as usize; + let end = start.saturating_add(len as usize); + let b = b.get_unchecked(start..end); + + let view = make_view(b, block, offset); + self.views_builder.append(view); + self.null_buffer_builder.append_non_null(); + } + + /// Try to append a view of the given `block`, `offset` and `length` + /// + /// See [`Self::append_block`] + pub fn try_append_view(&mut self, block: u32, offset: u32, len: u32) -> Result<(), ArrowError> { + let b = self.completed.get(block as usize).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("No block found with index {block}")) + })?; + let start = offset as usize; + let end = start.saturating_add(len as usize); + + let b = b.get(start..end).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Range {start}..{end} out of bounds for block of length {}", + b.len() + )) + })?; + + if T::Native::from_bytes_checked(b).is_none() { + return Err(ArrowError::InvalidArgumentError( + "Invalid view data".to_string(), + )); + } + + unsafe { + self.append_view_unchecked(block, offset, len); + } + Ok(()) + } + + /// Flushes the in progress block if any + #[inline] + fn flush_in_progress(&mut self) { + if !self.in_progress.is_empty() { + let f = Buffer::from_vec(std::mem::take(&mut self.in_progress)); + self.push_completed(f) + } + } + + /// Append a block to `self.completed`, checking for overflow + #[inline] + fn push_completed(&mut self, block: Buffer) { + assert!(block.len() < u32::MAX as usize, "Block too large"); + assert!(self.completed.len() < u32::MAX as usize, "Too many blocks"); + self.completed.push(block); + } + + /// Returns the value at the given index + /// Useful if we want to know what value has been inserted to the builder + /// The index has to be smaller than `self.len()`, otherwise it will panic + pub fn get_value(&self, index: usize) -> &[u8] { + let view = self.views_builder.as_slice().get(index).unwrap(); + let len = *view as u32; + if len <= 12 { + // # Safety + // The view is valid from the builder + unsafe { GenericByteViewArray::::inline_value(view, len as usize) } + } else { + let view = ByteView::from(*view); + if view.buffer_index < self.completed.len() as u32 { + let block = &self.completed[view.buffer_index as usize]; + &block[view.offset as usize..view.offset as usize + view.length as usize] + } else { + &self.in_progress[view.offset as usize..view.offset as usize + view.length as usize] + } + } + } + + /// Appends a value into the builder + /// + /// # Panics + /// + /// Panics if + /// - String buffer count exceeds `u32::MAX` + /// - String length exceeds `u32::MAX` + #[inline] + pub fn append_value(&mut self, value: impl AsRef) { + let v: &[u8] = value.as_ref().as_ref(); + let length: u32 = v.len().try_into().unwrap(); + if length <= 12 { + let mut view_buffer = [0; 16]; + view_buffer[0..4].copy_from_slice(&length.to_le_bytes()); + view_buffer[4..4 + v.len()].copy_from_slice(v); + self.views_builder.append(u128::from_le_bytes(view_buffer)); + self.null_buffer_builder.append_non_null(); + return; + } + + // Deduplication if: + // (1) deduplication is enabled. + // (2) len > 12 + if let Some((mut ht, hasher)) = self.string_tracker.take() { + let hash_val = hasher.hash_one(v); + let hasher_fn = |v: &_| hasher.hash_one(v); + + let entry = ht.entry( + hash_val, + |idx| { + let stored_value = self.get_value(*idx); + v == stored_value + }, + hasher_fn, + ); + match entry { + Entry::Occupied(occupied) => { + // If the string already exists, we will directly use the view + let idx = occupied.get(); + self.views_builder + .append(self.views_builder.as_slice()[*idx]); + self.null_buffer_builder.append_non_null(); + self.string_tracker = Some((ht, hasher)); + return; + } + Entry::Vacant(vacant) => { + // o.w. we insert the (string hash -> view index) + // the idx is current length of views_builder, as we are inserting a new view + vacant.insert(self.views_builder.len()); + } + } + self.string_tracker = Some((ht, hasher)); + } + + let required_cap = self.in_progress.len() + v.len(); + if self.in_progress.capacity() < required_cap { + self.flush_in_progress(); + let to_reserve = v.len().max(self.block_size.next_size() as usize); + self.in_progress.reserve(to_reserve); + }; + let offset = self.in_progress.len() as u32; + self.in_progress.extend_from_slice(v); + + let view = ByteView { + length, + prefix: u32::from_le_bytes(v[0..4].try_into().unwrap()), + buffer_index: self.completed.len() as u32, + offset, + }; + self.views_builder.append(view.into()); + self.null_buffer_builder.append_non_null(); + } + + /// Append an `Option` value into the builder + #[inline] + pub fn append_option(&mut self, value: Option>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Append a null value into the builder + #[inline] + pub fn append_null(&mut self) { + self.null_buffer_builder.append_null(); + self.views_builder.append(0); + } + + /// Builds the [`GenericByteViewArray`] and reset this builder + pub fn finish(&mut self) -> GenericByteViewArray { + self.flush_in_progress(); + let completed = std::mem::take(&mut self.completed); + let len = self.views_builder.len(); + let views = ScalarBuffer::new(self.views_builder.finish(), 0, len); + let nulls = self.null_buffer_builder.finish(); + if let Some((ref mut ht, _)) = self.string_tracker.as_mut() { + ht.clear(); + } + // SAFETY: valid by construction + unsafe { GenericByteViewArray::new_unchecked(views, completed, nulls) } + } + + /// Builds the [`GenericByteViewArray`] without resetting the builder + pub fn finish_cloned(&self) -> GenericByteViewArray { + let mut completed = self.completed.clone(); + if !self.in_progress.is_empty() { + completed.push(Buffer::from_slice_ref(&self.in_progress)); + } + let len = self.views_builder.len(); + let views = Buffer::from_slice_ref(self.views_builder.as_slice()); + let views = ScalarBuffer::new(views, 0, len); + let nulls = self.null_buffer_builder.finish_cloned(); + // SAFETY: valid by construction + unsafe { GenericByteViewArray::new_unchecked(views, completed, nulls) } + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Return the allocated size of this builder in bytes, useful for memory accounting. + pub fn allocated_size(&self) -> usize { + let views = self.views_builder.capacity() * std::mem::size_of::(); + let null = self.null_buffer_builder.allocated_size(); + let buffer_size = self.completed.iter().map(|b| b.capacity()).sum::(); + let in_progress = self.in_progress.capacity(); + let tracker = match &self.string_tracker { + Some((ht, _)) => ht.capacity() * std::mem::size_of::(), + None => 0, + }; + buffer_size + in_progress + tracker + views + null + } +} + +impl Default for GenericByteViewBuilder { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for GenericByteViewBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}ViewBuilder", T::PREFIX)?; + f.debug_struct("") + .field("views_builder", &self.views_builder) + .field("in_progress", &self.in_progress) + .field("completed", &self.completed) + .field("null_buffer_builder", &self.null_buffer_builder) + .finish() + } +} + +impl ArrayBuilder for GenericByteViewBuilder { + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +impl> Extend> + for GenericByteViewBuilder +{ + #[inline] + fn extend>>(&mut self, iter: I) { + for v in iter { + self.append_option(v) + } + } +} + +/// Array builder for [`StringViewArray`][crate::StringViewArray] +/// +/// Values can be appended using [`GenericByteViewBuilder::append_value`], and nulls with +/// [`GenericByteViewBuilder::append_null`] as normal. +/// +/// # Example +/// ``` +/// # use arrow_array::builder::StringViewBuilder; +/// # use arrow_array::StringViewArray; +/// let mut builder = StringViewBuilder::new(); +/// builder.append_value("hello"); +/// builder.append_null(); +/// builder.append_value("world"); +/// let array = builder.finish(); +/// +/// let expected = vec![Some("hello"), None, Some("world")]; +/// let actual: Vec<_> = array.iter().collect(); +/// assert_eq!(expected, actual); +/// ``` +pub type StringViewBuilder = GenericByteViewBuilder; + +/// Array builder for [`BinaryViewArray`][crate::BinaryViewArray] +/// +/// Values can be appended using [`GenericByteViewBuilder::append_value`], and nulls with +/// [`GenericByteViewBuilder::append_null`] as normal. +/// +/// # Example +/// ``` +/// # use arrow_array::builder::BinaryViewBuilder; +/// use arrow_array::BinaryViewArray; +/// let mut builder = BinaryViewBuilder::new(); +/// builder.append_value("hello"); +/// builder.append_null(); +/// builder.append_value("world"); +/// let array = builder.finish(); +/// +/// let expected: Vec> = vec![Some(b"hello"), None, Some(b"world")]; +/// let actual: Vec<_> = array.iter().collect(); +/// assert_eq!(expected, actual); +/// ``` +/// +pub type BinaryViewBuilder = GenericByteViewBuilder; + +/// Creates a view from a fixed length input (the compiler can generate +/// specialized code for this) +fn make_inlined_view(data: &[u8]) -> u128 { + let mut view_buffer = [0; 16]; + view_buffer[0..4].copy_from_slice(&(LEN as u32).to_le_bytes()); + view_buffer[4..4 + LEN].copy_from_slice(&data[..LEN]); + u128::from_le_bytes(view_buffer) +} + +/// Create a view based on the given data, block id and offset. +/// +/// Note that the code below is carefully examined with x86_64 assembly code: +/// The goal is to avoid calling into `ptr::copy_non_interleave`, which makes function call (i.e., not inlined), +/// which slows down things. +#[inline(never)] +pub fn make_view(data: &[u8], block_id: u32, offset: u32) -> u128 { + let len = data.len(); + + // Generate specialized code for each potential small string length + // to improve performance + match len { + 0 => make_inlined_view::<0>(data), + 1 => make_inlined_view::<1>(data), + 2 => make_inlined_view::<2>(data), + 3 => make_inlined_view::<3>(data), + 4 => make_inlined_view::<4>(data), + 5 => make_inlined_view::<5>(data), + 6 => make_inlined_view::<6>(data), + 7 => make_inlined_view::<7>(data), + 8 => make_inlined_view::<8>(data), + 9 => make_inlined_view::<9>(data), + 10 => make_inlined_view::<10>(data), + 11 => make_inlined_view::<11>(data), + 12 => make_inlined_view::<12>(data), + // When string is longer than 12 bytes, it can't be inlined, we create a ByteView instead. + _ => { + let view = ByteView { + length: len as u32, + prefix: u32::from_le_bytes(data[0..4].try_into().unwrap()), + buffer_index: block_id, + offset, + }; + view.as_u128() + } + } +} + +#[cfg(test)] +mod tests { + use core::str; + + use super::*; + use crate::Array; + + #[test] + fn test_string_view_deduplicate() { + let value_1 = "long string to test string view"; + let value_2 = "not so similar string but long"; + + let mut builder = StringViewBuilder::new() + .with_deduplicate_strings() + .with_fixed_block_size(value_1.len() as u32 * 2); // so that we will have multiple buffers + + let values = vec![ + Some(value_1), + Some(value_2), + Some("short"), + Some(value_1), + None, + Some(value_2), + Some(value_1), + ]; + builder.extend(values.clone()); + + let array = builder.finish_cloned(); + array.to_data().validate_full().unwrap(); + assert_eq!(array.data_buffers().len(), 1); // without duplication we would need 3 buffers. + let actual: Vec<_> = array.iter().collect(); + assert_eq!(actual, values); + + let view0 = array.views().first().unwrap(); + let view3 = array.views().get(3).unwrap(); + let view6 = array.views().get(6).unwrap(); + + assert_eq!(view0, view3); + assert_eq!(view0, view6); + + assert_eq!(array.views().get(1), array.views().get(5)); + } + + #[test] + fn test_string_view_deduplicate_after_finish() { + let mut builder = StringViewBuilder::new().with_deduplicate_strings(); + + let value_1 = "long string to test string view"; + let value_2 = "not so similar string but long"; + builder.append_value(value_1); + let _array = builder.finish(); + builder.append_value(value_2); + let _array = builder.finish(); + builder.append_value(value_1); + let _array = builder.finish(); + } + + #[test] + fn test_string_view() { + let b1 = Buffer::from(b"world\xFFbananas\xF0\x9F\x98\x81"); + let b2 = Buffer::from(b"cupcakes"); + let b3 = Buffer::from(b"Many strings are here contained of great length and verbosity"); + + let mut v = StringViewBuilder::new(); + assert_eq!(v.append_block(b1), 0); + + v.append_value("This is a very long string that exceeds the inline length"); + v.append_value("This is another very long string that exceeds the inline length"); + + assert_eq!(v.append_block(b2), 2); + assert_eq!(v.append_block(b3), 3); + + // Test short strings + v.try_append_view(0, 0, 5).unwrap(); // world + v.try_append_view(0, 6, 7).unwrap(); // bananas + v.try_append_view(2, 3, 5).unwrap(); // cake + v.try_append_view(2, 0, 3).unwrap(); // cup + v.try_append_view(2, 0, 8).unwrap(); // cupcakes + v.try_append_view(0, 13, 4).unwrap(); // 😁 + v.try_append_view(0, 13, 0).unwrap(); // + + // Test longer strings + v.try_append_view(3, 0, 16).unwrap(); // Many strings are + v.try_append_view(1, 0, 19).unwrap(); // This is a very long + v.try_append_view(3, 13, 27).unwrap(); // here contained of great length + + v.append_value("I do so like long strings"); + + let array = v.finish_cloned(); + array.to_data().validate_full().unwrap(); + assert_eq!(array.data_buffers().len(), 5); + let actual: Vec<_> = array.iter().flatten().collect(); + assert_eq!( + actual, + &[ + "This is a very long string that exceeds the inline length", + "This is another very long string that exceeds the inline length", + "world", + "bananas", + "cakes", + "cup", + "cupcakes", + "😁", + "", + "Many strings are", + "This is a very long", + "are here contained of great", + "I do so like long strings" + ] + ); + + let err = v.try_append_view(0, u32::MAX, 1).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Range 4294967295..4294967296 out of bounds for block of length 17"); + + let err = v.try_append_view(0, 1, u32::MAX).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Range 1..4294967296 out of bounds for block of length 17" + ); + + let err = v.try_append_view(0, 13, 2).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Invalid view data"); + + let err = v.try_append_view(0, 40, 0).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Range 40..40 out of bounds for block of length 17" + ); + + let err = v.try_append_view(5, 0, 0).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: No block found with index 5" + ); + } + + #[test] + fn test_string_view_with_block_size_growth() { + let mut exp_builder = StringViewBuilder::new(); + let mut fixed_builder = StringViewBuilder::new().with_fixed_block_size(STARTING_BLOCK_SIZE); + + let long_string = str::from_utf8(&[b'a'; STARTING_BLOCK_SIZE as usize]).unwrap(); + + for i in 0..9 { + // 8k, 16k, 32k, 64k, 128k, 256k, 512k, 1M, 2M + for _ in 0..(2_u32.pow(i)) { + exp_builder.append_value(long_string); + fixed_builder.append_value(long_string); + } + exp_builder.flush_in_progress(); + fixed_builder.flush_in_progress(); + + // Every step only add one buffer, but the buffer size is much larger + assert_eq!(exp_builder.completed.len(), i as usize + 1); + assert_eq!( + exp_builder.completed[i as usize].len(), + STARTING_BLOCK_SIZE as usize * 2_usize.pow(i) + ); + + // This step we added 2^i blocks, the sum of blocks should be 2^(i+1) - 1 + assert_eq!(fixed_builder.completed.len(), 2_usize.pow(i + 1) - 1); + + // Every buffer is fixed size + assert!(fixed_builder + .completed + .iter() + .all(|b| b.len() == STARTING_BLOCK_SIZE as usize)); + } + + // Add one more value, and the buffer stop growing. + exp_builder.append_value(long_string); + exp_builder.flush_in_progress(); + assert_eq!( + exp_builder.completed.last().unwrap().capacity(), + MAX_BLOCK_SIZE as usize + ); + } +} diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs new file mode 100644 index 000000000000..6ff5f20df684 --- /dev/null +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -0,0 +1,806 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{Buffer, OffsetBuffer}; +use arrow_schema::{Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`GenericListArray`] +/// +/// Use [`ListBuilder`] to build [`ListArray`]s and [`LargeListBuilder`] to build [`LargeListArray`]s. +/// +/// # Example +/// +/// Here is code that constructs a ListArray with the contents: +/// `[[A,B,C], [], NULL, [D], [NULL, F]]` +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{builder::ListBuilder, builder::StringBuilder, ArrayRef, StringArray, Array}; +/// # +/// let values_builder = StringBuilder::new(); +/// let mut builder = ListBuilder::new(values_builder); +/// +/// // [A, B, C] +/// builder.values().append_value("A"); +/// builder.values().append_value("B"); +/// builder.values().append_value("C"); +/// builder.append(true); +/// +/// // [ ] (empty list) +/// builder.append(true); +/// +/// // Null +/// builder.values().append_value("?"); // irrelevant +/// builder.append(false); +/// +/// // [D] +/// builder.values().append_value("D"); +/// builder.append(true); +/// +/// // [NULL, F] +/// builder.values().append_null(); +/// builder.values().append_value("F"); +/// builder.append(true); +/// +/// // Build the array +/// let array = builder.finish(); +/// +/// // Values is a string array +/// // "A", "B" "C", "?", "D", NULL, "F" +/// assert_eq!( +/// array.values().as_ref(), +/// &StringArray::from(vec![ +/// Some("A"), Some("B"), Some("C"), +/// Some("?"), Some("D"), None, +/// Some("F") +/// ]) +/// ); +/// +/// // Offsets are indexes into the values array +/// assert_eq!( +/// array.value_offsets(), +/// &[0, 3, 3, 4, 5, 7] +/// ); +/// ``` +/// +/// [`ListBuilder`]: crate::builder::ListBuilder +/// [`ListArray`]: crate::array::ListArray +/// [`LargeListBuilder`]: crate::builder::LargeListBuilder +/// [`LargeListArray`]: crate::array::LargeListArray +#[derive(Debug)] +pub struct GenericListBuilder { + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, + values_builder: T, + field: Option, +} + +impl Default for GenericListBuilder { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl GenericListBuilder { + /// Creates a new [`GenericListBuilder`] from a given values array builder + pub fn new(values_builder: T) -> Self { + let capacity = values_builder.len(); + Self::with_capacity(values_builder, capacity) + } + + /// Creates a new [`GenericListBuilder`] from a given values array builder + /// `capacity` is the number of items to pre-allocate space for in this builder + pub fn with_capacity(values_builder: T, capacity: usize) -> Self { + let mut offsets_builder = BufferBuilder::::new(capacity + 1); + offsets_builder.append(OffsetSize::zero()); + Self { + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(capacity), + values_builder, + field: None, + } + } + + /// Override the field passed to [`GenericListArray::new`] + /// + /// By default a nullable field is created with the name `item` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `T` + pub fn with_field(self, field: impl Into) -> Self { + Self { + field: Some(field.into()), + ..self + } + } +} + +impl ArrayBuilder + for GenericListBuilder +where + T: 'static, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericListBuilder +where + T: 'static, +{ + /// Returns the child array builder as a mutable reference. + /// + /// This mutable reference can be used to append values into the child array builder, + /// but you must call [`append`](#method.append) to delimit each distinct list value. + pub fn values(&mut self) -> &mut T { + &mut self.values_builder + } + + /// Returns the child array builder as an immutable reference + pub fn values_ref(&self) -> &T { + &self.values_builder + } + + /// Finish the current variable-length list array slot + /// + /// # Panics + /// + /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` + #[inline] + pub fn append(&mut self, is_valid: bool) { + self.offsets_builder.append(self.next_offset()); + self.null_buffer_builder.append(is_valid); + } + + /// Returns the next offset + /// + /// # Panics + /// + /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` + #[inline] + fn next_offset(&self) -> OffsetSize { + OffsetSize::from_usize(self.values_builder.len()).unwrap() + } + + /// Append a value to this [`GenericListBuilder`] + /// + /// ``` + /// # use arrow_array::builder::{Int32Builder, ListBuilder}; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::{Array, Int32Array}; + /// # use arrow_array::types::Int32Type; + /// let mut builder = ListBuilder::new(Int32Builder::new()); + /// + /// builder.append_value([Some(1), Some(2), Some(3)]); + /// builder.append_value([]); + /// builder.append_value([None]); + /// + /// let array = builder.finish(); + /// assert_eq!(array.len(), 3); + /// + /// assert_eq!(array.value_offsets(), &[0, 3, 3, 4]); + /// let values = array.values().as_primitive::(); + /// assert_eq!(values, &Int32Array::from(vec![Some(1), Some(2), Some(3), None])); + /// ``` + /// + /// This is an alternative API to appending directly to [`Self::values`] and + /// delimiting the result with [`Self::append`] + /// + /// ``` + /// # use arrow_array::builder::{Int32Builder, ListBuilder}; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::{Array, Int32Array}; + /// # use arrow_array::types::Int32Type; + /// let mut builder = ListBuilder::new(Int32Builder::new()); + /// + /// builder.values().append_value(1); + /// builder.values().append_value(2); + /// builder.values().append_value(3); + /// builder.append(true); + /// builder.append(true); + /// builder.values().append_null(); + /// builder.append(true); + /// + /// let array = builder.finish(); + /// assert_eq!(array.len(), 3); + /// + /// assert_eq!(array.value_offsets(), &[0, 3, 3, 4]); + /// let values = array.values().as_primitive::(); + /// assert_eq!(values, &Int32Array::from(vec![Some(1), Some(2), Some(3), None])); + /// ``` + #[inline] + pub fn append_value(&mut self, i: I) + where + T: Extend>, + I: IntoIterator>, + { + self.extend(std::iter::once(Some(i))) + } + + /// Append a null to this [`GenericListBuilder`] + /// + /// See [`Self::append_value`] for an example use. + #[inline] + pub fn append_null(&mut self) { + self.offsets_builder.append(self.next_offset()); + self.null_buffer_builder.append_null(); + } + + /// Appends an optional value into this [`GenericListBuilder`] + /// + /// If `Some` calls [`Self::append_value`] otherwise calls [`Self::append_null`] + #[inline] + pub fn append_option(&mut self, i: Option) + where + T: Extend>, + I: IntoIterator>, + { + match i { + Some(i) => self.append_value(i), + None => self.append_null(), + } + } + + /// Builds the [`GenericListArray`] and reset this builder. + pub fn finish(&mut self) -> GenericListArray { + let values = self.values_builder.finish(); + let nulls = self.null_buffer_builder.finish(); + + let offsets = self.offsets_builder.finish(); + // Safety: Safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + self.offsets_builder.append(OffsetSize::zero()); + + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; + + GenericListArray::new(field, offsets, values, nulls) + } + + /// Builds the [`GenericListArray`] without resetting the builder. + pub fn finish_cloned(&self) -> GenericListArray { + let values = self.values_builder.finish_cloned(); + let nulls = self.null_buffer_builder.finish_cloned(); + + let offsets = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + // Safety: safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; + + GenericListArray::new(field, offsets, values, nulls) + } + + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[OffsetSize] { + self.offsets_builder.as_slice() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } +} + +impl Extend> for GenericListBuilder +where + O: OffsetSizeTrait, + B: ArrayBuilder + Extend, + V: IntoIterator, +{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + match v { + Some(elements) => { + self.values_builder.extend(elements); + self.append(true); + } + None => self.append(false), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{make_builder, Int32Builder, ListBuilder}; + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::Int32Array; + use arrow_schema::DataType; + + fn _test_generic_list_array_builder() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListBuilder::::new(values_builder); + + // [[0, 1, 2], [3, 4, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + let list_array = builder.finish(); + + let list_values = list_array.values().as_primitive::(); + assert_eq!(list_values.values(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(list_array.value_offsets(), [0, 3, 6, 8].map(O::usize_as)); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(O::from_usize(6).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(2).unwrap(), list_array.value_length(2)); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + } + + #[test] + fn test_list_array_builder() { + _test_generic_list_array_builder::() + } + + #[test] + fn test_large_list_array_builder() { + _test_generic_list_array_builder::() + } + + fn _test_generic_list_array_builder_nulls() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListBuilder::::new(values_builder); + + // [[0, 1, 2], null, [3, null, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_length(2)); + } + + #[test] + fn test_list_array_builder_nulls() { + _test_generic_list_array_builder_nulls::() + } + + #[test] + fn test_large_list_array_builder_nulls() { + _test_generic_list_array_builder_nulls::() + } + + #[test] + fn test_list_array_builder_finish() { + let values_builder = Int32Array::builder(5); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish(); + assert_eq!(2, arr.len()); + assert!(builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(1, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_array_builder_finish_cloned() { + let values_builder = Int32Array::builder(5); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish_cloned(); + assert_eq!(2, arr.len()); + assert!(!builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(3, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_list_array_builder() { + let primitive_builder = Int32Builder::with_capacity(10); + let values_builder = ListBuilder::new(primitive_builder); + let mut builder = ListBuilder::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder.values().values().append_value(1); + builder.values().values().append_value(2); + builder.values().append(true); + builder.values().values().append_value(3); + builder.values().values().append_value(4); + builder.values().append(true); + builder.append(true); + + builder.values().values().append_value(5); + builder.values().values().append_value(6); + builder.values().values().append_value(7); + builder.values().append(true); + builder.values().append(false); + builder.values().values().append_value(8); + builder.values().append(true); + builder.append(true); + + builder.append(false); + + builder.values().values().append_value(9); + builder.values().values().append_value(10); + builder.values().append(true); + builder.append(true); + + let l1 = builder.finish(); + + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5, 6]); + let l2 = l1.values().as_list::(); + + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8, 10]); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn test_extend() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.extend([ + Some(vec![Some(1), Some(2), Some(7), None]), + Some(vec![]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let array = builder.finish(); + assert_eq!(array.value_offsets(), [0, 4, 4, 6, 6]); + assert_eq!(array.null_count(), 1); + assert!(array.is_null(3)); + let elements = array.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 7, 0, 4, 5]); + assert_eq!(elements.null_count(), 1); + assert!(elements.is_null(3)); + } + + #[test] + fn test_boxed_primitive_array_builder() { + let values_builder = make_builder(&DataType::Int32, 5); + let mut builder = ListBuilder::new(values_builder); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[1, 2, 3]); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[4, 5, 6]); + builder.append(true); + + let arr = builder.finish(); + assert_eq!(2, arr.len()); + + let elements = arr.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 3, 4, 5, 6]); + } + + #[test] + fn test_boxed_list_list_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_generic_list_array_builder::(values_builder); + } + + #[test] + fn test_boxed_large_list_large_list_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_generic_list_array_builder::(values_builder); + } + + fn test_boxed_generic_list_generic_list_array_builder( + values_builder: Box, + ) { + let mut builder: GenericListBuilder> = + GenericListBuilder::>::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(1); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(2); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(3); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(4); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(5); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(6); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an (Large)ListBuilder") + .append_value(7); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(false); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(8); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + builder.append(false); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(9); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(10); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + let l1 = builder.finish(); + + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5, 6].map(O::usize_as)); + let l2 = l1.values().as_list::(); + + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8, 10].map(O::usize_as)); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn test_with_field() { + let field = Arc::new(Field::new("bar", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), Some(2), Some(3)]); + builder.append_null(); // This is fine as nullability refers to nullability of values + builder.append_value([Some(4)]); + let array = builder.finish(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::List(field.clone())); + + builder.append_value([Some(4), Some(5)]); + let array = builder.finish(); + assert_eq!(array.data_type(), &DataType::List(field)); + assert_eq!(array.len(), 1); + } + + #[test] + #[should_panic(expected = "Non-nullable field of ListArray \\\"item\\\" cannot contain nulls")] + fn test_checks_nullability() { + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), None]); + builder.finish(); + } + + #[test] + #[should_panic(expected = "ListArray expected data type Int64 got Int32")] + fn test_checks_data_type() { + let field = Arc::new(Field::new("item", DataType::Int64, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1)]); + builder.finish(); + } +} diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs new file mode 100644 index 000000000000..1d89d427aae1 --- /dev/null +++ b/arrow-array/src/builder/map_builder.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{Array, ArrayRef, MapArray, StructArray}; +use arrow_buffer::Buffer; +use arrow_buffer::{NullBuffer, NullBufferBuilder}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType, Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`MapArray`] +/// +/// ``` +/// # use arrow_array::builder::{Int32Builder, MapBuilder, StringBuilder}; +/// # use arrow_array::{Int32Array, StringArray}; +/// +/// let string_builder = StringBuilder::new(); +/// let int_builder = Int32Builder::with_capacity(4); +/// +/// // Construct `[{"joe": 1}, {"blogs": 2, "foo": 4}, {}, null]` +/// let mut builder = MapBuilder::new(None, string_builder, int_builder); +/// +/// builder.keys().append_value("joe"); +/// builder.values().append_value(1); +/// builder.append(true).unwrap(); +/// +/// builder.keys().append_value("blogs"); +/// builder.values().append_value(2); +/// builder.keys().append_value("foo"); +/// builder.values().append_value(4); +/// builder.append(true).unwrap(); +/// builder.append(true).unwrap(); +/// builder.append(false).unwrap(); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value_offsets(), &[0, 1, 3, 3, 3]); +/// assert_eq!(array.values().as_ref(), &Int32Array::from(vec![1, 2, 4])); +/// assert_eq!(array.keys().as_ref(), &StringArray::from(vec!["joe", "blogs", "foo"])); +/// +/// ``` +#[derive(Debug)] +pub struct MapBuilder { + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, + field_names: MapFieldNames, + key_builder: K, + value_builder: V, + value_field: Option, +} + +/// The [`Field`] names for a [`MapArray`] +#[derive(Debug, Clone)] +pub struct MapFieldNames { + /// [`Field`] name for map entries + pub entry: String, + /// [`Field`] name for map key + pub key: String, + /// [`Field`] name for map value + pub value: String, +} + +impl Default for MapFieldNames { + fn default() -> Self { + Self { + entry: "entries".to_string(), + key: "keys".to_string(), + value: "values".to_string(), + } + } +} + +impl MapBuilder { + /// Creates a new `MapBuilder` + pub fn new(field_names: Option, key_builder: K, value_builder: V) -> Self { + let capacity = key_builder.len(); + Self::with_capacity(field_names, key_builder, value_builder, capacity) + } + + /// Creates a new `MapBuilder` with capacity + pub fn with_capacity( + field_names: Option, + key_builder: K, + value_builder: V, + capacity: usize, + ) -> Self { + let mut offsets_builder = BufferBuilder::::new(capacity + 1); + offsets_builder.append(0); + Self { + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(capacity), + field_names: field_names.unwrap_or_default(), + key_builder, + value_builder, + value_field: None, + } + } + + /// Override the field passed to [`MapBuilder::new`] + /// + /// By default a nullable field is created with the name `values` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `V` + pub fn with_values_field(self, field: impl Into) -> Self { + Self { + value_field: Some(field.into()), + ..self + } + } + + /// Returns the key array builder of the map + pub fn keys(&mut self) -> &mut K { + &mut self.key_builder + } + + /// Returns the value array builder of the map + pub fn values(&mut self) -> &mut V { + &mut self.value_builder + } + + /// Returns both the key and value array builders of the map + pub fn entries(&mut self) -> (&mut K, &mut V) { + (&mut self.key_builder, &mut self.value_builder) + } + + /// Finish the current map array slot + /// + /// Returns an error if the key and values builders are in an inconsistent state. + #[inline] + pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> { + if self.key_builder.len() != self.value_builder.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}", + self.key_builder.len(), + self.value_builder.len() + ))); + } + self.offsets_builder.append(self.key_builder.len() as i32); + self.null_buffer_builder.append(is_valid); + Ok(()) + } + + /// Builds the [`MapArray`] + pub fn finish(&mut self) -> MapArray { + let len = self.len(); + // Build the keys + let keys_arr = self.key_builder.finish(); + let values_arr = self.value_builder.finish(); + let offset_buffer = self.offsets_builder.finish(); + self.offsets_builder.append(0); + let null_bit_buffer = self.null_buffer_builder.finish(); + + self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len) + } + + /// Builds the [`MapArray`] without resetting the builder. + pub fn finish_cloned(&self) -> MapArray { + let len = self.len(); + // Build the keys + let keys_arr = self.key_builder.finish_cloned(); + let values_arr = self.value_builder.finish_cloned(); + let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let nulls = self.null_buffer_builder.finish_cloned(); + self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len) + } + + fn finish_helper( + &self, + keys_arr: Arc, + values_arr: Arc, + offset_buffer: Buffer, + nulls: Option, + len: usize, + ) -> MapArray { + assert!( + keys_arr.null_count() == 0, + "Keys array must have no null values, found {} null value(s)", + keys_arr.null_count() + ); + + let keys_field = Arc::new(Field::new( + self.field_names.key.as_str(), + keys_arr.data_type().clone(), + false, // always non-nullable + )); + let values_field = match &self.value_field { + Some(f) => f.clone(), + None => Arc::new(Field::new( + self.field_names.value.as_str(), + values_arr.data_type().clone(), + true, + )), + }; + + let struct_array = + StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]); + + let map_field = Arc::new(Field::new( + self.field_names.entry.as_str(), + struct_array.data_type().clone(), + false, // always non-nullable + )); + let array_data = ArrayData::builder(DataType::Map(map_field, false)) // TODO: support sorted keys + .len(len) + .add_buffer(offset_buffer) + .add_child_data(struct_array.into_data()) + .nulls(nulls); + + let array_data = unsafe { array_data.build_unchecked() }; + + MapArray::from(array_data) + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } +} + +impl ArrayBuilder for MapBuilder { + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +#[cfg(test)] +mod tests { + use crate::builder::{make_builder, Int32Builder, StringBuilder}; + use crate::{Int32Array, StringArray}; + + use super::*; + + #[test] + #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")] + fn test_map_builder_with_null_keys_panics() { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + builder.keys().append_null(); + builder.values().append_value(42); + builder.append(true).unwrap(); + + builder.finish(); + } + + #[test] + fn test_boxed_map_builder() { + let keys_builder = make_builder(&DataType::Utf8, 5); + let values_builder = make_builder(&DataType::Int32, 5); + + let mut builder = MapBuilder::new(None, keys_builder, values_builder); + builder + .keys() + .as_any_mut() + .downcast_mut::() + .expect("should be an StringBuilder") + .append_value("1"); + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(42); + builder.append(true).unwrap(); + + let map_array = builder.finish(); + + assert_eq!( + map_array + .keys() + .as_any() + .downcast_ref::() + .expect("should be an StringArray") + .value(0), + "1" + ); + assert_eq!( + map_array + .values() + .as_any() + .downcast_ref::() + .expect("should be an Int32Array") + .value(0), + 42 + ); + } + + #[test] + fn test_with_values_field() { + let value_field = Arc::new(Field::new("bars", DataType::Int32, false)); + let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) + .with_values_field(value_field.clone()); + builder.keys().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + builder.append(false).unwrap(); // This is fine as nullability refers to nullability of values + builder.keys().append_value(3); + builder.values().append_value(4); + builder.append(true).unwrap(); + let map = builder.finish(); + + assert_eq!(map.len(), 3); + assert_eq!( + map.data_type(), + &DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Arc::new(Field::new("keys", DataType::Int32, false)), + value_field.clone() + ] + .into() + ), + false, + )), + false + ) + ); + + builder.keys().append_value(5); + builder.values().append_value(6); + builder.append(true).unwrap(); + let map = builder.finish(); + + assert_eq!(map.len(), 1); + assert_eq!( + map.data_type(), + &DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Arc::new(Field::new("keys", DataType::Int32, false)), + value_field + ] + .into() + ), + false, + )), + false + ) + ); + } +} diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs new file mode 100644 index 000000000000..dd1a5c3ae722 --- /dev/null +++ b/arrow-array/src/builder/mod.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines push-based APIs for constructing arrays +//! +//! # Basic Usage +//! +//! Builders can be used to build simple, non-nested arrays +//! +//! ``` +//! # use arrow_array::builder::Int32Builder; +//! # use arrow_array::PrimitiveArray; +//! let mut a = Int32Builder::new(); +//! a.append_value(1); +//! a.append_null(); +//! a.append_value(2); +//! let a = a.finish(); +//! +//! assert_eq!(a, PrimitiveArray::from(vec![Some(1), None, Some(2)])); +//! ``` +//! +//! ``` +//! # use arrow_array::builder::StringBuilder; +//! # use arrow_array::{Array, StringArray}; +//! let mut a = StringBuilder::new(); +//! a.append_value("foo"); +//! a.append_value("bar"); +//! a.append_null(); +//! let a = a.finish(); +//! +//! assert_eq!(a, StringArray::from_iter([Some("foo"), Some("bar"), None])); +//! ``` +//! +//! # Nested Usage +//! +//! Builders can also be used to build more complex nested arrays, such as lists +//! +//! ``` +//! # use arrow_array::builder::{Int32Builder, ListBuilder}; +//! # use arrow_array::ListArray; +//! # use arrow_array::types::Int32Type; +//! let mut a = ListBuilder::new(Int32Builder::new()); +//! // [1, 2] +//! a.values().append_value(1); +//! a.values().append_value(2); +//! a.append(true); +//! // null +//! a.append(false); +//! // [] +//! a.append(true); +//! // [3, null] +//! a.values().append_value(3); +//! a.values().append_null(); +//! a.append(true); +//! +//! // [[1, 2], null, [], [3, null]] +//! let a = a.finish(); +//! +//! assert_eq!(a, ListArray::from_iter_primitive::([ +//! Some(vec![Some(1), Some(2)]), +//! None, +//! Some(vec![]), +//! Some(vec![Some(3), None])] +//! )) +//! ``` +//! +//! # Custom Builders +//! +//! It is common to have a collection of statically defined Rust types that +//! you want to convert to Arrow arrays. +//! +//! An example of doing so is below +//! +//! ``` +//! # use std::any::Any; +//! # use arrow_array::builder::{ArrayBuilder, Int32Builder, ListBuilder, StringBuilder}; +//! # use arrow_array::{ArrayRef, RecordBatch, StructArray}; +//! # use arrow_schema::{DataType, Field}; +//! # use std::sync::Arc; +//! /// A custom row representation +//! struct MyRow { +//! i32: i32, +//! optional_i32: Option, +//! string: Option, +//! i32_list: Option>>, +//! } +//! +//! /// Converts `Vec` into `StructArray` +//! #[derive(Debug, Default)] +//! struct MyRowBuilder { +//! i32: Int32Builder, +//! string: StringBuilder, +//! i32_list: ListBuilder, +//! } +//! +//! impl MyRowBuilder { +//! fn append(&mut self, row: &MyRow) { +//! self.i32.append_value(row.i32); +//! self.string.append_option(row.string.as_ref()); +//! self.i32_list.append_option(row.i32_list.as_ref().map(|x| x.iter().copied())); +//! } +//! +//! /// Note: returns StructArray to allow nesting within another array if desired +//! fn finish(&mut self) -> StructArray { +//! let i32 = Arc::new(self.i32.finish()) as ArrayRef; +//! let i32_field = Arc::new(Field::new("i32", DataType::Int32, false)); +//! +//! let string = Arc::new(self.string.finish()) as ArrayRef; +//! let string_field = Arc::new(Field::new("i32", DataType::Utf8, false)); +//! +//! let i32_list = Arc::new(self.i32_list.finish()) as ArrayRef; +//! let value_field = Arc::new(Field::new("item", DataType::Int32, true)); +//! let i32_list_field = Arc::new(Field::new("i32_list", DataType::List(value_field), true)); +//! +//! StructArray::from(vec![ +//! (i32_field, i32), +//! (string_field, string), +//! (i32_list_field, i32_list), +//! ]) +//! } +//! } +//! +//! impl<'a> Extend<&'a MyRow> for MyRowBuilder { +//! fn extend>(&mut self, iter: T) { +//! iter.into_iter().for_each(|row| self.append(row)); +//! } +//! } +//! +//! /// Converts a slice of [`MyRow`] to a [`RecordBatch`] +//! fn rows_to_batch(rows: &[MyRow]) -> RecordBatch { +//! let mut builder = MyRowBuilder::default(); +//! builder.extend(rows); +//! RecordBatch::from(&builder.finish()) +//! } +//! ``` + +pub use arrow_buffer::BooleanBufferBuilder; + +mod boolean_builder; +pub use boolean_builder::*; +mod buffer_builder; +pub use buffer_builder::*; +mod fixed_size_binary_builder; +pub use fixed_size_binary_builder::*; +mod fixed_size_list_builder; +pub use fixed_size_list_builder::*; +mod generic_bytes_builder; +pub use generic_bytes_builder::*; +mod generic_list_builder; +pub use generic_list_builder::*; +mod map_builder; +pub use map_builder::*; +mod null_builder; +pub use null_builder::*; +mod primitive_builder; +pub use primitive_builder::*; +mod primitive_dictionary_builder; +pub use primitive_dictionary_builder::*; +mod primitive_run_builder; +pub use primitive_run_builder::*; +mod struct_builder; +pub use struct_builder::*; +mod generic_bytes_dictionary_builder; +pub use generic_bytes_dictionary_builder::*; +mod generic_byte_run_builder; +pub use generic_byte_run_builder::*; +mod generic_bytes_view_builder; +pub use generic_bytes_view_builder::*; +mod union_builder; + +pub use union_builder::*; + +use crate::ArrayRef; +use std::any::Any; + +/// Trait for dealing with different array builders at runtime +/// +/// # Example +/// +/// ``` +/// // Create +/// # use arrow_array::{ArrayRef, StringArray}; +/// # use arrow_array::builder::{ArrayBuilder, Float64Builder, Int64Builder, StringBuilder}; +/// +/// let mut data_builders: Vec> = vec![ +/// Box::new(Float64Builder::new()), +/// Box::new(Int64Builder::new()), +/// Box::new(StringBuilder::new()), +/// ]; +/// +/// // Fill +/// data_builders[0] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value(3.14); +/// data_builders[1] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value(-1); +/// data_builders[2] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value("🍎"); +/// +/// // Finish +/// let array_refs: Vec = data_builders +/// .iter_mut() +/// .map(|builder| builder.finish()) +/// .collect(); +/// assert_eq!(array_refs[0].len(), 1); +/// assert_eq!(array_refs[1].is_null(0), false); +/// assert_eq!( +/// array_refs[2] +/// .as_any() +/// .downcast_ref::() +/// .unwrap() +/// .value(0), +/// "🍎" +/// ); +/// ``` +pub trait ArrayBuilder: Any + Send + Sync { + /// Returns the number of array slots in the builder + fn len(&self) -> usize; + + /// Returns whether number of array slots is zero + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Builds the array + fn finish(&mut self) -> ArrayRef; + + /// Builds the array without resetting the underlying builder. + fn finish_cloned(&self) -> ArrayRef; + + /// Returns the builder as a non-mutable `Any` reference. + /// + /// This is most useful when one wants to call non-mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_ref` to get a reference on the specific builder. + fn as_any(&self) -> &dyn Any; + + /// Returns the builder as a mutable `Any` reference. + /// + /// This is most useful when one wants to call mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_mut` to get a reference on the specific builder. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box; +} + +impl ArrayBuilder for Box { + fn len(&self) -> usize { + (**self).len() + } + + fn is_empty(&self) -> bool { + (**self).is_empty() + } + + fn finish(&mut self) -> ArrayRef { + (**self).finish() + } + + fn finish_cloned(&self) -> ArrayRef { + (**self).finish_cloned() + } + + fn as_any(&self) -> &dyn Any { + (**self).as_any() + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + (**self).as_any_mut() + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +/// Builder for [`ListArray`](crate::array::ListArray) +pub type ListBuilder = GenericListBuilder; + +/// Builder for [`LargeListArray`](crate::array::LargeListArray) +pub type LargeListBuilder = GenericListBuilder; + +/// Builder for [`BinaryArray`](crate::array::BinaryArray) +/// +/// See examples on [`GenericBinaryBuilder`] +pub type BinaryBuilder = GenericBinaryBuilder; + +/// Builder for [`LargeBinaryArray`](crate::array::LargeBinaryArray) +/// +/// See examples on [`GenericBinaryBuilder`] +pub type LargeBinaryBuilder = GenericBinaryBuilder; + +/// Builder for [`StringArray`](crate::array::StringArray) +/// +/// See examples on [`GenericStringBuilder`] +pub type StringBuilder = GenericStringBuilder; + +/// Builder for [`LargeStringArray`](crate::array::LargeStringArray) +/// +/// See examples on [`GenericStringBuilder`] +pub type LargeStringBuilder = GenericStringBuilder; diff --git a/arrow-array/src/builder/null_builder.rs b/arrow-array/src/builder/null_builder.rs new file mode 100644 index 000000000000..59086dffa907 --- /dev/null +++ b/arrow-array/src/builder/null_builder.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::ArrayBuilder; +use crate::{ArrayRef, NullArray}; +use arrow_data::ArrayData; +use arrow_schema::DataType; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`NullArray`] +/// +/// # Example +/// +/// Create a `NullArray` from a `NullBuilder` +/// +/// ``` +/// +/// # use arrow_array::{Array, NullArray, builder::NullBuilder}; +/// +/// let mut b = NullBuilder::new(); +/// b.append_empty_value(); +/// b.append_null(); +/// b.append_nulls(3); +/// b.append_empty_values(3); +/// let arr = b.finish(); +/// +/// assert_eq!(8, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// ``` +#[derive(Debug)] +pub struct NullBuilder { + len: usize, +} + +impl Default for NullBuilder { + fn default() -> Self { + Self::new() + } +} + +impl NullBuilder { + /// Creates a new null builder + pub fn new() -> Self { + Self { len: 0 } + } + + /// Creates a new null builder with space for `capacity` elements without re-allocating + #[deprecated = "there is no actual notion of capacity in the NullBuilder, so emulating it makes little sense"] + pub fn with_capacity(_capacity: usize) -> Self { + Self::new() + } + + /// Returns the capacity of this builder measured in slots of type `T` + #[deprecated = "there is no actual notion of capacity in the NullBuilder, so emulating it makes little sense"] + pub fn capacity(&self) -> usize { + self.len + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.len += 1; + } + + /// Appends `n` `null`s into the builder. + #[inline] + pub fn append_nulls(&mut self, n: usize) { + self.len += n; + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_empty_value(&mut self) { + self.append_null(); + } + + /// Appends `n` `null`s into the builder. + #[inline] + pub fn append_empty_values(&mut self, n: usize) { + self.append_nulls(n); + } + + /// Builds the [NullArray] and reset this builder. + pub fn finish(&mut self) -> NullArray { + let len = self.len(); + let builder = ArrayData::new_null(&DataType::Null, len).into_builder(); + + let array_data = unsafe { builder.build_unchecked() }; + NullArray::from(array_data) + } + + /// Builds the [NullArray] without resetting the builder. + pub fn finish_cloned(&self) -> NullArray { + let len = self.len(); + let builder = ArrayData::new_null(&DataType::Null, len).into_builder(); + + let array_data = unsafe { builder.build_unchecked() }; + NullArray::from(array_data) + } +} + +impl ArrayBuilder for NullBuilder { + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.len + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Array; + + #[test] + fn test_null_array_builder() { + let mut builder = NullArray::builder(10); + builder.append_null(); + builder.append_nulls(4); + builder.append_empty_value(); + builder.append_empty_values(4); + + let arr = builder.finish(); + assert_eq!(10, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.is_nullable()); + } + + #[test] + fn test_null_array_builder_finish_cloned() { + let mut builder = NullArray::builder(16); + builder.append_null(); + builder.append_empty_value(); + builder.append_empty_values(3); + let mut array = builder.finish_cloned(); + assert_eq!(5, array.len()); + + builder.append_empty_values(5); + array = builder.finish(); + assert_eq!(10, array.len()); + } +} diff --git a/arrow/src/array/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs similarity index 52% rename from arrow/src/array/builder/primitive_builder.rs rename to arrow-array/src/builder/primitive_builder.rs index 38c8b4471477..39b27bfca896 100644 --- a/arrow/src/array/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -15,21 +15,89 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::types::*; +use crate::{ArrayRef, PrimitiveArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::PrimitiveArray; -use crate::datatypes::ArrowPrimitiveType; - -use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; - -/// Array builder for fixed-width primitive types +/// A signed 8-bit integer array builder. +pub type Int8Builder = PrimitiveBuilder; +/// A signed 16-bit integer array builder. +pub type Int16Builder = PrimitiveBuilder; +/// A signed 32-bit integer array builder. +pub type Int32Builder = PrimitiveBuilder; +/// A signed 64-bit integer array builder. +pub type Int64Builder = PrimitiveBuilder; +/// An usigned 8-bit integer array builder. +pub type UInt8Builder = PrimitiveBuilder; +/// An usigned 16-bit integer array builder. +pub type UInt16Builder = PrimitiveBuilder; +/// An usigned 32-bit integer array builder. +pub type UInt32Builder = PrimitiveBuilder; +/// An usigned 64-bit integer array builder. +pub type UInt64Builder = PrimitiveBuilder; +/// A 16-bit floating point array builder. +pub type Float16Builder = PrimitiveBuilder; +/// A 32-bit floating point array builder. +pub type Float32Builder = PrimitiveBuilder; +/// A 64-bit floating point array builder. +pub type Float64Builder = PrimitiveBuilder; + +/// A timestamp second array builder. +pub type TimestampSecondBuilder = PrimitiveBuilder; +/// A timestamp millisecond array builder. +pub type TimestampMillisecondBuilder = PrimitiveBuilder; +/// A timestamp microsecond array builder. +pub type TimestampMicrosecondBuilder = PrimitiveBuilder; +/// A timestamp nanosecond array builder. +pub type TimestampNanosecondBuilder = PrimitiveBuilder; + +/// A 32-bit date array builder. +pub type Date32Builder = PrimitiveBuilder; +/// A 64-bit date array builder. +pub type Date64Builder = PrimitiveBuilder; + +/// A 32-bit elaspsed time in seconds array builder. +pub type Time32SecondBuilder = PrimitiveBuilder; +/// A 32-bit elaspsed time in milliseconds array builder. +pub type Time32MillisecondBuilder = PrimitiveBuilder; +/// A 64-bit elaspsed time in microseconds array builder. +pub type Time64MicrosecondBuilder = PrimitiveBuilder; +/// A 64-bit elaspsed time in nanoseconds array builder. +pub type Time64NanosecondBuilder = PrimitiveBuilder; + +/// A “calendar” interval in months array builder. +pub type IntervalYearMonthBuilder = PrimitiveBuilder; +/// A “calendar” interval in days and milliseconds array builder. +pub type IntervalDayTimeBuilder = PrimitiveBuilder; +/// A “calendar” interval in months, days, and nanoseconds array builder. +pub type IntervalMonthDayNanoBuilder = PrimitiveBuilder; + +/// An elapsed time in seconds array builder. +pub type DurationSecondBuilder = PrimitiveBuilder; +/// An elapsed time in milliseconds array builder. +pub type DurationMillisecondBuilder = PrimitiveBuilder; +/// An elapsed time in microseconds array builder. +pub type DurationMicrosecondBuilder = PrimitiveBuilder; +/// An elapsed time in nanoseconds array builder. +pub type DurationNanosecondBuilder = PrimitiveBuilder; + +/// A decimal 128 array builder +pub type Decimal128Builder = PrimitiveBuilder; +/// A decimal 256 array builder +pub type Decimal256Builder = PrimitiveBuilder; + +/// Builder for [`PrimitiveArray`] #[derive(Debug)] pub struct PrimitiveBuilder { values_builder: BufferBuilder, null_buffer_builder: NullBufferBuilder, + data_type: DataType, } impl ArrayBuilder for PrimitiveBuilder { @@ -53,15 +121,15 @@ impl ArrayBuilder for PrimitiveBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } } impl Default for PrimitiveBuilder { @@ -81,9 +149,47 @@ impl PrimitiveBuilder { Self { values_builder: BufferBuilder::::new(capacity), null_buffer_builder: NullBufferBuilder::new(capacity), + data_type: T::DATA_TYPE, + } + } + + /// Creates a new primitive array builder from buffers + pub fn new_from_buffer( + values_buffer: MutableBuffer, + null_buffer: Option, + ) -> Self { + let values_builder = BufferBuilder::::new_from_buffer(values_buffer); + + let null_buffer_builder = null_buffer + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, values_builder.len())) + .unwrap_or_else(|| NullBufferBuilder::new_with_len(values_builder.len())); + + Self { + values_builder, + null_buffer_builder, + data_type: T::DATA_TYPE, } } + /// By default [`PrimitiveBuilder`] uses [`ArrowPrimitiveType::DATA_TYPE`] as the + /// data type of the generated array. + /// + /// This method allows overriding the data type, to allow specifying timezones + /// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal128`] and [`DataType::Decimal256`] + /// + /// # Panics + /// + /// This method panics if `data_type` is not [PrimitiveArray::is_compatible] + pub fn with_data_type(self, data_type: DataType) -> Self { + assert!( + PrimitiveArray::::is_compatible(&data_type), + "incompatible data type for builder, expected {} got {}", + T::DATA_TYPE, + data_type + ); + Self { data_type, ..self } + } + /// Returns the capacity of this builder measured in slots of type `T` pub fn capacity(&self) -> usize { self.values_builder.capacity() @@ -103,6 +209,7 @@ impl PrimitiveBuilder { self.values_builder.advance(1); } + /// Appends `n` no. of null's into the builder #[inline] pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); @@ -126,6 +233,10 @@ impl PrimitiveBuilder { } /// Appends values from a slice of type `T` and a validity boolean slice + /// + /// # Panics + /// + /// Panics if `values` and `is_valid` have different lengths #[inline] pub fn append_values(&mut self, values: &[T::Native], is_valid: &[bool]) { assert_eq!( @@ -143,10 +254,7 @@ impl PrimitiveBuilder { /// This requires the iterator be a trusted length. This could instead require /// the iterator implement `TrustedLen` once that is stabilized. #[inline] - pub unsafe fn append_trusted_len_iter( - &mut self, - iter: impl IntoIterator, - ) { + pub unsafe fn append_trusted_len_iter(&mut self, iter: impl IntoIterator) { let iter = iter.into_iter(); let len = iter .size_hint() @@ -160,11 +268,25 @@ impl PrimitiveBuilder { /// Builds the [`PrimitiveArray`] and reset this builder. pub fn finish(&mut self) -> PrimitiveArray { let len = self.len(); - let null_bit_buffer = self.null_buffer_builder.finish(); - let builder = ArrayData::builder(T::DATA_TYPE) + let nulls = self.null_buffer_builder.finish(); + let builder = ArrayData::builder(self.data_type.clone()) .len(len) .add_buffer(self.values_builder.finish()) - .null_bit_buffer(null_bit_buffer); + .nulls(nulls); + + let array_data = unsafe { builder.build_unchecked() }; + PrimitiveArray::::from(array_data) + } + + /// Builds the [`PrimitiveArray`] without resetting the builder. + pub fn finish_cloned(&self) -> PrimitiveArray { + let len = self.len(); + let nulls = self.null_buffer_builder.finish_cloned(); + let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let builder = ArrayData::builder(self.data_type.clone()) + .len(len) + .add_buffer(values_buffer) + .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; PrimitiveArray::::from(array_data) @@ -174,19 +296,76 @@ impl PrimitiveBuilder { pub fn values_slice(&self) -> &[T::Native] { self.values_builder.as_slice() } + + /// Returns the current values buffer as a mutable slice + pub fn values_slice_mut(&mut self) -> &mut [T::Native] { + self.values_builder.as_slice_mut() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Returns the current null buffer as a mutable slice + pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { + self.null_buffer_builder.as_slice_mut() + } + + /// Returns the current values buffer and null buffer as a slice + pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) { + ( + self.values_builder.as_slice_mut(), + self.null_buffer_builder.as_slice_mut(), + ) + } +} + +impl PrimitiveBuilder

{ + /// Sets the precision and scale + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { + validate_decimal_precision_and_scale::

(precision, scale)?; + Ok(Self { + data_type: P::TYPE_CONSTRUCTOR(precision, scale), + ..self + }) + } +} + +impl PrimitiveBuilder

{ + /// Sets the timezone + pub fn with_timezone(self, timezone: impl Into>) -> Self { + self.with_timezone_opt(Some(timezone.into())) + } + + /// Sets an optional timezone + pub fn with_timezone_opt>>(self, timezone: Option) -> Self { + Self { + data_type: DataType::Timestamp(P::UNIT, timezone.map(Into::into)), + ..self + } + } +} + +impl Extend> for PrimitiveBuilder

{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } } #[cfg(test)] mod tests { use super::*; + use arrow_schema::TimeUnit; use crate::array::Array; use crate::array::BooleanArray; use crate::array::Date32Array; use crate::array::Int32Array; - use crate::array::Int32Builder; use crate::array::TimestampSecondArray; - use crate::buffer::Buffer; #[test] fn test_primitive_array_builder_i32() { @@ -282,14 +461,14 @@ mod tests { } let arr = builder.finish(); - assert_eq!(&buf, arr.values()); + assert_eq!(&buf, arr.values().inner()); assert_eq!(10, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); for i in 0..10 { assert!(!arr.is_null(i)); assert!(arr.is_valid(i)); - assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {}", i) + assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {i}") } } @@ -377,4 +556,56 @@ mod tests { assert_eq!(5, arr.len()); assert_eq!(0, builder.len()); } + + #[test] + fn test_primitive_array_builder_finish_cloned() { + let mut builder = Int32Builder::new(); + builder.append_value(23); + builder.append_value(45); + let result = builder.finish_cloned(); + assert_eq!(result, Int32Array::from(vec![23, 45])); + builder.append_value(56); + assert_eq!(builder.finish_cloned(), Int32Array::from(vec![23, 45, 56])); + + builder.append_slice(&[2, 4, 6, 8]); + let mut arr = builder.finish(); + assert_eq!(7, arr.len()); + assert_eq!(arr, Int32Array::from(vec![23, 45, 56, 2, 4, 6, 8])); + assert_eq!(0, builder.len()); + + builder.append_slice(&[1, 3, 5, 7, 9]); + arr = builder.finish(); + assert_eq!(5, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_primitive_array_builder_with_data_type() { + let mut builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.precision(), 1); + assert_eq!(array.scale(), 2); + + let data_type = DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())); + let mut builder = TimestampNanosecondBuilder::new().with_data_type(data_type.clone()); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.data_type(), &data_type); + } + + #[test] + #[should_panic(expected = "incompatible data type for builder, expected Int32 got Int64")] + fn test_invalid_with_data_type() { + Int32Builder::new().with_data_type(DataType::Int64); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveBuilder::::new(); + builder.extend([1, 2, 3, 5, 2, 4, 4].into_iter().map(Some)); + builder.extend([2, 4, 6, 2].into_iter().map(Some)); + let array = builder.finish(); + assert_eq!(array.values(), &[1, 2, 3, 5, 2, 4, 4, 2, 4, 6, 2]); + } } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs new file mode 100644 index 000000000000..a764fa4c29c8 --- /dev/null +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, PrimitiveBuilder}; +use crate::types::ArrowDictionaryKeyType; +use crate::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray}; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; + +/// Wraps a type implementing `ToByteSlice` implementing `Hash` and `Eq` for it +/// +/// This is necessary to handle types such as f32, which don't natively implement these +#[derive(Debug)] +struct Value(T); + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.0.to_byte_slice().eq(other.0.to_byte_slice()) + } +} + +impl Eq for Value {} + +/// Builder for [`DictionaryArray`] of [`PrimitiveArray`](crate::array::PrimitiveArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::{UInt32Type, UInt8Type}; +/// # use arrow_array::{Array, UInt32Array, UInt8Array}; +/// +/// let mut builder = PrimitiveDictionaryBuilder::::new(); +/// builder.append(12345678).unwrap(); +/// builder.append_null(); +/// builder.append(22345678).unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &UInt8Array::from(vec![Some(0), None, Some(1)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &UInt32Array = av.as_any().downcast_ref::().unwrap(); +/// let avs: &[u32] = ava.values(); +/// +/// assert!(!array.is_null(0)); +/// assert!(array.is_null(1)); +/// assert!(!array.is_null(2)); +/// +/// assert_eq!(avs, &[12345678, 22345678]); +/// ``` +#[derive(Debug)] +pub struct PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + map: HashMap, usize>, +} + +impl Default for PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + fn default() -> Self { + Self::new() + } +} + +impl PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + /// Creates a new `PrimitiveDictionaryBuilder`. + pub fn new() -> Self { + Self { + keys_builder: PrimitiveBuilder::new(), + values_builder: PrimitiveBuilder::new(), + map: HashMap::new(), + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` from the provided keys and values builders. + /// + /// # Panics + /// + /// This method panics if `keys_builder` or `values_builder` is not empty. + pub fn new_from_empty_builders( + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + ) -> Self { + assert!( + keys_builder.is_empty() && values_builder.is_empty(), + "keys and values builders must be empty" + ); + Self { + keys_builder, + values_builder, + map: HashMap::new(), + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` from existing `PrimitiveBuilder`s of keys and values. + /// + /// # Safety + /// + /// caller must ensure that the passed in builders are valid for DictionaryArray. + pub unsafe fn new_from_builders( + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + ) -> Self { + let keys = keys_builder.values_slice(); + let values = values_builder.values_slice(); + let mut map = HashMap::with_capacity(values.len()); + + keys.iter().zip(values.iter()).for_each(|(key, value)| { + map.insert(Value(*value), K::Native::to_usize(*key).unwrap()); + }); + + Self { + keys_builder, + values_builder, + map, + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` with the provided capacities + /// + /// `keys_capacity`: the number of keys, i.e. length of array to build + /// `values_capacity`: the number of distinct dictionary values, i.e. size of dictionary + pub fn with_capacity(keys_capacity: usize, values_capacity: usize) -> Self { + Self { + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder: PrimitiveBuilder::with_capacity(values_capacity), + map: HashMap::with_capacity(values_capacity), + } + } +} + +impl ArrayBuilder for PrimitiveDictionaryBuilder +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + /// Returns the builder as an non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as an mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.keys_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl PrimitiveDictionaryBuilder +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + /// Append a primitive value to the array. Return an existing index + /// if already present in the values array or a new index if the + /// value is appended to the values array. + #[inline] + pub fn append(&mut self, value: V::Native) -> Result { + let key = match self.map.entry(Value(value)) { + Entry::Vacant(vacant) => { + // Append new value. + let key = self.values_builder.len(); + self.values_builder.append_value(value); + vacant.insert(key); + K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)? + } + Entry::Occupied(o) => K::Native::usize_as(*o.get()), + }; + + self.keys_builder.append_value(key); + Ok(key) + } + + /// Infallibly append a value to this builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_value(&mut self, value: V::Native) { + self.append(value).expect("dictionary key overflow"); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.keys_builder.append_null() + } + + /// Append an `Option` value into the builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_option(&mut self, value: Option) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Builds the `DictionaryArray` and reset this builder. + pub fn finish(&mut self) -> DictionaryArray { + self.map.clear(); + let values = self.values_builder.finish(); + let keys = self.keys_builder.finish(); + + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Builds the `DictionaryArray` without resetting the builder. + pub fn finish_cloned(&self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish_cloned(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Returns the current dictionary values buffer as a slice + pub fn values_slice(&self) -> &[V::Native] { + self.values_builder.values_slice() + } + + /// Returns the current dictionary values buffer as a mutable slice + pub fn values_slice_mut(&mut self) -> &mut [V::Native] { + self.values_builder.values_slice_mut() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.keys_builder.validity_slice() + } +} + +impl Extend> + for PrimitiveDictionaryBuilder +{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::UInt32Array; + use crate::array::UInt8Array; + use crate::builder::Decimal128Builder; + use crate::types::{Decimal128Type, Int32Type, UInt32Type, UInt8Type}; + + #[test] + fn test_primitive_dictionary_builder() { + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); + builder.append(12345678).unwrap(); + builder.append_null(); + builder.append(22345678).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &UInt8Array::from(vec![Some(0), None, Some(1)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &UInt32Array = av.as_any().downcast_ref::().unwrap(); + let avs: &[u32] = ava.values(); + + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert!(!array.is_null(2)); + + assert_eq!(avs, &[12345678, 22345678]); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some)); + builder.extend([4, 5, 1, 3, 1].into_iter().map(Some)); + let dict = builder.finish(); + assert_eq!( + dict.keys().values(), + &[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 0, 2, 0] + ); + assert_eq!(dict.values().len(), 5); + } + + #[test] + #[should_panic(expected = "DictionaryKeyOverflowError")] + fn test_primitive_dictionary_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(257, 257); + // 256 unique keys. + for i in 0..256 { + builder.append(i + 1000).unwrap(); + } + // Special error if the key overflows (256th entry) + builder.append(1257).unwrap(); + } + + #[test] + fn test_primitive_dictionary_with_builders() { + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + let mut builder = + PrimitiveDictionaryBuilder::::new_from_empty_builders( + keys_builder, + values_builder, + ); + let dict_array = builder.finish(); + assert_eq!(dict_array.value_type(), DataType::Decimal128(1, 2)); + assert_eq!( + dict_array.data_type(), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal128(1, 2)), + ) + ); + } +} diff --git a/arrow-array/src/builder/primitive_run_builder.rs b/arrow-array/src/builder/primitive_run_builder.rs new file mode 100644 index 000000000000..01a989199b58 --- /dev/null +++ b/arrow-array/src/builder/primitive_run_builder.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use crate::{types::RunEndIndexType, ArrayRef, ArrowPrimitiveType, RunArray}; + +use super::{ArrayBuilder, PrimitiveBuilder}; + +use arrow_buffer::ArrowNativeType; + +/// Builder for [`RunArray`] of [`PrimitiveArray`](crate::array::PrimitiveArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::PrimitiveRunBuilder; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::{UInt32Type, Int16Type}; +/// # use arrow_array::{Array, UInt32Array, Int16Array}; +/// +/// let mut builder = +/// PrimitiveRunBuilder::::new(); +/// builder.append_value(1234); +/// builder.append_value(1234); +/// builder.append_value(1234); +/// builder.append_null(); +/// builder.append_value(5678); +/// builder.append_value(5678); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[3, 4, 6]); +/// +/// let av = array.values(); +/// +/// assert!(!av.is_null(0)); +/// assert!(av.is_null(1)); +/// assert!(!av.is_null(2)); +/// +/// // Values are polymorphic and so require a downcast. +/// let ava: &UInt32Array = av.as_primitive::(); +/// +/// assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); +/// ``` +#[derive(Debug)] +pub struct PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + run_ends_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + current_value: Option, + current_run_end_index: usize, + prev_run_end_index: usize, +} + +impl Default for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + fn default() -> Self { + Self::new() + } +} + +impl PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Creates a new `PrimitiveRunBuilder` + pub fn new() -> Self { + Self { + run_ends_builder: PrimitiveBuilder::new(), + values_builder: PrimitiveBuilder::new(), + current_value: None, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } + + /// Creates a new `PrimitiveRunBuilder` with the provided capacity + /// + /// `capacity`: the expected number of run-end encoded values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + run_ends_builder: PrimitiveBuilder::with_capacity(capacity), + values_builder: PrimitiveBuilder::with_capacity(capacity), + current_value: None, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } +} + +impl ArrayBuilder for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the length of logical array encoded by + /// the eventual runs array. + fn len(&self) -> usize { + self.current_run_end_index + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Appends optional value to the logical array encoded by the RunArray. + pub fn append_option(&mut self, value: Option) { + if self.current_run_end_index == 0 { + self.current_run_end_index = 1; + self.current_value = value; + return; + } + if self.current_value != value { + self.append_run_end(); + self.current_value = value; + } + + self.current_run_end_index += 1; + } + + /// Appends value to the logical array encoded by the run-ends array. + pub fn append_value(&mut self, value: V::Native) { + self.append_option(Some(value)) + } + + /// Appends null to the logical array encoded by the run-ends array. + pub fn append_null(&mut self) { + self.append_option(None) + } + + /// Creates the RunArray and resets the builder. + /// Panics if RunArray cannot be built. + pub fn finish(&mut self) -> RunArray { + // write the last run end to the array. + self.append_run_end(); + + // reset the run index to zero. + self.current_value = None; + self.current_run_end_index = 0; + + // build the run encoded array by adding run_ends and values array as its children. + let run_ends_array = self.run_ends_builder.finish(); + let values_array = self.values_builder.finish(); + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + /// Creates the RunArray and without resetting the builder. + /// Panics if RunArray cannot be built. + pub fn finish_cloned(&self) -> RunArray { + let mut run_ends_array = self.run_ends_builder.finish_cloned(); + let mut values_array = self.values_builder.finish_cloned(); + + // Add current run if one exists + if self.prev_run_end_index != self.current_run_end_index { + let mut run_end_builder = run_ends_array.into_builder().unwrap(); + let mut values_builder = values_array.into_builder().unwrap(); + self.append_run_end_with_builders(&mut run_end_builder, &mut values_builder); + run_ends_array = run_end_builder.finish(); + values_array = values_builder.finish(); + } + + RunArray::try_new(&run_ends_array, &values_array).unwrap() + } + + // Appends the current run to the array. + fn append_run_end(&mut self) { + // empty array or the function called without appending any value. + if self.prev_run_end_index == self.current_run_end_index { + return; + } + let run_end_index = self.run_end_index_as_native(); + self.run_ends_builder.append_value(run_end_index); + self.values_builder.append_option(self.current_value); + self.prev_run_end_index = self.current_run_end_index; + } + + // Similar to `append_run_end` but on custom builders. + // Used in `finish_cloned` which is not suppose to mutate `self`. + fn append_run_end_with_builders( + &self, + run_ends_builder: &mut PrimitiveBuilder, + values_builder: &mut PrimitiveBuilder, + ) { + let run_end_index = self.run_end_index_as_native(); + run_ends_builder.append_value(run_end_index); + values_builder.append_option(self.current_value); + } + + fn run_end_index_as_native(&self) -> R::Native { + R::Native::from_usize(self.current_run_end_index) + .unwrap_or_else(|| panic!( + "Cannot convert `current_run_end_index` {} from `usize` to native form of arrow datatype {}", + self.current_run_end_index, + R::DATA_TYPE + )) + } +} + +impl Extend> for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + fn extend>>(&mut self, iter: T) { + for elem in iter { + self.append_option(elem); + } + } +} + +#[cfg(test)] +mod tests { + use crate::builder::PrimitiveRunBuilder; + use crate::cast::AsArray; + use crate::types::{Int16Type, UInt32Type}; + use crate::{Array, UInt32Array}; + + #[test] + fn test_primitive_ree_array_builder() { + let mut builder = PrimitiveRunBuilder::::new(); + builder.append_value(1234); + builder.append_value(1234); + builder.append_value(1234); + builder.append_null(); + builder.append_value(5678); + builder.append_value(5678); + + let array = builder.finish(); + + assert_eq!(array.null_count(), 0); + assert_eq!(array.len(), 6); + + assert_eq!(array.run_ends().values(), &[3, 4, 6]); + + let av = array.values(); + + assert!(!av.is_null(0)); + assert!(av.is_null(1)); + assert!(!av.is_null(2)); + + // Values are polymorphic and so require a downcast. + let ava: &UInt32Array = av.as_primitive::(); + + assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend([1, 2, 2, 5, 5, 4, 4].into_iter().map(Some)); + builder.extend([4, 4, 6, 2].into_iter().map(Some)); + let array = builder.finish(); + + assert_eq!(array.len(), 11); + assert_eq!(array.null_count(), 0); + assert_eq!(array.run_ends().values(), &[1, 3, 5, 9, 10, 11]); + assert_eq!( + array.values().as_primitive::().values(), + &[1, 2, 5, 4, 6, 2] + ); + } +} diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs new file mode 100644 index 000000000000..c0e49b939f2c --- /dev/null +++ b/arrow-array/src/builder/struct_builder.rs @@ -0,0 +1,730 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::*; +use crate::StructArray; +use arrow_buffer::NullBufferBuilder; +use arrow_schema::{DataType, Fields, IntervalUnit, SchemaBuilder, TimeUnit}; +use std::sync::Arc; + +/// Builder for [`StructArray`] +/// +/// Note that callers should make sure that methods of all the child field builders are +/// properly called to maintain the consistency of the data structure. +/// +/// +/// Handling arrays with complex layouts, such as `List>>`, in Rust can be challenging due to its strong typing system. +/// To construct a collection builder ([`ListBuilder`], [`LargeListBuilder`], or [`MapBuilder`]) using [`make_builder`], multiple calls are required. This complexity arises from the recursive approach utilized by [`StructBuilder::from_fields`]. +/// +/// Initially, [`StructBuilder::from_fields`] invokes [`make_builder`], which returns a `Box`. To obtain the specific collection builder, one must first use [`StructBuilder::field_builder`] to get a `Collection<[Box]>`. Subsequently, the `values()` result from this operation can be downcast to the desired builder type. +/// +/// For example, when working with [`ListBuilder`], you would first call [`StructBuilder::field_builder::>>`] and then downcast the [`Box`] to the specific [`StructBuilder`] you need. +/// +/// For a practical example see the code below: +/// +/// ```rust +/// use arrow_array::builder::{ArrayBuilder, ListBuilder, StringBuilder, StructBuilder}; +/// use arrow_schema::{DataType, Field, Fields}; +/// use std::sync::Arc; +/// +/// // This is an example column that has a List>> layout +/// let mut example_col = ListBuilder::new(StructBuilder::from_fields( +/// vec![Field::new( +/// "value_list", +/// DataType::List(Arc::new(Field::new( +/// "item", +/// DataType::Struct(Fields::from(vec![ +/// Field::new("key", DataType::Utf8, true), +/// Field::new("value", DataType::Utf8, true), +/// ])), //In this example we are trying to get to this builder and insert key/value pairs +/// true, +/// ))), +/// true, +/// )], +/// 0, +/// )); +/// +/// // We can obtain the StructBuilder without issues, because example_col was created with StructBuilder +/// let col_struct_builder: &mut StructBuilder = example_col.values(); +/// +/// // We can't obtain the ListBuilder with the expected generic types, because under the hood +/// // the StructBuilder was returned as a Box and passed as such to the ListBuilder constructor +/// +/// // This panics in runtime, even though we know that the builder is a ListBuilder. +/// // let sb = col_struct_builder +/// // .field_builder::>(0) +/// // .as_mut() +/// // .unwrap(); +/// +/// //To keep in line with Rust's strong typing, we fetch a ListBuilder> from the column StructBuilder first... +/// let mut list_builder_option = +/// col_struct_builder.field_builder::>>(0); +/// +/// let list_builder = list_builder_option.as_mut().unwrap(); +/// +/// // ... and then downcast the key/value pair values to a StructBuilder +/// let struct_builder = list_builder +/// .values() +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap(); +/// +/// // We can now append values to the StructBuilder +/// let key_builder = struct_builder.field_builder::(0).unwrap(); +/// key_builder.append_value("my key"); +/// +/// let value_builder = struct_builder.field_builder::(1).unwrap(); +/// value_builder.append_value("my value"); +/// +/// struct_builder.append(true); +/// list_builder.append(true); +/// col_struct_builder.append(true); +/// example_col.append(true); +/// +/// let array = example_col.finish(); +/// +/// println!("My array: {:?}", array); +/// ``` +/// +pub struct StructBuilder { + fields: Fields, + field_builders: Vec>, + null_buffer_builder: NullBufferBuilder, +} + +impl std::fmt::Debug for StructBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StructBuilder") + .field("fields", &self.fields) + .field("bitmap_builder", &self.null_buffer_builder) + .field("len", &self.len()) + .finish() + } +} + +impl ArrayBuilder for StructBuilder { + /// Returns the number of array slots in the builder. + /// + /// Note that this always return the first child field builder's length, and it is + /// the caller's responsibility to maintain the consistency that all the child field + /// builder should have the equal number of elements. + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + + /// Returns the builder as a non-mutable `Any` reference. + /// + /// This is most useful when one wants to call non-mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_ref` to get a reference on the specific builder. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + /// + /// This is most useful when one wants to call mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_mut` to get a reference on the specific builder. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } +} + +/// Returns a builder with capacity `capacity` that corresponds to the datatype `DataType` +/// This function is useful to construct arrays from an arbitrary vectors with known/expected +/// schema. +/// +/// See comments on StructBuilder on how to retreive collection builders built by make_builder. +pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { + use crate::builder::*; + match datatype { + DataType::Null => Box::new(NullBuilder::new()), + DataType::Boolean => Box::new(BooleanBuilder::with_capacity(capacity)), + DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)), + DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)), + DataType::Int32 => Box::new(Int32Builder::with_capacity(capacity)), + DataType::Int64 => Box::new(Int64Builder::with_capacity(capacity)), + DataType::UInt8 => Box::new(UInt8Builder::with_capacity(capacity)), + DataType::UInt16 => Box::new(UInt16Builder::with_capacity(capacity)), + DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)), + DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)), + DataType::Float16 => Box::new(Float16Builder::with_capacity(capacity)), + DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)), + DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)), + DataType::Binary => Box::new(BinaryBuilder::with_capacity(capacity, 1024)), + DataType::LargeBinary => Box::new(LargeBinaryBuilder::with_capacity(capacity, 1024)), + DataType::FixedSizeBinary(len) => { + Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) + } + DataType::Decimal128(p, s) => Box::new( + Decimal128Builder::with_capacity(capacity).with_data_type(DataType::Decimal128(*p, *s)), + ), + DataType::Decimal256(p, s) => Box::new( + Decimal256Builder::with_capacity(capacity).with_data_type(DataType::Decimal256(*p, *s)), + ), + DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, 1024)), + DataType::LargeUtf8 => Box::new(LargeStringBuilder::with_capacity(capacity, 1024)), + DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), + DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), + DataType::Time32(TimeUnit::Second) => { + Box::new(Time32SecondBuilder::with_capacity(capacity)) + } + DataType::Time32(TimeUnit::Millisecond) => { + Box::new(Time32MillisecondBuilder::with_capacity(capacity)) + } + DataType::Time64(TimeUnit::Microsecond) => { + Box::new(Time64MicrosecondBuilder::with_capacity(capacity)) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Box::new(Time64NanosecondBuilder::with_capacity(capacity)) + } + DataType::Timestamp(TimeUnit::Second, tz) => Box::new( + TimestampSecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Second, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Box::new( + TimestampMillisecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Box::new( + TimestampMicrosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Box::new( + TimestampNanosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Nanosecond, tz.clone())), + ), + DataType::Interval(IntervalUnit::YearMonth) => { + Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) + } + DataType::Interval(IntervalUnit::DayTime) => { + Box::new(IntervalDayTimeBuilder::with_capacity(capacity)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Box::new(IntervalMonthDayNanoBuilder::with_capacity(capacity)) + } + DataType::Duration(TimeUnit::Second) => { + Box::new(DurationSecondBuilder::with_capacity(capacity)) + } + DataType::Duration(TimeUnit::Millisecond) => { + Box::new(DurationMillisecondBuilder::with_capacity(capacity)) + } + DataType::Duration(TimeUnit::Microsecond) => { + Box::new(DurationMicrosecondBuilder::with_capacity(capacity)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + Box::new(DurationNanosecondBuilder::with_capacity(capacity)) + } + DataType::List(field) => { + let builder = make_builder(field.data_type(), capacity); + Box::new(ListBuilder::with_capacity(builder, capacity).with_field(field.clone())) + } + DataType::LargeList(field) => { + let builder = make_builder(field.data_type(), capacity); + Box::new(LargeListBuilder::with_capacity(builder, capacity).with_field(field.clone())) + } + DataType::Map(field, _) => match field.data_type() { + DataType::Struct(fields) => { + let map_field_names = MapFieldNames { + key: fields[0].name().clone(), + value: fields[1].name().clone(), + entry: field.name().clone(), + }; + let key_builder = make_builder(fields[0].data_type(), capacity); + let value_builder = make_builder(fields[1].data_type(), capacity); + Box::new( + MapBuilder::with_capacity( + Some(map_field_names), + key_builder, + value_builder, + capacity, + ) + .with_values_field(fields[1].clone()), + ) + } + t => panic!("The field of Map data type {t:?} should has a child Struct field"), + }, + DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), + t => panic!("Data type {t:?} is not currently supported"), + } +} + +impl StructBuilder { + /// Creates a new `StructBuilder` + pub fn new(fields: impl Into, field_builders: Vec>) -> Self { + Self { + field_builders, + fields: fields.into(), + null_buffer_builder: NullBufferBuilder::new(0), + } + } + + /// Creates a new `StructBuilder` from [`Fields`] and `capacity` + pub fn from_fields(fields: impl Into, capacity: usize) -> Self { + let fields = fields.into(); + let mut builders = Vec::with_capacity(fields.len()); + for field in &fields { + builders.push(make_builder(field.data_type(), capacity)); + } + Self::new(fields, builders) + } + + /// Returns a mutable reference to the child field builder at index `i`. + /// Result will be `None` if the input type `T` provided doesn't match the actual + /// field builder's type. + pub fn field_builder(&mut self, i: usize) -> Option<&mut T> { + self.field_builders[i].as_any_mut().downcast_mut::() + } + + /// Returns the number of fields for the struct this builder is building. + pub fn num_fields(&self) -> usize { + self.field_builders.len() + } + + /// Appends an element (either null or non-null) to the struct. The actual elements + /// should be appended for each child sub-array in a consistent way. + #[inline] + pub fn append(&mut self, is_valid: bool) { + self.null_buffer_builder.append(is_valid); + } + + /// Appends a null element to the struct. + #[inline] + pub fn append_null(&mut self) { + self.append(false) + } + + /// Builds the `StructArray` and reset this builder. + pub fn finish(&mut self) -> StructArray { + self.validate_content(); + if self.fields.is_empty() { + return StructArray::new_empty_fields(self.len(), self.null_buffer_builder.finish()); + } + + let arrays = self.field_builders.iter_mut().map(|f| f.finish()).collect(); + let nulls = self.null_buffer_builder.finish(); + StructArray::new(self.fields.clone(), arrays, nulls) + } + + /// Builds the `StructArray` without resetting the builder. + pub fn finish_cloned(&self) -> StructArray { + self.validate_content(); + + if self.fields.is_empty() { + return StructArray::new_empty_fields( + self.len(), + self.null_buffer_builder.finish_cloned(), + ); + } + + let arrays = self + .field_builders + .iter() + .map(|f| f.finish_cloned()) + .collect(); + + let nulls = self.null_buffer_builder.finish_cloned(); + + StructArray::new(self.fields.clone(), arrays, nulls) + } + + /// Constructs and validates contents in the builder to ensure that + /// - fields and field_builders are of equal length + /// - the number of items in individual field_builders are equal to self.len() + fn validate_content(&self) { + if self.fields.len() != self.field_builders.len() { + panic!("Number of fields is not equal to the number of field_builders."); + } + self.field_builders.iter().enumerate().for_each(|(idx, x)| { + if x.len() != self.len() { + let builder = SchemaBuilder::from(&self.fields); + let schema = builder.finish(); + + panic!("{}", format!( + "StructBuilder ({:?}) and field_builder with index {} ({:?}) are of unequal lengths: ({} != {}).", + schema, + idx, + self.fields[idx].data_type(), + self.len(), + x.len() + )); + } + }); + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::Buffer; + use arrow_data::ArrayData; + use arrow_schema::Field; + + use crate::array::Array; + + #[test] + fn test_struct_array_builder() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + + let fields = vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ]; + let field_builders = vec![ + Box::new(string_builder) as Box, + Box::new(int_builder) as Box, + ]; + + let mut builder = StructBuilder::new(fields, field_builders); + assert_eq!(2, builder.num_fields()); + + let string_builder = builder + .field_builder::(0) + .expect("builder at field 0 should be string builder"); + string_builder.append_value("joe"); + string_builder.append_null(); + string_builder.append_null(); + string_builder.append_value("mark"); + + let int_builder = builder + .field_builder::(1) + .expect("builder at field 1 should be int builder"); + int_builder.append_value(1); + int_builder.append_value(2); + int_builder.append_null(); + int_builder.append_value(4); + + builder.append(true); + builder.append(true); + builder.append_null(); + builder.append(true); + + let struct_data = builder.finish().into_data(); + + assert_eq!(4, struct_data.len()); + assert_eq!(1, struct_data.null_count()); + assert_eq!(&[11_u8], struct_data.nulls().unwrap().validity()); + + let expected_string_data = ArrayData::builder(DataType::Utf8) + .len(4) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) + .add_buffer(Buffer::from_slice_ref([0, 3, 3, 3, 7])) + .add_buffer(Buffer::from_slice_ref(b"joemark")) + .build() + .unwrap(); + + let expected_int_data = ArrayData::builder(DataType::Int32) + .len(4) + .null_bit_buffer(Some(Buffer::from_slice_ref([11_u8]))) + .add_buffer(Buffer::from_slice_ref([1, 2, 0, 4])) + .build() + .unwrap(); + + assert_eq!(expected_string_data, struct_data.child_data()[0]); + assert_eq!(expected_int_data, struct_data.child_data()[1]); + } + + #[test] + fn test_struct_array_builder_finish() { + let int_builder = Int32Builder::new(); + let bool_builder = BooleanBuilder::new(); + + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; + + let mut builder = StructBuilder::new(fields, field_builders); + builder + .field_builder::(0) + .unwrap() + .append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[ + false, true, false, true, false, true, false, true, false, true, + ]); + + // Append slot values - all are valid. + for _ in 0..10 { + builder.append(true); + } + + assert_eq!(10, builder.len()); + + let arr = builder.finish(); + + assert_eq!(10, arr.len()); + assert_eq!(0, builder.len()); + + builder + .field_builder::(0) + .unwrap() + .append_slice(&[1, 3, 5, 7, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[false, true, false, true, false]); + + // Append slot values - all are valid. + for _ in 0..5 { + builder.append(true); + } + + assert_eq!(5, builder.len()); + + let arr = builder.finish(); + + assert_eq!(5, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_struct_array_builder_finish_cloned() { + let int_builder = Int32Builder::new(); + let bool_builder = BooleanBuilder::new(); + + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; + + let mut builder = StructBuilder::new(fields, field_builders); + builder + .field_builder::(0) + .unwrap() + .append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[ + false, true, false, true, false, true, false, true, false, true, + ]); + + // Append slot values - all are valid. + for _ in 0..10 { + builder.append(true); + } + + assert_eq!(10, builder.len()); + + let mut arr = builder.finish_cloned(); + + assert_eq!(10, arr.len()); + assert_eq!(10, builder.len()); + + builder + .field_builder::(0) + .unwrap() + .append_slice(&[1, 3, 5, 7, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[false, true, false, true, false]); + + // Append slot values - all are valid. + for _ in 0..5 { + builder.append(true); + } + + assert_eq!(15, builder.len()); + + arr = builder.finish(); + + assert_eq!(15, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_struct_array_builder_from_schema() { + let mut fields = vec![ + Field::new("f1", DataType::Float32, false), + Field::new("f2", DataType::Utf8, false), + ]; + let sub_fields = vec![ + Field::new("g1", DataType::Int32, false), + Field::new("g2", DataType::Boolean, false), + ]; + let struct_type = DataType::Struct(sub_fields.into()); + fields.push(Field::new("f3", struct_type, false)); + + let mut builder = StructBuilder::from_fields(fields, 5); + assert_eq!(3, builder.num_fields()); + assert!(builder.field_builder::(0).is_some()); + assert!(builder.field_builder::(1).is_some()); + assert!(builder.field_builder::(2).is_some()); + } + + #[test] + fn test_datatype_properties() { + let fields = Fields::from(vec![ + Field::new("f1", DataType::Decimal128(1, 2), false), + Field::new( + "f2", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + false, + ), + ]); + let mut builder = StructBuilder::from_fields(fields.clone(), 1); + builder + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .field_builder::(1) + .unwrap() + .append_value(1); + builder.append(true); + let array = builder.finish(); + + assert_eq!(array.data_type(), &DataType::Struct(fields.clone())); + assert_eq!(array.column(0).data_type(), fields[0].data_type()); + assert_eq!(array.column(1).data_type(), fields[1].data_type()); + } + + #[test] + #[should_panic(expected = "Data type Dictionary(Int32, Utf8) is not currently supported")] + fn test_struct_array_builder_from_schema_unsupported_type() { + let fields = vec![ + Field::new("f1", DataType::Int16, false), + Field::new( + "f2", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + ]; + + let _ = StructBuilder::from_fields(fields, 5); + } + + #[test] + fn test_struct_array_builder_field_builder_type_mismatch() { + let int_builder = Int32Builder::with_capacity(10); + + let fields = vec![Field::new("f1", DataType::Int32, false)]; + let field_builders = vec![Box::new(int_builder) as Box]; + + let mut builder = StructBuilder::new(fields, field_builders); + assert!(builder.field_builder::(0).is_none()); + } + + #[test] + #[should_panic( + expected = "StructBuilder (Schema { fields: [Field { name: \"f1\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"f2\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }) and field_builder with index 1 (Boolean) are of unequal lengths: (2 != 1)." + )] + fn test_struct_array_builder_unequal_field_builders_lengths() { + let mut int_builder = Int32Builder::with_capacity(10); + let mut bool_builder = BooleanBuilder::new(); + + int_builder.append_value(1); + int_builder.append_value(2); + bool_builder.append_value(true); + + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; + + let mut builder = StructBuilder::new(fields, field_builders); + builder.append(true); + builder.append(true); + builder.finish(); + } + + #[test] + #[should_panic(expected = "Number of fields is not equal to the number of field_builders.")] + fn test_struct_array_builder_unequal_field_field_builders() { + let int_builder = Int32Builder::with_capacity(10); + + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![Box::new(int_builder) as Box]; + + let mut builder = StructBuilder::new(fields, field_builders); + builder.finish(); + } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"timestamp\\\", expected Timestamp(Nanosecond, Some(\\\"UTC\\\")) got Timestamp(Nanosecond, None)" + )] + fn test_struct_array_mismatch_builder() { + let fields = vec![Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned().into())), + false, + )]; + + let field_builders: Vec> = + vec![Box::new(TimestampNanosecondBuilder::new())]; + + let mut sa = StructBuilder::new(fields, field_builders); + sa.finish(); + } + + #[test] + fn test_empty() { + let mut builder = StructBuilder::new(Fields::empty(), vec![]); + builder.append(true); + builder.append(false); + + let a1 = builder.finish_cloned(); + let a2 = builder.finish(); + assert_eq!(a1, a2); + assert_eq!(a1.len(), 2); + assert_eq!(a1.null_count(), 1); + assert!(a1.is_valid(0)); + assert!(a1.is_null(1)); + } +} diff --git a/arrow/src/array/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs similarity index 70% rename from arrow/src/array/builder/union_builder.rs rename to arrow-array/src/builder/union_builder.rs index c0ae76853dd2..e6184f4ac6d2 100644 --- a/arrow/src/array/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -15,23 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder}; +use crate::builder::BufferBuilder; +use crate::{make_array, ArrowPrimitiveType, UnionArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{ArrowNativeType, Buffer}; +use arrow_data::ArrayDataBuilder; +use arrow_schema::{ArrowError, DataType, Field}; use std::any::Any; -use std::collections::HashMap; - -use crate::array::ArrayDataBuilder; -use crate::array::Int32BufferBuilder; -use crate::array::Int8BufferBuilder; -use crate::array::UnionArray; -use crate::buffer::Buffer; - -use crate::datatypes::DataType; -use crate::datatypes::Field; -use crate::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use crate::error::{ArrowError, Result}; - -use super::{BufferBuilder, NullBufferBuilder}; - -use crate::array::make_array; +use std::collections::BTreeMap; +use std::sync::Arc; /// `FieldData` is a helper struct to track the state of the fields in the `UnionBuilder`. #[derive(Debug)] @@ -73,11 +66,7 @@ impl FieldDataValues for BufferBuilder { impl FieldData { /// Creates a new `FieldData`. - fn new( - type_id: i8, - data_type: DataType, - capacity: usize, - ) -> Self { + fn new(type_id: i8, data_type: DataType, capacity: usize) -> Self { Self { type_id, data_type, @@ -107,13 +96,13 @@ impl FieldData { } } -/// Builder type for creating a new `UnionArray`. +/// Builder for [`UnionArray`] /// /// Example: **Dense Memory Layout** /// /// ``` -/// use arrow::array::UnionBuilder; -/// use arrow::datatypes::{Float64Type, Int32Type}; +/// # use arrow_array::builder::UnionBuilder; +/// # use arrow_array::types::{Float64Type, Int32Type}; /// /// let mut builder = UnionBuilder::new_dense(); /// builder.append::("a", 1).unwrap(); @@ -121,19 +110,19 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 0_i32); -/// assert_eq!(union.value_offset(2), 1_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 0); +/// assert_eq!(union.value_offset(2), 1); /// ``` /// /// Example: **Sparse Memory Layout** /// ``` -/// use arrow::array::UnionBuilder; -/// use arrow::datatypes::{Float64Type, Int32Type}; +/// # use arrow_array::builder::UnionBuilder; +/// # use arrow_array::types::{Float64Type, Int32Type}; /// /// let mut builder = UnionBuilder::new_sparse(); /// builder.append::("a", 1).unwrap(); @@ -141,20 +130,20 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 1_i32); -/// assert_eq!(union.value_offset(2), 2_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 1); +/// assert_eq!(union.value_offset(2), 2); /// ``` #[derive(Debug)] pub struct UnionBuilder { /// The current number of slots in the array len: usize, /// Maps field names to `FieldData` instances which track the builders for that field - fields: HashMap, + fields: BTreeMap, /// Builder to keep track of type ids type_id_builder: Int8BufferBuilder, /// Builder to keep track of offsets (`None` for sparse unions) @@ -177,7 +166,7 @@ impl UnionBuilder { pub fn with_capacity_dense(capacity: usize) -> Self { Self { len: 0, - fields: HashMap::default(), + fields: Default::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: Some(Int32BufferBuilder::new(capacity)), initial_capacity: capacity, @@ -188,7 +177,7 @@ impl UnionBuilder { pub fn with_capacity_sparse(capacity: usize) -> Self { Self { len: 0, - fields: HashMap::default(), + fields: Default::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: None, initial_capacity: capacity, @@ -203,7 +192,10 @@ impl UnionBuilder { /// is part of the final array, appending a NULL requires /// specifying which field (child) to use. #[inline] - pub fn append_null(&mut self, type_name: &str) -> Result<()> { + pub fn append_null( + &mut self, + type_name: &str, + ) -> Result<(), ArrowError> { self.append_option::(type_name, None) } @@ -213,7 +205,7 @@ impl UnionBuilder { &mut self, type_name: &str, v: T::Native, - ) -> Result<()> { + ) -> Result<(), ArrowError> { self.append_option::(type_name, Some(v)) } @@ -221,13 +213,18 @@ impl UnionBuilder { &mut self, type_name: &str, v: Option, - ) -> Result<()> { + ) -> Result<(), ArrowError> { let type_name = type_name.to_string(); let mut field_data = match self.fields.remove(&type_name) { Some(data) => { if data.data_type != T::DATA_TYPE { - return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type))); + return Err(ArrowError::InvalidArgumentError(format!( + "Attempt to write col \"{}\" with type {} doesn't match existing type {}", + type_name, + T::DATA_TYPE, + data.data_type + ))); } data } @@ -278,40 +275,39 @@ impl UnionBuilder { } /// Builds this builder creating a new `UnionArray`. - pub fn build(mut self) -> Result { - let type_id_buffer = self.type_id_builder.finish(); - let value_offsets_buffer = self.value_offset_builder.map(|mut b| b.finish()); - let mut children = Vec::new(); - for ( - name, - FieldData { - type_id, - data_type, - mut values_buffer, - slots, - null_buffer_builder: mut bitmap_builder, - }, - ) in self.fields.into_iter() - { - let buffer = values_buffer.finish(); - let arr_data_builder = ArrayDataBuilder::new(data_type.clone()) - .add_buffer(buffer) - .len(slots) - .null_bit_buffer(bitmap_builder.finish()); - - let arr_data_ref = unsafe { arr_data_builder.build_unchecked() }; - let array_ref = make_array(arr_data_ref); - children.push((type_id, (Field::new(&name, data_type, false), array_ref))) - } - - children.sort_by(|a, b| { - a.0.partial_cmp(&b.0) - .expect("This will never be None as type ids are always i8 values.") - }); - let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect(); - - let type_ids: Vec = (0_i8..children.len() as i8).collect(); - - UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, children) + pub fn build(self) -> Result { + let mut children = Vec::with_capacity(self.fields.len()); + let union_fields = self + .fields + .into_iter() + .map( + |( + name, + FieldData { + type_id, + data_type, + mut values_buffer, + slots, + mut null_buffer_builder, + }, + )| { + let array_ref = make_array(unsafe { + ArrayDataBuilder::new(data_type.clone()) + .add_buffer(values_buffer.finish()) + .len(slots) + .nulls(null_buffer_builder.finish()) + .build_unchecked() + }); + children.push(array_ref); + (type_id, Arc::new(Field::new(name, data_type, false))) + }, + ) + .collect(); + UnionArray::try_new( + union_fields, + self.type_id_builder.into(), + self.value_offset_builder.map(Into::into), + children, + ) } } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs new file mode 100644 index 000000000000..cda179b78c2e --- /dev/null +++ b/arrow-array/src/cast.rs @@ -0,0 +1,1021 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines helper functions for downcasting [`dyn Array`](Array) to concrete types + +use crate::array::*; +use crate::types::*; +use arrow_data::ArrayData; + +/// Repeats the provided pattern based on the number of comma separated identifiers +#[doc(hidden)] +#[macro_export] +macro_rules! repeat_pat { + ($e:pat, $v_:expr) => { + $e + }; + ($e:pat, $v_:expr $(, $tail:expr)+) => { + ($e, $crate::repeat_pat!($e $(, $tail)+)) + } +} + +/// Given one or more expressions evaluating to an integer [`DataType`] invokes the provided macro +/// `m` with the corresponding integer [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType, downcast_integer}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! dictionary_key_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn dictionary_key_size(t: &DataType) -> u8 { +/// match t { +/// DataType::Dictionary(k, _) => downcast_integer! { +/// k.as_ref() => (dictionary_key_size_helper, u8), +/// _ => unreachable!(), +/// }, +/// _ => u8::MAX, +/// } +/// } +/// +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))), 4); +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8))), 8); +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8))), 2); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_integer { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Int8, $($data_type),+) => { + $m!($crate::types::Int8Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int16, $($data_type),+) => { + $m!($crate::types::Int16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int32, $($data_type),+) => { + $m!($crate::types::Int32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int64, $($data_type),+) => { + $m!($crate::types::Int64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt8, $($data_type),+) => { + $m!($crate::types::UInt8Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt16, $($data_type),+) => { + $m!($crate::types::UInt16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt32, $($data_type),+) => { + $m!($crate::types::UInt32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt64, $($data_type),+) => { + $m!($crate::types::UInt64Type $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to an integer [`DataType`] invokes the provided macro +/// `m` with the corresponding integer [`RunEndIndexType`], followed by any additional arguments +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType, downcast_run_end_index}; +/// # use arrow_schema::{DataType, Field}; +/// +/// macro_rules! run_end_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn run_end_index_size(t: &DataType) -> u8 { +/// match t { +/// DataType::RunEndEncoded(k, _) => downcast_run_end_index! { +/// k.data_type() => (run_end_size_helper, u8), +/// _ => unreachable!(), +/// }, +/// _ => u8::MAX, +/// } +/// } +/// +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int32, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 4); +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int64, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 8); +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int16, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 2); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_run_end_index { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Int16, $($data_type),+) => { + $m!($crate::types::Int16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int32, $($data_type),+) => { + $m!($crate::types::Int32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int64, $($data_type),+) => { + $m!($crate::types::Int64Type $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to primitive [`DataType`] invokes the provided macro +/// `m` with the corresponding [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_temporal, ArrowPrimitiveType}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! temporal_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn temporal_size(t: &DataType) -> u8 { +/// downcast_temporal! { +/// t => (temporal_size_helper, u8), +/// _ => u8::MAX +/// } +/// } +/// +/// assert_eq!(temporal_size(&DataType::Date32), 4); +/// assert_eq!(temporal_size(&DataType::Date64), 8); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_temporal { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Time32(arrow_schema::TimeUnit::Second), $($data_type),+) => { + $m!($crate::types::Time32SecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time32(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => { + $m!($crate::types::Time32MillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => { + $m!($crate::types::Time64MicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => { + $m!($crate::types::Time64NanosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Date32, $($data_type),+) => { + $m!($crate::types::Date32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Date64, $($data_type),+) => { + $m!($crate::types::Date64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Second, _), $($data_type),+) => { + $m!($crate::types::TimestampSecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _), $($data_type),+) => { + $m!($crate::types::TimestampMillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _), $($data_type),+) => { + $m!($crate::types::TimestampMicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _), $($data_type),+) => { + $m!($crate::types::TimestampNanosecondType $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Downcast an [`Array`] to a temporal [`PrimitiveArray`] based on its [`DataType`] +/// accepts a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, downcast_temporal_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_temporal(array: &dyn Array) { +/// downcast_temporal_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_temporal_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!($values => {$e} $($p => $fallback)*) + }; + (($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!($($values),+ => {$e} $($p => $fallback)*) + }; + ($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!(($($values),+) => $e $($p => $fallback)*) + }; + (($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal!{ + $($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e), + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to primitive [`DataType`] invokes the provided macro +/// `m` with the corresponding [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! primitive_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn primitive_size(t: &DataType) -> u8 { +/// downcast_primitive! { +/// t => (primitive_size_helper, u8), +/// _ => u8::MAX +/// } +/// } +/// +/// assert_eq!(primitive_size(&DataType::Int32), 4); +/// assert_eq!(primitive_size(&DataType::Int64), 8); +/// assert_eq!(primitive_size(&DataType::Float16), 2); +/// assert_eq!(primitive_size(&DataType::Decimal128(38, 10)), 16); +/// assert_eq!(primitive_size(&DataType::Decimal256(76, 20)), 32); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_primitive { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_integer! { + $($data_type),+ => ($m $(, $args)*), + $crate::repeat_pat!(arrow_schema::DataType::Float16, $($data_type),+) => { + $m!($crate::types::Float16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float32, $($data_type),+) => { + $m!($crate::types::Float32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float64, $($data_type),+) => { + $m!($crate::types::Float64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal128(_, _), $($data_type),+) => { + $m!($crate::types::Decimal128Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal256(_, _), $($data_type),+) => { + $m!($crate::types::Decimal256Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::YearMonth), $($data_type),+) => { + $m!($crate::types::IntervalYearMonthType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::DayTime), $($data_type),+) => { + $m!($crate::types::IntervalDayTimeType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano), $($data_type),+) => { + $m!($crate::types::IntervalMonthDayNanoType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Second), $($data_type),+) => { + $m!($crate::types::DurationSecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => { + $m!($crate::types::DurationMillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => { + $m!($crate::types::DurationMicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => { + $m!($crate::types::DurationNanosecondType $(, $args)*) + } + _ => { + $crate::downcast_temporal! { + $($data_type),+ => ($m $(, $args)*), + $($p => $fallback,)* + } + } + } + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_primitive_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_primitive_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`] +/// accepts a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, downcast_primitive_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_primitive(array: &dyn Array) { +/// downcast_primitive_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_primitive_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!($values => {$e} $($p => $fallback)*) + }; + (($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!($($values),+ => {$e} $($p => $fallback)*) + }; + ($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!(($($values),+) => $e $($p => $fallback)*) + }; + (($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive!{ + $($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e), + $($p => $fallback,)* + } + }; +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`], to +/// [`PrimitiveArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use arrow_array::cast::as_primitive_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); +/// +/// // Downcast an `ArrayRef` to Int32Array / PrimitiveArray: +/// let primitive_array: &Int32Array = as_primitive_array(&arr); +/// +/// // Equivalently: +/// let primitive_array = as_primitive_array::(&arr); +/// +/// // This is the equivalent of: +/// let primitive_array = arr +/// .as_any() +/// .downcast_ref::() +/// .unwrap(); +/// ``` + +pub fn as_primitive_array(arr: &dyn Array) -> &PrimitiveArray +where + T: ArrowPrimitiveType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to primitive array") +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_dictionary_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_dictionary_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, StringArray, downcast_dictionary_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_dictionary_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast_dict::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported dictionary value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_dictionary_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_dictionary_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + arrow_schema::DataType::Dictionary(k, _) => { + $crate::downcast_integer! { + k.as_ref() => ($crate::downcast_dictionary_array_helper, $values, $e), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + $($p => $fallback,)* + } + } +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`DictionaryArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow_array::{ArrayRef, DictionaryArray}; +/// # use arrow_array::cast::as_dictionary_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: DictionaryArray = vec![Some("foo")].into_iter().collect(); +/// let arr: ArrayRef = std::sync::Arc::new(arr); +/// let dict_array: &DictionaryArray = as_dictionary_array::(&arr); +/// ``` +pub fn as_dictionary_array(arr: &dyn Array) -> &DictionaryArray +where + T: ArrowDictionaryKeyType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to dictionary array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`RunArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow_array::{ArrayRef, RunArray}; +/// # use arrow_array::cast::as_run_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: RunArray = vec![Some("foo")].into_iter().collect(); +/// let arr: ArrayRef = std::sync::Arc::new(arr); +/// let run_array: &RunArray = as_run_array::(&arr); +/// ``` +pub fn as_run_array(arr: &dyn Array) -> &RunArray +where + T: RunEndIndexType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to run array") +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_run_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_run_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`RunArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, StringArray, downcast_run_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_run_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported run array value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_run_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_run_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + arrow_schema::DataType::RunEndEncoded(k, _) => { + $crate::downcast_run_end_index! { + k.data_type() => ($crate::downcast_run_array_helper, $values, $e), + k => unreachable!("unsupported run end index type: {}", k) + } + } + $($p => $fallback,)* + } + } +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericListArray`], panicking on failure. +pub fn as_generic_list_array(arr: &dyn Array) -> &GenericListArray { + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to list array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`ListArray`], panicking on failure. +#[inline] +pub fn as_list_array(arr: &dyn Array) -> &ListArray { + as_generic_list_array::(arr) +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`FixedSizeListArray`], panicking on failure. +#[inline] +pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to fixed size list array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`LargeListArray`], panicking on failure. +#[inline] +pub fn as_large_list_array(arr: &dyn Array) -> &LargeListArray { + as_generic_list_array::(arr) +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericBinaryArray`], panicking on failure. +#[inline] +pub fn as_generic_binary_array(arr: &dyn Array) -> &GenericBinaryArray { + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to binary array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`StringArray`], panicking on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::cast::as_string_array; +/// # use arrow_array::{ArrayRef, StringArray}; +/// +/// let arr: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo")])); +/// let string_array = as_string_array(&arr); +/// ``` +pub fn as_string_array(arr: &dyn Array) -> &StringArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to StringArray") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`BooleanArray`], panicking on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, BooleanArray}; +/// # use arrow_array::cast::as_boolean_array; +/// +/// let arr: ArrayRef = Arc::new(BooleanArray::from_iter(vec![Some(true)])); +/// let boolean_array = as_boolean_array(&arr); +/// ``` +pub fn as_boolean_array(arr: &dyn Array) -> &BooleanArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to BooleanArray") +} + +macro_rules! array_downcast_fn { + ($name: ident, $arrty: ty, $arrty_str:expr) => { + #[doc = "Force downcast of an [`Array`], such as an [`ArrayRef`] to "] + #[doc = $arrty_str] + pub fn $name(arr: &dyn Array) -> &$arrty { + arr.as_any().downcast_ref::<$arrty>().expect(concat!( + "Unable to downcast to typed array through ", + stringify!($name) + )) + } + }; + + // use recursive macro to generate dynamic doc string for a given array type + ($name: ident, $arrty: ty) => { + array_downcast_fn!( + $name, + $arrty, + concat!("[`", stringify!($arrty), "`], panicking on failure.") + ); + }; +} + +array_downcast_fn!(as_largestring_array, LargeStringArray); +array_downcast_fn!(as_null_array, NullArray); +array_downcast_fn!(as_struct_array, StructArray); +array_downcast_fn!(as_union_array, UnionArray); +array_downcast_fn!(as_map_array, MapArray); + +/// Force downcast of an Array, such as an ArrayRef to Decimal128Array, panic’ing on failure. +#[deprecated(note = "please use `as_primitive_array::` instead")] +pub fn as_decimal_array(arr: &dyn Array) -> &PrimitiveArray { + as_primitive_array::(arr) +} + +/// Downcasts a `dyn Array` to a concrete type +/// +/// ``` +/// # use arrow_array::{BooleanArray, Int32Array, RecordBatch, StringArray}; +/// # use arrow_array::cast::downcast_array; +/// struct ConcreteBatch { +/// col1: Int32Array, +/// col2: BooleanArray, +/// col3: StringArray, +/// } +/// +/// impl ConcreteBatch { +/// fn new(batch: &RecordBatch) -> Self { +/// Self { +/// col1: downcast_array(batch.column(0).as_ref()), +/// col2: downcast_array(batch.column(1).as_ref()), +/// col3: downcast_array(batch.column(2).as_ref()), +/// } +/// } +/// } +/// ``` +/// +/// # Panics +/// +/// Panics if array is not of the correct data type +pub fn downcast_array(array: &dyn Array) -> T +where + T: From, +{ + T::from(array.to_data()) +} + +mod private { + pub trait Sealed {} +} + +/// An extension trait for `dyn Array` that provides ergonomic downcasting +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::Int32Type; +/// let col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; +/// assert_eq!(col.as_primitive::().values(), &[1, 2, 3]); +/// ``` +pub trait AsArray: private::Sealed { + /// Downcast this to a [`BooleanArray`] returning `None` if not possible + fn as_boolean_opt(&self) -> Option<&BooleanArray>; + + /// Downcast this to a [`BooleanArray`] panicking if not possible + fn as_boolean(&self) -> &BooleanArray { + self.as_boolean_opt().expect("boolean array") + } + + /// Downcast this to a [`PrimitiveArray`] returning `None` if not possible + fn as_primitive_opt(&self) -> Option<&PrimitiveArray>; + + /// Downcast this to a [`PrimitiveArray`] panicking if not possible + fn as_primitive(&self) -> &PrimitiveArray { + self.as_primitive_opt().expect("primitive array") + } + + /// Downcast this to a [`GenericByteArray`] returning `None` if not possible + fn as_bytes_opt(&self) -> Option<&GenericByteArray>; + + /// Downcast this to a [`GenericByteArray`] panicking if not possible + fn as_bytes(&self) -> &GenericByteArray { + self.as_bytes_opt().expect("byte array") + } + + /// Downcast this to a [`GenericStringArray`] returning `None` if not possible + fn as_string_opt(&self) -> Option<&GenericStringArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericStringArray`] panicking if not possible + fn as_string(&self) -> &GenericStringArray { + self.as_bytes_opt().expect("string array") + } + + /// Downcast this to a [`GenericBinaryArray`] returning `None` if not possible + fn as_binary_opt(&self) -> Option<&GenericBinaryArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericBinaryArray`] panicking if not possible + fn as_binary(&self) -> &GenericBinaryArray { + self.as_bytes_opt().expect("binary array") + } + + /// Downcast this to a [`StringViewArray`] returning `None` if not possible + fn as_string_view(&self) -> &StringViewArray { + self.as_byte_view_opt().expect("string view array") + } + + /// Downcast this to a [`StringViewArray`] returning `None` if not possible + fn as_string_view_opt(&self) -> Option<&StringViewArray> { + self.as_byte_view_opt() + } + + /// Downcast this to a [`StringViewArray`] returning `None` if not possible + fn as_binary_view(&self) -> &BinaryViewArray { + self.as_byte_view_opt().expect("binary view array") + } + + /// Downcast this to a [`BinaryViewArray`] returning `None` if not possible + fn as_binary_view_opt(&self) -> Option<&BinaryViewArray> { + self.as_byte_view_opt() + } + + /// Downcast this to a [`GenericByteViewArray`] returning `None` if not possible + fn as_byte_view(&self) -> &GenericByteViewArray { + self.as_byte_view_opt().expect("byte view array") + } + + /// Downcast this to a [`GenericByteViewArray`] returning `None` if not possible + fn as_byte_view_opt(&self) -> Option<&GenericByteViewArray>; + + /// Downcast this to a [`StructArray`] returning `None` if not possible + fn as_struct_opt(&self) -> Option<&StructArray>; + + /// Downcast this to a [`StructArray`] panicking if not possible + fn as_struct(&self) -> &StructArray { + self.as_struct_opt().expect("struct array") + } + + /// Downcast this to a [`UnionArray`] returning `None` if not possible + fn as_union_opt(&self) -> Option<&UnionArray>; + + /// Downcast this to a [`UnionArray`] panicking if not possible + fn as_union(&self) -> &UnionArray { + self.as_union_opt().expect("union array") + } + + /// Downcast this to a [`GenericListArray`] returning `None` if not possible + fn as_list_opt(&self) -> Option<&GenericListArray>; + + /// Downcast this to a [`GenericListArray`] panicking if not possible + fn as_list(&self) -> &GenericListArray { + self.as_list_opt().expect("list array") + } + + /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not possible + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>; + + /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible + fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray { + self.as_fixed_size_binary_opt() + .expect("fixed size binary array") + } + + /// Downcast this to a [`FixedSizeListArray`] returning `None` if not possible + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>; + + /// Downcast this to a [`FixedSizeListArray`] panicking if not possible + fn as_fixed_size_list(&self) -> &FixedSizeListArray { + self.as_fixed_size_list_opt() + .expect("fixed size list array") + } + + /// Downcast this to a [`MapArray`] returning `None` if not possible + fn as_map_opt(&self) -> Option<&MapArray>; + + /// Downcast this to a [`MapArray`] panicking if not possible + fn as_map(&self) -> &MapArray { + self.as_map_opt().expect("map array") + } + + /// Downcast this to a [`DictionaryArray`] returning `None` if not possible + fn as_dictionary_opt(&self) -> Option<&DictionaryArray>; + + /// Downcast this to a [`DictionaryArray`] panicking if not possible + fn as_dictionary(&self) -> &DictionaryArray { + self.as_dictionary_opt().expect("dictionary array") + } + + /// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not possible + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>; + + /// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible + fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray { + self.as_any_dictionary_opt().expect("any dictionary array") + } +} + +impl private::Sealed for dyn Array + '_ {} +impl AsArray for dyn Array + '_ { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_any().downcast_ref() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_any().downcast_ref() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_any().downcast_ref() + } + + fn as_byte_view_opt(&self) -> Option<&GenericByteViewArray> { + self.as_any().downcast_ref() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_any().downcast_ref() + } + + fn as_union_opt(&self) -> Option<&UnionArray> { + self.as_any().downcast_ref() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_any().downcast_ref() + } + + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_any().downcast_ref() + } + + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { + self.as_any().downcast_ref() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { + self.as_any().downcast_ref() + } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + let array = self; + downcast_dictionary_array! { + array => Some(array), + _ => None + } + } +} + +impl private::Sealed for ArrayRef {} +impl AsArray for ArrayRef { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_ref().as_boolean_opt() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_ref().as_primitive_opt() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_ref().as_bytes_opt() + } + + fn as_byte_view_opt(&self) -> Option<&GenericByteViewArray> { + self.as_ref().as_byte_view_opt() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_ref().as_struct_opt() + } + + fn as_union_opt(&self) -> Option<&UnionArray> { + self.as_any().downcast_ref() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_ref().as_list_opt() + } + + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_ref().as_fixed_size_binary_opt() + } + + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { + self.as_ref().as_fixed_size_list_opt() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { + self.as_ref().as_dictionary_opt() + } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + self.as_ref().as_any_dictionary_opt() + } +} + +#[cfg(test)] +mod tests { + use arrow_buffer::i256; + use std::sync::Arc; + + use super::*; + + #[test] + fn test_as_primitive_array_ref() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert!(!as_primitive_array::(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_primitive_array::(&array).is_empty()); + } + + #[test] + fn test_as_string_array_ref() { + let array: StringArray = vec!["foo", "bar"].into_iter().map(Some).collect(); + assert!(!as_string_array(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_string_array(&array).is_empty()) + } + + #[test] + fn test_decimal128array() { + let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + + #[test] + fn test_decimal256array() { + let a = Decimal256Array::from_iter_values([1, 2, 4, 5].into_iter().map(i256::from_i128)); + assert!(!as_primitive_array::(&a).is_empty()); + } +} diff --git a/arrow-array/src/delta.rs b/arrow-array/src/delta.rs new file mode 100644 index 000000000000..d9aa4aa6de5d --- /dev/null +++ b/arrow-array/src/delta.rs @@ -0,0 +1,285 @@ +// MIT License +// +// Copyright (c) 2020-2022 Oliver Margetts +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Copied from chronoutil crate + +//! Contains utility functions for shifting Date objects. +use chrono::{DateTime, Datelike, Days, Months, TimeZone}; +use std::cmp::Ordering; + +/// Shift a date by the given number of months. +pub(crate) fn shift_months(date: D, months: i32) -> D +where + D: Datelike + std::ops::Add + std::ops::Sub, +{ + match months.cmp(&0) { + Ordering::Equal => date, + Ordering::Greater => date + Months::new(months as u32), + Ordering::Less => date - Months::new(months.unsigned_abs()), + } +} + +/// Add the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_months(Months::new(months as u32)), + Ordering::Less => dt.checked_sub_months(Months::new(months.unsigned_abs())), + } +} + +/// Add the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_days_datetime(dt: DateTime, days: i32) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_days(Days::new(days as u64)), + Ordering::Less => dt.checked_sub_days(Days::new(days.unsigned_abs() as u64)), + } +} + +/// Substract the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_months(Months::new(months as u32)), + Ordering::Less => dt.checked_add_months(Months::new(months.unsigned_abs())), + } +} + +/// Substract the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_days_datetime(dt: DateTime, days: i32) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_days(Days::new(days as u64)), + Ordering::Less => dt.checked_add_days(Days::new(days.unsigned_abs() as u64)), + } +} + +#[cfg(test)] +mod tests { + + use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime}; + + use super::*; + + #[test] + fn test_shift_months() { + let base = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); + + assert_eq!( + shift_months(base, 0), + NaiveDate::from_ymd_opt(2020, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 1), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, 2), + NaiveDate::from_ymd_opt(2020, 3, 31).unwrap() + ); + assert_eq!( + shift_months(base, 3), + NaiveDate::from_ymd_opt(2020, 4, 30).unwrap() + ); + assert_eq!( + shift_months(base, 4), + NaiveDate::from_ymd_opt(2020, 5, 31).unwrap() + ); + assert_eq!( + shift_months(base, 5), + NaiveDate::from_ymd_opt(2020, 6, 30).unwrap() + ); + assert_eq!( + shift_months(base, 6), + NaiveDate::from_ymd_opt(2020, 7, 31).unwrap() + ); + assert_eq!( + shift_months(base, 7), + NaiveDate::from_ymd_opt(2020, 8, 31).unwrap() + ); + assert_eq!( + shift_months(base, 8), + NaiveDate::from_ymd_opt(2020, 9, 30).unwrap() + ); + assert_eq!( + shift_months(base, 9), + NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, 10), + NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, 11), + NaiveDate::from_ymd_opt(2020, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, 12), + NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 13), + NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() + ); + + assert_eq!( + shift_months(base, -1), + NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, -2), + NaiveDate::from_ymd_opt(2019, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, -3), + NaiveDate::from_ymd_opt(2019, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, -4), + NaiveDate::from_ymd_opt(2019, 9, 30).unwrap() + ); + assert_eq!( + shift_months(base, -5), + NaiveDate::from_ymd_opt(2019, 8, 31).unwrap() + ); + assert_eq!( + shift_months(base, -6), + NaiveDate::from_ymd_opt(2019, 7, 31).unwrap() + ); + assert_eq!( + shift_months(base, -7), + NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() + ); + assert_eq!( + shift_months(base, -8), + NaiveDate::from_ymd_opt(2019, 5, 31).unwrap() + ); + assert_eq!( + shift_months(base, -9), + NaiveDate::from_ymd_opt(2019, 4, 30).unwrap() + ); + assert_eq!( + shift_months(base, -10), + NaiveDate::from_ymd_opt(2019, 3, 31).unwrap() + ); + assert_eq!( + shift_months(base, -11), + NaiveDate::from_ymd_opt(2019, 2, 28).unwrap() + ); + assert_eq!( + shift_months(base, -12), + NaiveDate::from_ymd_opt(2019, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, -13), + NaiveDate::from_ymd_opt(2018, 12, 31).unwrap() + ); + + assert_eq!( + shift_months(base, 1265), + NaiveDate::from_ymd_opt(2125, 6, 30).unwrap() + ); + } + + #[test] + fn test_shift_months_with_overflow() { + let base = NaiveDate::from_ymd_opt(2020, 12, 31).unwrap(); + + assert_eq!(shift_months(base, 0), base); + assert_eq!( + shift_months(base, 1), + NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 2), + NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() + ); + assert_eq!( + shift_months(base, 12), + NaiveDate::from_ymd_opt(2021, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, 18), + NaiveDate::from_ymd_opt(2022, 6, 30).unwrap() + ); + + assert_eq!( + shift_months(base, -1), + NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, -2), + NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, -10), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, -12), + NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, -18), + NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() + ); + } + + #[test] + fn test_shift_months_datetime() { + let date = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); + let o_clock = NaiveTime::from_hms_opt(1, 2, 3).unwrap(); + + let base = NaiveDateTime::new(date, o_clock); + + assert_eq!( + shift_months(base, 0).date(), + NaiveDate::from_ymd_opt(2020, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 1).date(), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, 2).date(), + NaiveDate::from_ymd_opt(2020, 3, 31).unwrap() + ); + assert_eq!(shift_months(base, 0).time(), o_clock); + assert_eq!(shift_months(base, 1).time(), o_clock); + assert_eq!(shift_months(base, 2).time(), o_clock); + } +} diff --git a/arrow-array/src/ffi.rs b/arrow-array/src/ffi.rs new file mode 100644 index 000000000000..a28b3f746115 --- /dev/null +++ b/arrow-array/src/ffi.rs @@ -0,0 +1,1694 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). +//! +//! Generally, this module is divided in two main interfaces: +//! One interface maps C ABI to native Rust types, i.e. convert c-pointers, c_char, to native rust. +//! This is handled by [FFI_ArrowSchema] and [FFI_ArrowArray]. +//! +//! The second interface maps native Rust types to the Rust-specific implementation of Arrow such as `format` to `Datatype`, +//! `Buffer`, etc. This is handled by `from_ffi` and `to_ffi`. +//! +//! +//! Export to FFI +//! +//! ```rust +//! # use std::sync::Arc; +//! # use arrow_array::{Int32Array, Array, make_array}; +//! # use arrow_data::ArrayData; +//! # use arrow_array::ffi::{to_ffi, from_ffi}; +//! # use arrow_schema::ArrowError; +//! # fn main() -> Result<(), ArrowError> { +//! // create an array natively +//! +//! let array = Int32Array::from(vec![Some(1), None, Some(3)]); +//! let data = array.into_data(); +//! +//! // Export it +//! let (out_array, out_schema) = to_ffi(&data)?; +//! +//! // import it +//! let data = unsafe { from_ffi(out_array, &out_schema) }?; +//! let array = Int32Array::from(data); +//! +//! // verify +//! assert_eq!(array, Int32Array::from(vec![Some(1), None, Some(3)])); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Import from FFI +//! +//! ``` +//! # use std::ptr::addr_of_mut; +//! # use arrow_array::ffi::{from_ffi, FFI_ArrowArray}; +//! # use arrow_array::{ArrayRef, make_array}; +//! # use arrow_schema::{ArrowError, ffi::FFI_ArrowSchema}; +//! # +//! /// A foreign data container that can export to C Data interface +//! struct ForeignArray {}; +//! +//! impl ForeignArray { +//! /// Export from foreign array representation to C Data interface +//! /// e.g. +//! fn export_to_c(&self, array: *mut FFI_ArrowArray, schema: *mut FFI_ArrowSchema) { +//! // ... +//! } +//! } +//! +//! /// Import an [`ArrayRef`] from a [`ForeignArray`] +//! fn import_array(foreign: &ForeignArray) -> Result { +//! let mut schema = FFI_ArrowSchema::empty(); +//! let mut array = FFI_ArrowArray::empty(); +//! foreign.export_to_c(addr_of_mut!(array), addr_of_mut!(schema)); +//! Ok(make_array(unsafe { from_ffi(array, &schema) }?)) +//! } +//! ``` + +/* +# Design: + +Main assumptions: +* A memory region is deallocated according it its own release mechanism. +* Rust shares memory regions between arrays. +* A memory region should be deallocated when no-one is using it. + +The design of this module is as follows: + +`ArrowArray` contains two `Arc`s, one per ABI-compatible `struct`, each containing data +according to the C Data Interface. These Arcs are used for ref counting of the structs +within Rust and lifetime management. + +Each ABI-compatible `struct` knowns how to `drop` itself, calling `release`. + +To import an array, unsafely create an `ArrowArray` from two pointers using [ArrowArray::try_from_raw]. +To export an array, create an `ArrowArray` using [ArrowArray::try_new]. +*/ + +use std::{mem::size_of, ptr::NonNull, sync::Arc}; + +use arrow_buffer::{bit_util, Buffer, MutableBuffer}; +pub use arrow_data::ffi::FFI_ArrowArray; +use arrow_data::{layout, ArrayData}; +pub use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::{ArrowError, DataType, UnionMode}; + +use crate::array::ArrayRef; + +type Result = std::result::Result; + +/// Exports an array to raw pointers of the C Data Interface provided by the consumer. +/// # Safety +/// Assumes that these pointers represent valid C Data Interfaces, both in memory +/// representation and lifetime via the `release` mechanism. +/// +/// This function copies the content of two FFI structs [arrow_data::ffi::FFI_ArrowArray] and +/// [arrow_schema::ffi::FFI_ArrowSchema] in the array to the location pointed by the raw pointers. +/// Usually the raw pointers are provided by the array data consumer. +#[deprecated(note = "Use FFI_ArrowArray::new and FFI_ArrowSchema::try_from")] +pub unsafe fn export_array_into_raw( + src: ArrayRef, + out_array: *mut FFI_ArrowArray, + out_schema: *mut FFI_ArrowSchema, +) -> Result<()> { + let data = src.to_data(); + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type())?; + + std::ptr::write_unaligned(out_array, array); + std::ptr::write_unaligned(out_schema, schema); + + Ok(()) +} + +// returns the number of bits that buffer `i` (in the C data interface) is expected to have. +// This is set by the Arrow specification +fn bit_width(data_type: &DataType, i: usize) -> Result { + if let Some(primitive) = data_type.primitive_width() { + return match i { + 0 => Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." + ))), + 1 => Ok(primitive * 8), + i => Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))), + }; + } + + Ok(match (data_type, i) { + (DataType::Boolean, 1) => 1, + (DataType::Boolean, _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + (DataType::FixedSizeBinary(num_bytes), 1) => *num_bytes as usize * u8::BITS as usize, + (DataType::FixedSizeList(f, num_elems), 1) => { + let child_bit_width = bit_width(f.data_type(), 1)?; + child_bit_width * (*num_elems as usize) + }, + (DataType::FixedSizeBinary(_), _) | (DataType::FixedSizeList(_, _), _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + }, + // Variable-size list and map have one i32 buffer. + // Variable-sized binaries: have two buffers. + // "small": first buffer is i32, second is in bytes + (DataType::Utf8, 1) | (DataType::Binary, 1) | (DataType::List(_), 1) | (DataType::Map(_, _), 1) => i32::BITS as _, + (DataType::Utf8, 2) | (DataType::Binary, 2) => u8::BITS as _, + (DataType::List(_), _) | (DataType::Map(_, _), _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + (DataType::Utf8, _) | (DataType::Binary, _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + // Variable-sized binaries: have two buffers. + // LargeUtf8: first buffer is i64, second is in bytes + (DataType::LargeUtf8, 1) | (DataType::LargeBinary, 1) | (DataType::LargeList(_), 1) => i64::BITS as _, + (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) | (DataType::LargeList(_), 2)=> u8::BITS as _, + (DataType::LargeUtf8, _) | (DataType::LargeBinary, _) | (DataType::LargeList(_), _)=> { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + // Variable-sized views: have 3 or more buffers. + // Buffer 1 are the u128 views + // Buffers 2...N-1 are u8 byte buffers + (DataType::Utf8View, 1) | (DataType::BinaryView,1) => u128::BITS as _, + (DataType::Utf8View, _) | (DataType::BinaryView, _) => { + u8::BITS as _ + } + // type ids. UnionArray doesn't have null bitmap so buffer index begins with 0. + (DataType::Union(_, _), 0) => i8::BITS as _, + // Only DenseUnion has 2nd buffer + (DataType::Union(_, UnionMode::Dense), 1) => i32::BITS as _, + (DataType::Union(_, UnionMode::Sparse), _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 1 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + (DataType::Union(_, UnionMode::Dense), _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" expects 2 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." + ))) + } + (_, 0) => { + // We don't call this `bit_width` to compute buffer length for null buffer. If any types that don't have null buffer like + // UnionArray, they should be handled above. + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." + ))) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{data_type:?}\" is still not supported in Rust implementation" + ))) + } + }) +} + +/// returns a new buffer corresponding to the index `i` of the FFI array. It may not exist (null pointer). +/// `bits` is the number of bits that the native type of this buffer has. +/// The size of the buffer will be `ceil(self.length * bits, 8)`. +/// # Panic +/// This function panics if `i` is larger or equal to `n_buffers`. +/// # Safety +/// This function assumes that `ceil(self.length * bits, 8)` is the size of the buffer +unsafe fn create_buffer( + owner: Arc, + array: &FFI_ArrowArray, + index: usize, + len: usize, +) -> Option { + if array.num_buffers() == 0 { + return None; + } + NonNull::new(array.buffer(index) as _) + .map(|ptr| Buffer::from_custom_allocation(ptr, len, owner)) +} + +/// Export to the C Data Interface +pub fn to_ffi(data: &ArrayData) -> Result<(FFI_ArrowArray, FFI_ArrowSchema)> { + let array = FFI_ArrowArray::new(data); + let schema = FFI_ArrowSchema::try_from(data.data_type())?; + Ok((array, schema)) +} + +/// Import [ArrayData] from the C Data Interface +/// +/// # Safety +/// +/// This struct assumes that the incoming data agrees with the C data interface. +pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { + let dt = DataType::try_from(schema)?; + let array = Arc::new(array); + let tmp = ImportedArrowArray { + array: &array, + data_type: dt, + owner: &array, + }; + tmp.consume() +} + +/// Import [ArrayData] from the C Data Interface +/// +/// # Safety +/// +/// This struct assumes that the incoming data agrees with the C data interface. +pub unsafe fn from_ffi_and_data_type( + array: FFI_ArrowArray, + data_type: DataType, +) -> Result { + let array = Arc::new(array); + let tmp = ImportedArrowArray { + array: &array, + data_type, + owner: &array, + }; + tmp.consume() +} + +#[derive(Debug)] +struct ImportedArrowArray<'a> { + array: &'a FFI_ArrowArray, + data_type: DataType, + owner: &'a Arc, +} + +impl<'a> ImportedArrowArray<'a> { + fn consume(self) -> Result { + let len = self.array.len(); + let offset = self.array.offset(); + let null_count = match &self.data_type { + DataType::Null => 0, + _ => self.array.null_count(), + }; + + let data_layout = layout(&self.data_type); + let buffers = self.buffers(data_layout.can_contain_null_mask, data_layout.variadic)?; + + let null_bit_buffer = if data_layout.can_contain_null_mask { + self.null_bit_buffer() + } else { + None + }; + + let mut child_data = self.consume_children()?; + + if let Some(d) = self.dictionary()? { + // For dictionary type there should only be a single child, so we don't need to worry if + // there are other children added above. + assert!(child_data.is_empty()); + child_data.push(d.consume()?); + } + + // Should FFI be checking validity? + Ok(unsafe { + ArrayData::new_unchecked( + self.data_type, + len, + Some(null_count), + null_bit_buffer, + offset, + buffers, + child_data, + ) + }) + } + + fn consume_children(&self) -> Result> { + match &self.data_type { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::Map(field, _) => Ok([self.consume_child(0, field.data_type())?].to_vec()), + DataType::Struct(fields) => { + assert!(fields.len() == self.array.num_children()); + fields + .iter() + .enumerate() + .map(|(i, field)| self.consume_child(i, field.data_type())) + .collect::>>() + } + DataType::Union(union_fields, _) => { + assert!(union_fields.len() == self.array.num_children()); + union_fields + .iter() + .enumerate() + .map(|(i, (_, field))| self.consume_child(i, field.data_type())) + .collect::>>() + } + DataType::RunEndEncoded(run_ends_field, values_field) => Ok([ + self.consume_child(0, run_ends_field.data_type())?, + self.consume_child(1, values_field.data_type())?, + ] + .to_vec()), + _ => Ok(Vec::new()), + } + } + + fn consume_child(&self, index: usize, child_type: &DataType) -> Result { + ImportedArrowArray { + array: self.array.child(index), + data_type: child_type.clone(), + owner: self.owner, + } + .consume() + } + + /// returns all buffers, as organized by Rust (i.e. null buffer is skipped if it's present + /// in the spec of the type) + fn buffers(&self, can_contain_null_mask: bool, variadic: bool) -> Result> { + // + 1: skip null buffer + let buffer_begin = can_contain_null_mask as usize; + let buffer_end = self.array.num_buffers() - usize::from(variadic); + + let variadic_buffer_lens = if variadic { + // Each views array has 1 (optional) null buffer, 1 views buffer, 1 lengths buffer. + // Rest are variadic. + let num_variadic_buffers = + self.array.num_buffers() - (2 + usize::from(can_contain_null_mask)); + if num_variadic_buffers == 0 { + &[] + } else { + let lengths = self.array.buffer(self.array.num_buffers() - 1); + // SAFETY: is lengths is non-null, then it must be valid for up to num_variadic_buffers. + unsafe { std::slice::from_raw_parts(lengths.cast::(), num_variadic_buffers) } + } + } else { + &[] + }; + + (buffer_begin..buffer_end) + .map(|index| { + let len = self.buffer_len(index, variadic_buffer_lens, &self.data_type)?; + match unsafe { create_buffer(self.owner.clone(), self.array, index, len) } { + Some(buf) => Ok(buf), + None if len == 0 => { + // Null data buffer, which Rust doesn't allow. So create + // an empty buffer. + Ok(MutableBuffer::new(0).into()) + } + None => Err(ArrowError::CDataInterface(format!( + "The external buffer at position {index} is null." + ))), + } + }) + .collect() + } + + /// Returns the length, in bytes, of the buffer `i` (indexed according to the C data interface) + /// Rust implementation uses fixed-sized buffers, which require knowledge of their `len`. + /// for variable-sized buffers, such as the second buffer of a stringArray, we need + /// to fetch offset buffer's len to build the second buffer. + fn buffer_len( + &self, + i: usize, + variadic_buffer_lengths: &[i64], + dt: &DataType, + ) -> Result { + // Special handling for dictionary type as we only care about the key type in the case. + let data_type = match dt { + DataType::Dictionary(key_data_type, _) => key_data_type.as_ref(), + dt => dt, + }; + + // `ffi::ArrowArray` records array offset, we need to add it back to the + // buffer length to get the actual buffer length. + let length = self.array.len() + self.array.offset(); + + // Inner type is not important for buffer length. + Ok(match (&data_type, i) { + (DataType::Utf8, 1) + | (DataType::LargeUtf8, 1) + | (DataType::Binary, 1) + | (DataType::LargeBinary, 1) + | (DataType::List(_), 1) + | (DataType::LargeList(_), 1) + | (DataType::Map(_, _), 1) => { + // the len of the offset buffer (buffer 1) equals length + 1 + let bits = bit_width(data_type, i)?; + debug_assert_eq!(bits % 8, 0); + (length + 1) * (bits / 8) + } + (DataType::Utf8, 2) | (DataType::Binary, 2) => { + if self.array.is_empty() { + return Ok(0); + } + + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = self.buffer_len(1, variadic_buffer_lengths, dt)?; + // first buffer is the null buffer => add(1) + // we assume that pointer is aligned for `i32`, as Utf8 uses `i32` offsets. + #[allow(clippy::cast_ptr_alignment)] + let offset_buffer = self.array.buffer(1) as *const i32; + // get last offset + (unsafe { *offset_buffer.add(len / size_of::() - 1) }) as usize + } + (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) => { + if self.array.is_empty() { + return Ok(0); + } + + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = self.buffer_len(1, variadic_buffer_lengths, dt)?; + // first buffer is the null buffer => add(1) + // we assume that pointer is aligned for `i64`, as Large uses `i64` offsets. + #[allow(clippy::cast_ptr_alignment)] + let offset_buffer = self.array.buffer(1) as *const i64; + // get last offset + (unsafe { *offset_buffer.add(len / size_of::() - 1) }) as usize + } + // View types: these have variadic buffers. + // Buffer 1 is the views buffer, which stores 1 u128 per length of the array. + // Buffers 2..N-1 are the buffers holding the byte data. Their lengths are variable. + // Buffer N is of length (N - 2) and stores i64 containing the lengths of buffers 2..N-1 + (DataType::Utf8View, 1) | (DataType::BinaryView, 1) => { + std::mem::size_of::() * length + } + (DataType::Utf8View, i) | (DataType::BinaryView, i) => { + variadic_buffer_lengths[i - 2] as usize + } + // buffer len of primitive types + _ => { + let bits = bit_width(data_type, i)?; + bit_util::ceil(length * bits, 8) + } + }) + } + + /// returns the null bit buffer. + /// Rust implementation uses a buffer that is not part of the array of buffers. + /// The C Data interface's null buffer is part of the array of buffers. + fn null_bit_buffer(&self) -> Option { + // similar to `self.buffer_len(0)`, but without `Result`. + // `ffi::ArrowArray` records array offset, we need to add it back to the + // buffer length to get the actual buffer length. + let length = self.array.len() + self.array.offset(); + let buffer_len = bit_util::ceil(length, 8); + + unsafe { create_buffer(self.owner.clone(), self.array, 0, buffer_len) } + } + + fn dictionary(&self) -> Result> { + match (self.array.dictionary(), &self.data_type) { + (Some(array), DataType::Dictionary(_, value_type)) => Ok(Some(ImportedArrowArray { + array, + data_type: value_type.as_ref().clone(), + owner: self.owner, + })), + (Some(_), _) => Err(ArrowError::CDataInterface( + "Got dictionary in FFI_ArrowArray for non-dictionary data type".to_string(), + )), + (None, DataType::Dictionary(_, _)) => Err(ArrowError::CDataInterface( + "Missing dictionary in FFI_ArrowArray for dictionary data type".to_string(), + )), + (_, _) => Ok(None), + } + } +} + +#[cfg(test)] +mod tests_to_then_from_ffi { + use std::collections::HashMap; + use std::mem::ManuallyDrop; + + use arrow_buffer::NullBuffer; + use arrow_schema::Field; + + use crate::builder::UnionBuilder; + use crate::cast::AsArray; + use crate::types::{Float64Type, Int32Type, Int8Type}; + use crate::*; + + use super::*; + + #[test] + fn test_round_trip() { + // create an array natively + let array = Int32Array::from(vec![1, 2, 3]); + + // export it + let (array, schema) = to_ffi(&array.into_data()).unwrap(); + + // (simulate consumer) import it + let array = Int32Array::from(unsafe { from_ffi(array, &schema) }.unwrap()); + + // verify + assert_eq!(array, Int32Array::from(vec![1, 2, 3])); + } + + #[test] + fn test_import() { + // Model receiving const pointers from an external system + + // Create an array natively + let data = Int32Array::from(vec![1, 2, 3]).into_data(); + let schema = FFI_ArrowSchema::try_from(data.data_type()).unwrap(); + let array = FFI_ArrowArray::new(&data); + + // Use ManuallyDrop to avoid Box:Drop recursing + let schema = Box::new(ManuallyDrop::new(schema)); + let array = Box::new(ManuallyDrop::new(array)); + + let schema_ptr = &**schema as *const _; + let array_ptr = &**array as *const _; + + // We can read them back to memory + // SAFETY: + // Pointers are aligned and valid + let data = + unsafe { from_ffi(std::ptr::read(array_ptr), &std::ptr::read(schema_ptr)).unwrap() }; + + let array = Int32Array::from(data); + assert_eq!(array, Int32Array::from(vec![1, 2, 3])); + } + + #[test] + fn test_round_trip_with_offset() -> Result<()> { + // create an array natively + let array = Int32Array::from(vec![Some(1), Some(2), None, Some(3), None]); + + let array = array.slice(1, 2); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(array, &Int32Array::from(vec![Some(2), None])); + + // (drop/release) + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_decimal_round_trip() -> Result<()> { + // create an array natively + let original_array = [Some(12345_i128), Some(-12345_i128), None] + .into_iter() + .collect::() + .with_precision_and_scale(6, 2) + .unwrap(); + + // export it + let (array, schema) = to_ffi(&original_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + + // verify + assert_eq!(array, &original_array); + + // (drop/release) + Ok(()) + } + // case with nulls is tested in the docs, through the example on this module. + + fn test_generic_string() -> Result<()> { + // create an array natively + let array = GenericStringArray::::from(vec![Some("a"), None, Some("aaa")]); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // verify + let expected = GenericStringArray::::from(vec![Some("a"), None, Some("aaa")]); + assert_eq!(array, &expected); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_string() -> Result<()> { + test_generic_string::() + } + + #[test] + fn test_large_string() -> Result<()> { + test_generic_string::() + } + + fn test_generic_list() -> Result<()> { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = [0_usize, 3, 6, 8] + .iter() + .map(|i| Offset::from_usize(*i).unwrap()) + .collect::(); + + // Construct a list array from the above two + let list_data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new("item", DataType::Int32, false), + )); + + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + + // create an array natively + let array = GenericListArray::::from(list_data.clone()); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // downcast + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // verify + let expected = GenericListArray::::from(list_data); + assert_eq!(&array.value(0), &expected.value(0)); + assert_eq!(&array.value(1), &expected.value(1)); + assert_eq!(&array.value(2), &expected.value(2)); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_list() -> Result<()> { + test_generic_list::() + } + + #[test] + fn test_large_list() -> Result<()> { + test_generic_list::() + } + + fn test_generic_binary() -> Result<()> { + // create an array natively + let array: Vec> = vec![Some(b"a"), None, Some(b"aaa")]; + let array = GenericBinaryArray::::from(array); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // verify + let expected: Vec> = vec![Some(b"a"), None, Some(b"aaa")]; + let expected = GenericBinaryArray::::from(expected); + assert_eq!(array, &expected); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_binary() -> Result<()> { + test_generic_binary::() + } + + #[test] + fn test_large_binary() -> Result<()> { + test_generic_binary::() + } + + #[test] + fn test_bool() -> Result<()> { + // create an array natively + let array = BooleanArray::from(vec![None, Some(true), Some(false)]); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array.as_any().downcast_ref::().unwrap(); + + // verify + assert_eq!( + array, + &BooleanArray::from(vec![None, Some(true), Some(false)]) + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_time32() -> Result<()> { + // create an array natively + let array = Time32MillisecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &Time32MillisecondArray::from(vec![None, Some(1), Some(2)]) + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_timestamp() -> Result<()> { + // create an array natively + let array = TimestampMillisecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &TimestampMillisecondArray::from(vec![None, Some(1), Some(2)]) + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_fixed_size_binary_array() -> Result<()> { + let values = vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ]; + let array = FixedSizeBinaryArray::try_from_sparse_iter_with_size(values.into_iter(), 3)?; + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ] + .into_iter(), + 3 + )? + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_fixed_size_list_array() -> Result<()> { + // 0000 0100 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 2); + + let v: Vec = (0..9).collect(); + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&v)) + .build()?; + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("f", DataType::Int32, false)), 3); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(value_data) + .build()?; + + // export it + let (array, schema) = to_ffi(&list_data)?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array.as_any().downcast_ref::().unwrap(); + + // 0010 0100 + let mut expected_validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut expected_validity_bits, 2); + bit_util::set_bit(&mut expected_validity_bits, 5); + + let mut w = vec![]; + w.extend_from_slice(&v); + + let expected_value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&w)) + .build()?; + + let expected_list_data = ArrayData::builder(list_data_type) + .len(3) + .null_bit_buffer(Some(Buffer::from(expected_validity_bits))) + .add_child_data(expected_value_data) + .build()?; + let expected_array = FixedSizeListArray::from(expected_list_data); + + // verify + assert_eq!(array, &expected_array); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_dictionary() -> Result<()> { + // create an array natively + let values = vec!["a", "aaa", "aaa"]; + let dict_array: DictionaryArray = values.into_iter().collect(); + + // export it + let (array, schema) = to_ffi(&dict_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let actual = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // verify + let new_values = vec!["a", "aaa", "aaa"]; + let expected: DictionaryArray = new_values.into_iter().collect(); + assert_eq!(actual, &expected); + + // (drop/release) + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_export_array_into_raw() -> Result<()> { + let array = make_array(Int32Array::from(vec![1, 2, 3]).into_data()); + + // Assume two raw pointers provided by the consumer + let mut out_array = FFI_ArrowArray::empty(); + let mut out_schema = FFI_ArrowSchema::empty(); + + { + let out_array_ptr = std::ptr::addr_of_mut!(out_array); + let out_schema_ptr = std::ptr::addr_of_mut!(out_schema); + unsafe { + export_array_into_raw(array, out_array_ptr, out_schema_ptr)?; + } + } + + // (simulate consumer) import it + let data = unsafe { from_ffi(out_array, &out_schema) }?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + + // verify + assert_eq!(array, &Int32Array::from(vec![1, 2, 3])); + Ok(()) + } + + #[test] + fn test_duration() -> Result<()> { + // create an array natively + let array = DurationSecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let (array, schema) = to_ffi(&array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &DurationSecondArray::from(vec![None, Some(1), Some(2)]) + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_map_array() -> Result<()> { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + // export it + let (array, schema) = to_ffi(&map_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array, &map_array); + + Ok(()) + } + + #[test] + fn test_struct_array() -> Result<()> { + let metadata: HashMap = + [("Hello".to_string(), "World! 😊".to_string())].into(); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int32, false).with_metadata(metadata)), + Arc::new(Int32Array::from(vec![2, 4, 6])) as Arc, + )]); + + // export it + let (array, schema) = to_ffi(&struct_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.data_type(), struct_array.data_type()); + assert_eq!(array, &struct_array); + + Ok(()) + } + + #[test] + fn test_union_sparse_array() -> Result<()> { + let mut builder = UnionBuilder::new_sparse(); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append::("a", 4).unwrap(); + let union = builder.build().unwrap(); + + // export it + let (array, schema) = to_ffi(&union.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + let array = array.as_any().downcast_ref::().unwrap(); + + let expected_type_ids = vec![0_i8, 0, 1, 0]; + + // Check type ids + assert_eq!(*array.type_ids(), expected_type_ids); + for (i, id) in expected_type_ids.iter().enumerate() { + assert_eq!(id, &array.type_id(i)); + } + + // Check offsets, sparse union should only have a single buffer, i.e. no offsets + assert!(array.offsets().is_none()); + + for i in 0..array.len() { + let slot = array.value(i); + match i { + 0 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(1_i32, value); + } + 1 => assert!(slot.is_null(0)), + 2 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(value, 3_f64); + } + 3 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(4_i32, value); + } + _ => unreachable!(), + } + } + + Ok(()) + } + + #[test] + fn test_union_dense_array() -> Result<()> { + let mut builder = UnionBuilder::new_dense(); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append::("a", 4).unwrap(); + let union = builder.build().unwrap(); + + // export it + let (array, schema) = to_ffi(&union.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = UnionArray::from(data); + + let expected_type_ids = vec![0_i8, 0, 1, 0]; + + // Check type ids + assert_eq!(*array.type_ids(), expected_type_ids); + for (i, id) in expected_type_ids.iter().enumerate() { + assert_eq!(id, &array.type_id(i)); + } + + assert!(array.offsets().is_some()); + + for i in 0..array.len() { + let slot = array.value(i); + match i { + 0 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(1_i32, value); + } + 1 => assert!(slot.is_null(0)), + 2 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(value, 3_f64); + } + 3 => { + let slot = slot.as_primitive::(); + assert!(!slot.is_null(0)); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(4_i32, value); + } + _ => unreachable!(), + } + } + + Ok(()) + } + + #[test] + fn test_run_array() -> Result<()> { + let value_data = + PrimitiveArray::::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + // Construct a run_ends array: + let run_ends_values = [4_i32, 6, 7, 9, 13, 18, 20, 22]; + let run_ends_data = + PrimitiveArray::::from_iter_values(run_ends_values.iter().copied()); + + // Construct a run ends encoded array from the above two + let ree_array = RunArray::::try_new(&run_ends_data, &value_data).unwrap(); + + // export it + let (array, schema) = to_ffi(&ree_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(array.data_type(), ree_array.data_type()); + assert_eq!(array.run_ends().values(), ree_array.run_ends().values()); + assert_eq!(array.values(), ree_array.values()); + + Ok(()) + } + + #[test] + fn test_nullable_run_array() -> Result<()> { + let nulls = NullBuffer::from(vec![true, false, true, true, false]); + let value_data = + PrimitiveArray::::new(vec![1_i8, 2, 3, 4, 5].into(), Some(nulls)); + + // Construct a run_ends array: + let run_ends_values = [5_i32, 6, 7, 8, 10]; + let run_ends_data = + PrimitiveArray::::from_iter_values(run_ends_values.iter().copied()); + + // Construct a run ends encoded array from the above two + let ree_array = RunArray::::try_new(&run_ends_data, &value_data).unwrap(); + + // export it + let (array, schema) = to_ffi(&ree_array.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // perform some operation + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(array.data_type(), ree_array.data_type()); + assert_eq!(array.run_ends().values(), ree_array.run_ends().values()); + assert_eq!(array.values(), ree_array.values()); + + Ok(()) + } +} + +#[cfg(test)] +mod tests_from_ffi { + use std::sync::Arc; + + use arrow_buffer::{bit_util, buffer::Buffer, MutableBuffer, OffsetBuffer}; + use arrow_data::transform::MutableArrayData; + use arrow_data::ArrayData; + use arrow_schema::{DataType, Field}; + + use super::{ImportedArrowArray, Result}; + use crate::builder::GenericByteViewBuilder; + use crate::types::{BinaryViewType, ByteViewType, Int32Type, StringViewType}; + use crate::{ + array::{ + Array, BooleanArray, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, + Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array, + }, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + make_array, ArrayRef, GenericByteViewArray, ListArray, + }; + + fn test_round_trip(expected: &ArrayData) -> Result<()> { + // here we export the array + let array = FFI_ArrowArray::new(expected); + let schema = FFI_ArrowSchema::try_from(expected.data_type())?; + + // simulate an external consumer by being the consumer + let result = &unsafe { from_ffi(array, &schema) }?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn test_u32() -> Result<()> { + let array = UInt32Array::from(vec![Some(2), None, Some(1), None]); + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_u64() -> Result<()> { + let array = UInt64Array::from(vec![Some(2), None, Some(1), None]); + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_i64() -> Result<()> { + let array = Int64Array::from(vec![Some(2), None, Some(1), None]); + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_struct() -> Result<()> { + let inner = StructArray::from(vec![ + ( + Arc::new(Field::new("a1", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![true, true, false, false])) as Arc, + ), + ( + Arc::new(Field::new("a2", DataType::UInt32, false)), + Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), + ), + ]); + + let array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", inner.data_type().clone(), false)), + Arc::new(inner) as Arc, + ), + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, + ), + ( + Arc::new(Field::new("c", DataType::UInt32, false)), + Arc::new(UInt32Array::from(vec![42, 28, 19, 31])), + ), + ]); + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_dictionary() -> Result<()> { + let values = StringArray::from(vec![Some("foo"), Some("bar"), None]); + let keys = Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(1), + Some(1), + None, + Some(1), + Some(2), + Some(1), + None, + ]); + let array = DictionaryArray::new(keys, Arc::new(values)); + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_fixed_size_binary() -> Result<()> { + let values = vec![vec![10, 10, 10], vec![20, 20, 20], vec![30, 30, 30]]; + let array = FixedSizeBinaryArray::try_from_iter(values.into_iter())?; + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_fixed_size_binary_with_nulls() -> Result<()> { + let values = vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ]; + let array = FixedSizeBinaryArray::try_from_sparse_iter_with_size(values.into_iter(), 3)?; + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_fixed_size_list() -> Result<()> { + let v: Vec = (0..9).collect(); + let value_data = ArrayData::builder(DataType::Int64) + .len(9) + .add_buffer(Buffer::from_slice_ref(v)) + .build()?; + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("f", DataType::Int64, false)), 3); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build()?; + let array = FixedSizeListArray::from(list_data); + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_fixed_size_list_with_nulls() -> Result<()> { + // 0100 0110 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 1); + bit_util::set_bit(&mut validity_bits, 2); + bit_util::set_bit(&mut validity_bits, 6); + + let v: Vec = (0..16).collect(); + let value_data = ArrayData::builder(DataType::Int16) + .len(16) + .add_buffer(Buffer::from_slice_ref(v)) + .build()?; + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("f", DataType::Int16, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(8) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(value_data) + .build()?; + let array = FixedSizeListArray::from(list_data); + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_fixed_size_list_nested() -> Result<()> { + let v: Vec = (0..16).collect(); + let value_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(Buffer::from_slice_ref(v)) + .build()?; + + let offsets: Vec = vec![0, 2, 4, 6, 8, 10, 12, 14, 16]; + let value_offsets = Buffer::from_slice_ref(offsets); + let inner_list_data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let inner_list_data = ArrayData::builder(inner_list_data_type.clone()) + .len(8) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build()?; + + // 0000 0100 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 2); + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("f", inner_list_data_type, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(4) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(inner_list_data) + .build()?; + + let array = FixedSizeListArray::from(list_data); + + let data = array.into_data(); + test_round_trip(&data) + } + + #[test] + fn test_empty_string_with_non_zero_offset() -> Result<()> { + // Simulate an empty string array with a non-zero offset from a producer + let data: Buffer = MutableBuffer::new(0).into(); + let offsets = OffsetBuffer::new(vec![123].into()); + let string_array = + unsafe { StringArray::new_unchecked(offsets.clone(), data.clone(), None) }; + + let data = string_array.into_data(); + + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type())?; + + let dt = DataType::try_from(&schema)?; + let array = Arc::new(array); + let imported_array = ImportedArrowArray { + array: &array, + data_type: dt, + owner: &array, + }; + + let offset_buf_len = imported_array.buffer_len(1, &[], &imported_array.data_type)?; + let data_buf_len = imported_array.buffer_len(2, &[], &imported_array.data_type)?; + + assert_eq!(offset_buf_len, 4); + assert_eq!(data_buf_len, 0); + + test_round_trip(&imported_array.consume()?) + } + + fn roundtrip_string_array(array: StringArray) -> StringArray { + let data = array.into_data(); + + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type()).unwrap(); + + let array = unsafe { from_ffi(array, &schema) }.unwrap(); + StringArray::from(array) + } + + fn roundtrip_byte_view_array( + array: GenericByteViewArray, + ) -> GenericByteViewArray { + let data = array.into_data(); + + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type()).unwrap(); + + let array = unsafe { from_ffi(array, &schema) }.unwrap(); + GenericByteViewArray::::from(array) + } + + fn extend_array(array: &dyn Array) -> ArrayRef { + let len = array.len(); + let data = array.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data], false, len); + mutable.extend(0, 0, len); + make_array(mutable.freeze()) + } + + #[test] + fn test_extend_imported_string_slice() { + let mut strings = vec![]; + + for i in 0..1000 { + strings.push(format!("string: {}", i)); + } + + let string_array = StringArray::from(strings); + + let imported = roundtrip_string_array(string_array.clone()); + assert_eq!(imported.len(), 1000); + assert_eq!(imported.value(0), "string: 0"); + assert_eq!(imported.value(499), "string: 499"); + + let copied = extend_array(&imported); + assert_eq!( + copied.as_any().downcast_ref::().unwrap(), + &imported + ); + + let slice = string_array.slice(500, 500); + + let imported = roundtrip_string_array(slice); + assert_eq!(imported.len(), 500); + assert_eq!(imported.value(0), "string: 500"); + assert_eq!(imported.value(499), "string: 999"); + + let copied = extend_array(&imported); + assert_eq!( + copied.as_any().downcast_ref::().unwrap(), + &imported + ); + } + + fn roundtrip_list_array(array: ListArray) -> ListArray { + let data = array.into_data(); + + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type()).unwrap(); + + let array = unsafe { from_ffi(array, &schema) }.unwrap(); + ListArray::from(array) + } + + #[test] + fn test_extend_imported_list_slice() { + let mut data = vec![]; + + for i in 0..1000 { + let mut list = vec![]; + for j in 0..100 { + list.push(Some(i * 1000 + j)); + } + data.push(Some(list)); + } + + let list_array = ListArray::from_iter_primitive::(data); + + let slice = list_array.slice(500, 500); + let imported = roundtrip_list_array(slice.clone()); + assert_eq!(imported.len(), 500); + assert_eq!(&slice, &imported); + + let copied = extend_array(&imported); + assert_eq!( + copied.as_any().downcast_ref::().unwrap(), + &imported + ); + } + + /// Helper trait to allow us to use easily strings as either BinaryViewType::Native or + /// StringViewType::Native scalars. + trait NativeFromStr { + fn from_str(value: &str) -> &Self; + } + + impl NativeFromStr for str { + fn from_str(value: &str) -> &Self { + value + } + } + + impl NativeFromStr for [u8] { + fn from_str(value: &str) -> &Self { + value.as_bytes() + } + } + + #[test] + fn test_round_trip_byte_view() { + fn test_case() + where + T: ByteViewType, + T::Native: NativeFromStr, + { + macro_rules! run_test_case { + ($array:expr) => {{ + // round-trip through C Data Interface + let len = $array.len(); + let imported = roundtrip_byte_view_array($array); + assert_eq!(imported.len(), len); + + let copied = extend_array(&imported); + assert_eq!( + copied + .as_any() + .downcast_ref::>() + .unwrap(), + &imported + ); + }}; + } + + // Empty test case. + let empty = GenericByteViewBuilder::::new().finish(); + run_test_case!(empty); + + // All inlined strings test case. + let mut all_inlined = GenericByteViewBuilder::::new(); + all_inlined.append_value(T::Native::from_str("inlined1")); + all_inlined.append_value(T::Native::from_str("inlined2")); + all_inlined.append_value(T::Native::from_str("inlined3")); + let all_inlined = all_inlined.finish(); + assert_eq!(all_inlined.data_buffers().len(), 0); + run_test_case!(all_inlined); + + // some inlined + non-inlined, 1 variadic buffer. + let mixed_one_variadic = { + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value(T::Native::from_str("inlined")); + let block_id = + builder.append_block(Buffer::from("non-inlined-string-buffer".as_bytes())); + builder.try_append_view(block_id, 0, 25).unwrap(); + builder.finish() + }; + assert_eq!(mixed_one_variadic.data_buffers().len(), 1); + run_test_case!(mixed_one_variadic); + + // inlined + non-inlined, 2 variadic buffers. + let mixed_two_variadic = { + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value(T::Native::from_str("inlined")); + let block_id = + builder.append_block(Buffer::from("non-inlined-string-buffer".as_bytes())); + builder.try_append_view(block_id, 0, 25).unwrap(); + + let block_id = builder + .append_block(Buffer::from("another-non-inlined-string-buffer".as_bytes())); + builder.try_append_view(block_id, 0, 33).unwrap(); + builder.finish() + }; + assert_eq!(mixed_two_variadic.data_buffers().len(), 2); + run_test_case!(mixed_two_variadic); + } + + test_case::(); + test_case::(); + } +} diff --git a/arrow/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs similarity index 65% rename from arrow/src/ffi_stream.rs rename to arrow-array/src/ffi_stream.rs index 3a85f2ef6421..34f0cd7cfc74 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow-array/src/ffi_stream.rs @@ -37,44 +37,42 @@ //! let reader = Box::new(FileReader::try_new(file).unwrap()); //! //! // export it -//! let stream = Box::new(FFI_ArrowArrayStream::empty()); -//! let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; -//! unsafe { export_reader_into_raw(reader, stream_ptr) }; +//! let mut stream = FFI_ArrowArrayStream::empty(); +//! unsafe { export_reader_into_raw(reader, &mut stream) }; //! //! // consumed and used by something else... //! //! // import it -//! let stream_reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; +//! let stream_reader = unsafe { ArrowArrayStreamReader::from_raw(&mut stream).unwrap() }; //! let imported_schema = stream_reader.schema(); //! //! let mut produced_batches = vec![]; //! for batch in stream_reader { //! produced_batches.push(batch.unwrap()); //! } -//! -//! // (drop/release) -//! unsafe { -//! Box::from_raw(stream_ptr); -//! } //! Ok(()) //! } //! ``` +use arrow_schema::DataType; +use std::ffi::CStr; +use std::ptr::addr_of; use std::{ - convert::TryFrom, ffi::CString, os::raw::{c_char, c_int, c_void}, sync::Arc, }; +use arrow_data::ffi::FFI_ArrowArray; +use arrow_schema::{ffi::FFI_ArrowSchema, ArrowError, Schema, SchemaRef}; + use crate::array::Array; use crate::array::StructArray; -use crate::datatypes::{Schema, SchemaRef}; -use crate::error::ArrowError; -use crate::error::Result; -use crate::ffi::*; +use crate::ffi::from_ffi_and_data_type; use crate::record_batch::{RecordBatch, RecordBatchReader}; +type Result = std::result::Result; + const ENOMEM: i32 = 12; const EIO: i32 = 5; const EINVAL: i32 = 22; @@ -85,25 +83,23 @@ const ENOSYS: i32 = 78; /// This was created by bindgen #[repr(C)] #[derive(Debug)] +#[allow(non_camel_case_types)] pub struct FFI_ArrowArrayStream { - pub get_schema: Option< - unsafe extern "C" fn( - arg1: *mut FFI_ArrowArrayStream, - out: *mut FFI_ArrowSchema, - ) -> c_int, - >, - pub get_next: Option< - unsafe extern "C" fn( - arg1: *mut FFI_ArrowArrayStream, - out: *mut FFI_ArrowArray, - ) -> c_int, - >, - pub get_last_error: - Option *const c_char>, - pub release: Option, + /// C function to get schema from the stream + pub get_schema: + Option c_int>, + /// C function to get next array from the stream + pub get_next: Option c_int>, + /// C function to get the error from last operation on the stream + pub get_last_error: Option *const c_char>, + /// C function to release the stream + pub release: Option, + /// Private data used by the stream pub private_data: *mut c_void, } +unsafe impl Send for FFI_ArrowArrayStream {} + // callback used to drop [FFI_ArrowArrayStream] when it is exported. unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { if stream.is_null() { @@ -122,8 +118,8 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { } struct StreamPrivateData { - batch_reader: Box, - last_error: String, + batch_reader: Box, + last_error: Option, } // The callback used to get array schema @@ -145,8 +141,12 @@ unsafe extern "C" fn get_next( // The callback used to get the error from last operation on the `FFI_ArrowArrayStream` unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char { let mut ffi_stream = ExportedArrayStream { stream }; - let last_error = ffi_stream.get_last_error(); - CString::new(last_error.as_str()).unwrap().into_raw() + // The consumer should not take ownership of this string, we should return + // a const pointer to it. + match ffi_stream.get_last_error() { + Some(err_string) => err_string.as_ptr(), + None => std::ptr::null(), + } } impl Drop for FFI_ArrowArrayStream { @@ -160,10 +160,10 @@ impl Drop for FFI_ArrowArrayStream { impl FFI_ArrowArrayStream { /// Creates a new [`FFI_ArrowArrayStream`]. - pub fn new(batch_reader: Box) -> Self { + pub fn new(batch_reader: Box) -> Self { let private_data = Box::new(StreamPrivateData { batch_reader, - last_error: String::new(), + last_error: None, }); Self { @@ -175,6 +175,22 @@ impl FFI_ArrowArrayStream { } } + /// Takes ownership of the pointed to [`FFI_ArrowArrayStream`] + /// + /// This acts to [move] the data out of `raw_stream`, setting the release callback to NULL + /// + /// # Safety + /// + /// * `raw_stream` must be [valid] for reads and writes + /// * `raw_stream` must be properly aligned + /// * `raw_stream` must point to a properly initialized value of [`FFI_ArrowArrayStream`] + /// + /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array + /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety + pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Self { + std::ptr::replace(raw_stream, Self::empty()) + } + /// Creates a new empty [FFI_ArrowArrayStream]. Used to import from the C Stream Interface. pub fn empty() -> Self { Self { @@ -197,59 +213,56 @@ impl ExportedArrayStream { } pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 { - let mut private_data = self.get_private_data(); + let private_data = self.get_private_data(); let reader = &private_data.batch_reader; let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref()); match schema { - Ok(mut schema) => unsafe { - std::ptr::copy(&schema as *const FFI_ArrowSchema, out, 1); - schema.release = None; + Ok(schema) => { + unsafe { std::ptr::copy(addr_of!(schema), out, 1) }; + std::mem::forget(schema); 0 - }, + } Err(ref err) => { - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()).expect("Error string has a null byte in it."), + ); get_error_code(err) } } } pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 { - let mut private_data = self.get_private_data(); + let private_data = self.get_private_data(); let reader = &mut private_data.batch_reader; - let ret_code = match reader.next() { + match reader.next() { None => { // Marks ArrowArray released to indicate reaching the end of stream. - unsafe { - (*out).release = None; - } + unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) } 0 } Some(next_batch) => { if let Ok(batch) = next_batch { let struct_array = StructArray::from(batch); - let mut array = FFI_ArrowArray::new(struct_array.data()); + let array = FFI_ArrowArray::new(&struct_array.to_data()); - unsafe { - std::ptr::copy(&array as *const FFI_ArrowArray, out, 1); - array.release = None; - 0 - } + unsafe { std::ptr::write_unaligned(out, array) }; + 0 } else { let err = &next_batch.unwrap_err(); - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()).expect("Error string has a null byte in it."), + ); get_error_code(err) } } - }; - - ret_code + } } - pub fn get_last_error(&mut self) -> &String { - &self.get_private_data().last_error + pub fn get_last_error(&mut self) -> Option<&CString> { + self.get_private_data().last_error.as_ref() } } @@ -257,38 +270,35 @@ fn get_error_code(err: &ArrowError) -> i32 { match err { ArrowError::NotYetImplemented(_) => ENOSYS, ArrowError::MemoryError(_) => ENOMEM, - ArrowError::IoError(_) => EIO, + ArrowError::IoError(_, _) => EIO, _ => EINVAL, } } /// A `RecordBatchReader` which imports Arrays from `FFI_ArrowArrayStream`. +/// /// Struct used to fetch `RecordBatch` from the C Stream Interface. /// Its main responsibility is to expose `RecordBatchReader` functionality /// that requires [FFI_ArrowArrayStream]. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ArrowArrayStreamReader { - stream: Arc, + stream: FFI_ArrowArrayStream, schema: SchemaRef, } /// Gets schema from a raw pointer of `FFI_ArrowArrayStream`. This is used when constructing /// `ArrowArrayStreamReader` to cache schema. fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result { - let empty_schema = Arc::new(FFI_ArrowSchema::empty()); - let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema; - - let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, schema_ptr) }; + let mut schema = FFI_ArrowSchema::empty(); - let ffi_schema = unsafe { Arc::from_raw(schema_ptr) }; + let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, &mut schema) }; if ret_code == 0 { - let schema = Schema::try_from(ffi_schema.as_ref()).unwrap(); + let schema = Schema::try_from(&schema)?; Ok(Arc::new(schema)) } else { Err(ArrowError::CDataInterface(format!( - "Cannot get schema from input stream. Error code: {:?}", - ret_code + "Cannot get schema from input stream. Error code: {ret_code:?}" ))) } } @@ -297,21 +307,16 @@ impl ArrowArrayStreamReader { /// Creates a new `ArrowArrayStreamReader` from a `FFI_ArrowArrayStream`. /// This is used to import from the C Stream Interface. #[allow(dead_code)] - pub fn try_new(stream: FFI_ArrowArrayStream) -> Result { + pub fn try_new(mut stream: FFI_ArrowArrayStream) -> Result { if stream.release.is_none() { return Err(ArrowError::CDataInterface( "input stream is already released".to_string(), )); } - let stream_ptr = Arc::into_raw(Arc::new(stream)) as *mut FFI_ArrowArrayStream; + let schema = get_stream_schema(&mut stream)?; - let schema = get_stream_schema(stream_ptr)?; - - Ok(Self { - stream: unsafe { Arc::from_raw(stream_ptr) }, - schema, - }) + Ok(Self { stream, schema }) } /// Creates a new `ArrowArrayStreamReader` from a raw pointer of `FFI_ArrowArrayStream`. @@ -322,29 +327,23 @@ impl ArrowArrayStreamReader { /// the pointer. /// /// # Safety - /// This function dereferences a raw pointer of `FFI_ArrowArrayStream`. + /// + /// See [`FFI_ArrowArrayStream::from_raw`] pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result { - let stream_data = std::ptr::replace(raw_stream, FFI_ArrowArrayStream::empty()); - - Self::try_new(stream_data) + Self::try_new(FFI_ArrowArrayStream::from_raw(raw_stream)) } /// Get the last error from `ArrowArrayStreamReader` - fn get_stream_last_error(&self) -> Option { - self.stream.get_last_error?; - - let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream; - - let error_str = unsafe { - let c_str = self.stream.get_last_error.unwrap()(stream_ptr) as *mut c_char; - CString::from_raw(c_str).into_string() - }; + fn get_stream_last_error(&mut self) -> Option { + let get_last_error = self.stream.get_last_error?; - if let Err(err) = error_str { - Some(err.to_string()) - } else { - Some(error_str.unwrap()) + let error_str = unsafe { get_last_error(&mut self.stream) }; + if error_str.is_null() { + return None; } + + let error_str = unsafe { CStr::from_ptr(error_str) }; + Some(error_str.to_string_lossy().to_string()) } } @@ -352,35 +351,21 @@ impl Iterator for ArrowArrayStreamReader { type Item = Result; fn next(&mut self) -> Option { - let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream; - - let empty_array = Arc::new(FFI_ArrowArray::empty()); - let array_ptr = Arc::into_raw(empty_array) as *mut FFI_ArrowArray; + let mut array = FFI_ArrowArray::empty(); - let ret_code = unsafe { self.stream.get_next.unwrap()(stream_ptr, array_ptr) }; + let ret_code = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) }; if ret_code == 0 { - let ffi_array = unsafe { Arc::from_raw(array_ptr) }; - // The end of stream has been reached - ffi_array.release?; - - let schema_ref = self.schema(); - let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?; - - let data = ArrowArray { - array: ffi_array, - schema: Arc::new(schema), + if array.is_released() { + return None; } - .to_data() - .ok()?; - - let record_batch = RecordBatch::from(&StructArray::from(data)); - Some(Ok(record_batch)) + let result = unsafe { + from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone())) + }; + Some(result.map(|data| RecordBatch::from(StructArray::from(data)))) } else { - unsafe { Arc::from_raw(array_ptr) }; - let last_error = self.get_stream_last_error(); let err = ArrowError::CDataInterface(last_error.unwrap()); Some(Err(err)) @@ -399,8 +384,9 @@ impl RecordBatchReader for ArrowArrayStreamReader { /// # Safety /// Assumes that the pointer represents valid C Stream Interfaces, both in memory /// representation and lifetime via the `release` mechanism. +#[deprecated(note = "Use FFI_ArrowArrayStream::new")] pub unsafe fn export_reader_into_raw( - reader: Box, + reader: Box, out_stream: *mut FFI_ArrowArrayStream, ) { let stream = FFI_ArrowArrayStream::new(reader); @@ -412,18 +398,20 @@ pub unsafe fn export_reader_into_raw( mod tests { use super::*; + use arrow_schema::Field; + use crate::array::Int32Array; - use crate::datatypes::{Field, Schema}; + use crate::ffi::from_ffi; struct TestRecordBatchReader { schema: SchemaRef, - iter: Box>>, + iter: Box> + Send>, } impl TestRecordBatchReader { pub fn new( schema: SchemaRef, - iter: Box>>, + iter: Box> + Send>, ) -> Box { Box::new(TestRecordBatchReader { schema, iter }) } @@ -455,52 +443,36 @@ mod tests { let reader = TestRecordBatchReader::new(schema.clone(), iter); // Export a `RecordBatchReader` through `FFI_ArrowArrayStream` - let stream = Arc::new(FFI_ArrowArrayStream::empty()); - let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream; - - unsafe { export_reader_into_raw(reader, stream_ptr) }; - - let empty_schema = Arc::new(FFI_ArrowSchema::empty()); - let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema; + let mut ffi_stream = FFI_ArrowArrayStream::new(reader); // Get schema from `FFI_ArrowArrayStream` - let ret_code = unsafe { get_schema(stream_ptr, schema_ptr) }; + let mut ffi_schema = FFI_ArrowSchema::empty(); + let ret_code = unsafe { get_schema(&mut ffi_stream, &mut ffi_schema) }; assert_eq!(ret_code, 0); - let ffi_schema = unsafe { Arc::from_raw(schema_ptr) }; - - let exported_schema = Schema::try_from(ffi_schema.as_ref()).unwrap(); + let exported_schema = Schema::try_from(&ffi_schema).unwrap(); assert_eq!(&exported_schema, schema.as_ref()); // Get array from `FFI_ArrowArrayStream` let mut produced_batches = vec![]; loop { - let empty_array = Arc::new(FFI_ArrowArray::empty()); - let array_ptr = Arc::into_raw(empty_array.clone()) as *mut FFI_ArrowArray; - - let ret_code = unsafe { get_next(stream_ptr, array_ptr) }; + let mut ffi_array = FFI_ArrowArray::empty(); + let ret_code = unsafe { get_next(&mut ffi_stream, &mut ffi_array) }; assert_eq!(ret_code, 0); // The end of stream has been reached - let ffi_array = unsafe { Arc::from_raw(array_ptr) }; - if ffi_array.release.is_none() { + if ffi_array.is_released() { break; } - let array = ArrowArray { - array: ffi_array, - schema: ffi_schema.clone(), - } - .to_data() - .unwrap(); + let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap(); - let record_batch = RecordBatch::from(&StructArray::from(array)); + let record_batch = RecordBatch::from(StructArray::from(array)); produced_batches.push(record_batch); } assert_eq!(produced_batches, vec![batch.clone(), batch]); - unsafe { Arc::from_raw(stream_ptr) }; Ok(()) } @@ -516,10 +488,8 @@ mod tests { let reader = TestRecordBatchReader::new(schema.clone(), iter); // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` - let stream = Arc::new(FFI_ArrowArrayStream::new(reader)); - let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream; - let stream_reader = - unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; + let stream = FFI_ArrowArrayStream::new(reader); + let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap(); let imported_schema = stream_reader.schema(); assert_eq!(imported_schema, schema); @@ -531,7 +501,6 @@ mod tests { assert_eq!(produced_batches, vec![batch.clone(), batch]); - unsafe { Arc::from_raw(stream_ptr) }; Ok(()) } @@ -550,4 +519,31 @@ mod tests { _test_round_trip_import(vec![array.clone(), array.clone(), array]) } + + #[test] + fn test_error_import() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let iter = Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter()); + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` + let stream = FFI_ArrowArrayStream::new(reader); + let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap(); + + let imported_schema = stream_reader.schema(); + assert_eq!(imported_schema, schema); + + let mut produced_batches = vec![]; + for batch in stream_reader { + produced_batches.push(batch); + } + + // The results should outlive the lifetime of the stream itself. + assert_eq!(produced_batches.len(), 1); + assert!(produced_batches[0].is_err()); + + Ok(()) + } } diff --git a/arrow/src/array/iterator.rs b/arrow-array/src/iterator.rs similarity index 73% rename from arrow/src/array/iterator.rs rename to arrow-array/src/iterator.rs index 4269e99625b7..3f9cc0d525c1 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -15,20 +15,39 @@ // specific language governing permissions and limitations // under the License. -use crate::array::array::ArrayAccessor; -use crate::array::{DecimalArray, FixedSizeBinaryArray}; -use crate::datatypes::{Decimal128Type, Decimal256Type}; +//! Idiomatic iterators for [`Array`](crate::Array) -use super::{ - BooleanArray, GenericBinaryArray, GenericListArray, GenericStringArray, - PrimitiveArray, +use crate::array::{ + ArrayAccessor, BooleanArray, FixedSizeBinaryArray, GenericBinaryArray, GenericListArray, + GenericStringArray, PrimitiveArray, }; - -/// an iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] -// Note: This implementation is based on std's [Vec]s' [IntoIter]. +use crate::{FixedSizeListArray, MapArray}; +use arrow_buffer::NullBuffer; + +/// An iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] +/// +/// # Performance +/// +/// [`ArrayIter`] provides an idiomatic way to iterate over an array, however, this +/// comes at the cost of performance. In particular the interleaved handling of +/// the null mask is often sub-optimal. +/// +/// If performing an infallible operation, it is typically faster to perform the operation +/// on every index of the array, and handle the null mask separately. For [`PrimitiveArray`] +/// this functionality is provided by [`compute::unary`] +/// +/// If performing a fallible operation, it isn't possible to perform the operation independently +/// of the null mask, as this might result in a spurious failure on a null index. However, +/// there are more efficient ways to iterate over just the non-null indices, this functionality +/// is provided by [`compute::try_unary`] +/// +/// [`PrimitiveArray`]: crate::PrimitiveArray +/// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html +/// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html #[derive(Debug)] pub struct ArrayIter { array: T, + logical_nulls: Option, current: usize, current_end: usize, } @@ -37,12 +56,22 @@ impl ArrayIter { /// create a new iterator pub fn new(array: T) -> Self { let len = array.len(); + let logical_nulls = array.logical_nulls(); ArrayIter { array, + logical_nulls, current: 0, current_end: len, } } + + #[inline] + fn is_null(&self, idx: usize) -> bool { + self.logical_nulls + .as_ref() + .map(|x| x.is_null(idx)) + .unwrap_or_default() + } } impl Iterator for ArrayIter { @@ -52,7 +81,7 @@ impl Iterator for ArrayIter { fn next(&mut self) -> Option { if self.current == self.current_end { None - } else if self.array.is_null(self.current) { + } else if self.is_null(self.current) { self.current += 1; Some(None) } else { @@ -81,7 +110,7 @@ impl DoubleEndedIterator for ArrayIter { None } else { self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { + Some(if self.is_null(self.current_end) { None } else { // Safety: @@ -100,20 +129,20 @@ impl ExactSizeIterator for ArrayIter {} /// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray pub type PrimitiveIter<'a, T> = ArrayIter<&'a PrimitiveArray>; +/// an iterator that returns Some(T) or None, that can be used on any BooleanArray pub type BooleanIter<'a> = ArrayIter<&'a BooleanArray>; +/// an iterator that returns Some(T) or None, that can be used on any Utf8Array pub type GenericStringIter<'a, T> = ArrayIter<&'a GenericStringArray>; +/// an iterator that returns Some(T) or None, that can be used on any BinaryArray pub type GenericBinaryIter<'a, T> = ArrayIter<&'a GenericBinaryArray>; +/// an iterator that returns Some(T) or None, that can be used on any FixedSizeBinaryArray pub type FixedSizeBinaryIter<'a> = ArrayIter<&'a FixedSizeBinaryArray>; +/// an iterator that returns Some(T) or None, that can be used on any FixedSizeListArray +pub type FixedSizeListIter<'a> = ArrayIter<&'a FixedSizeListArray>; +/// an iterator that returns Some(T) or None, that can be used on any ListArray pub type GenericListArrayIter<'a, O> = ArrayIter<&'a GenericListArray>; - -pub type DecimalIter<'a, T> = ArrayIter<&'a DecimalArray>; -/// an iterator that returns `Some(Decimal128)` or `None`, that can be used on a -/// [`super::Decimal128Array`] -pub type Decimal128Iter<'a> = DecimalIter<'a, Decimal128Type>; - -/// an iterator that returns `Some(Decimal256)` or `None`, that can be used on a -/// [`super::Decimal256Array`] -pub type Decimal256Iter<'a> = DecimalIter<'a, Decimal256Type>; +/// an iterator that returns Some(T) or None, that can be used on any MapArray +pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; #[cfg(test)] mod tests { @@ -158,8 +187,7 @@ mod tests { #[test] fn test_string_array_iter_round_trip() { - let array = - StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); + let array = StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); let array = Arc::new(array) as ArrayRef; let array = array.as_any().downcast_ref::().unwrap(); @@ -182,8 +210,7 @@ mod tests { // check if DoubleEndedIterator is implemented let result: StringArray = array.iter().rev().collect(); - let rev_array = - StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); + let rev_array = StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); assert_eq!(result, rev_array); // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some("a")); diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs new file mode 100644 index 000000000000..90bc5e31205a --- /dev/null +++ b/arrow-array/src/lib.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The central type in Apache Arrow are arrays, which are a known-length sequence of values +//! all having the same type. This crate provides concrete implementations of each type, as +//! well as an [`Array`] trait that can be used for type-erasure. +//! +//! # Building an Array +//! +//! Most [`Array`] implementations can be constructed directly from iterators or [`Vec`] +//! +//! ``` +//! # use arrow_array::{Int32Array, ListArray, StringArray}; +//! # use arrow_array::types::Int32Type; +//! # +//! Int32Array::from(vec![1, 2]); +//! Int32Array::from(vec![Some(1), None]); +//! Int32Array::from_iter([1, 2, 3, 4]); +//! Int32Array::from_iter([Some(1), Some(2), None, Some(4)]); +//! +//! StringArray::from(vec!["foo", "bar"]); +//! StringArray::from(vec![Some("foo"), None]); +//! StringArray::from_iter([Some("foo"), None]); +//! StringArray::from_iter_values(["foo", "bar"]); +//! +//! ListArray::from_iter_primitive::([ +//! Some(vec![Some(1), None, Some(3)]), +//! None, +//! Some(vec![]) +//! ]); +//! ``` +//! +//! Additionally [`ArrayBuilder`](builder::ArrayBuilder) implementations can be +//! used to construct arrays with a push-based interface +//! +//! ``` +//! # use arrow_array::Int16Array; +//! # +//! // Create a new builder with a capacity of 100 +//! let mut builder = Int16Array::builder(100); +//! +//! // Append a single primitive value +//! builder.append_value(1); +//! // Append a null value +//! builder.append_null(); +//! // Append a slice of primitive values +//! builder.append_slice(&[2, 3, 4]); +//! +//! // Build the array +//! let array = builder.finish(); +//! +//! assert_eq!(5, array.len()); +//! assert_eq!(2, array.value(2)); +//! assert_eq!(&array.values()[3..5], &[3, 4]) +//! ``` +//! +//! # Low-level API +//! +//! Internally, arrays consist of one or more shared memory regions backed by a [`Buffer`], +//! the number and meaning of which depend on the array’s data type, as documented in +//! the [Arrow specification]. +//! +//! For example, the type [`Int16Array`] represents an array of 16-bit integers and consists of: +//! +//! * An optional [`NullBuffer`] identifying any null values +//! * A contiguous [`ScalarBuffer`] of values +//! +//! Similarly, the type [`StringArray`] represents an array of UTF-8 strings and consists of: +//! +//! * An optional [`NullBuffer`] identifying any null values +//! * An offsets [`OffsetBuffer`] identifying valid UTF-8 sequences within the values buffer +//! * A values [`Buffer`] of UTF-8 encoded string data +//! +//! Array constructors such as [`PrimitiveArray::try_new`] provide the ability to cheaply +//! construct an array from these parts, with functions such as [`PrimitiveArray::into_parts`] +//! providing the reverse operation. +//! +//! ``` +//! # use arrow_array::{Array, Int32Array, StringArray}; +//! # use arrow_buffer::OffsetBuffer; +//! # +//! // Create a Int32Array from Vec without copying +//! let array = Int32Array::new(vec![1, 2, 3].into(), None); +//! assert_eq!(array.values(), &[1, 2, 3]); +//! assert_eq!(array.null_count(), 0); +//! +//! // Create a StringArray from parts +//! let offsets = OffsetBuffer::new(vec![0, 5, 10].into()); +//! let array = StringArray::new(offsets, b"helloworld".into(), None); +//! let values: Vec<_> = array.iter().map(|x| x.unwrap()).collect(); +//! assert_eq!(values, &["hello", "world"]); +//! ``` +//! +//! As [`Buffer`], and its derivatives, can be created from [`Vec`] without copying, this provides +//! an efficient way to not only interoperate with other Rust code, but also implement kernels +//! optimised for the arrow data layout - e.g. by handling buffers instead of values. +//! +//! # Zero-Copy Slicing +//! +//! Given an [`Array`] of arbitrary length, it is possible to create an owned slice of this +//! data. Internally this just increments some ref-counts, and so is incredibly cheap +//! +//! ```rust +//! # use arrow_array::Int32Array; +//! let array = Int32Array::from_iter([1, 2, 3]); +//! +//! // Slice with offset 1 and length 2 +//! let sliced = array.slice(1, 2); +//! assert_eq!(sliced.values(), &[2, 3]); +//! ``` +//! +//! # Downcasting an Array +//! +//! Arrays are often passed around as a dynamically typed [`&dyn Array`] or [`ArrayRef`]. +//! For example, [`RecordBatch`](`crate::RecordBatch`) stores columns as [`ArrayRef`]. +//! +//! Whilst these arrays can be passed directly to the [`compute`], [`csv`], [`json`], etc... APIs, +//! it is often the case that you wish to interact with the concrete arrays directly. +//! +//! This requires downcasting to the concrete type of the array: +//! +//! ``` +//! # use arrow_array::{Array, Float32Array, Int32Array}; +//! +//! // Safely downcast an `Array` to an `Int32Array` and compute the sum +//! // using native i32 values +//! fn sum_int32(array: &dyn Array) -> i32 { +//! let integers: &Int32Array = array.as_any().downcast_ref().unwrap(); +//! integers.iter().map(|val| val.unwrap_or_default()).sum() +//! } +//! +//! // Safely downcasts the array to a `Float32Array` and returns a &[f32] view of the data +//! // Note: the values for positions corresponding to nulls will be arbitrary (but still valid f32) +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! array.as_any().downcast_ref::().unwrap().values() +//! } +//! ``` +//! +//! The [`cast::AsArray`] extension trait can make this more ergonomic +//! +//! ``` +//! # use arrow_array::Array; +//! # use arrow_array::cast::{AsArray, as_primitive_array}; +//! # use arrow_array::types::Float32Type; +//! +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! array.as_primitive::().values() +//! } +//! ``` +//! +//! [`ScalarBuffer`]: arrow_buffer::ScalarBuffer +//! [`ScalarBuffer`]: arrow_buffer::ScalarBuffer +//! [`OffsetBuffer`]: arrow_buffer::OffsetBuffer +//! [`NullBuffer`]: arrow_buffer::NullBuffer +//! [Arrow specification]: https://arrow.apache.org/docs/format/Columnar.html +//! [`&dyn Array`]: Array +//! [`NullBuffer`]: arrow_buffer::NullBuffer +//! [`Buffer`]: arrow_buffer::Buffer +//! [`compute`]: https://docs.rs/arrow/latest/arrow/compute/index.html +//! [`json`]: https://docs.rs/arrow/latest/arrow/json/index.html +//! [`csv`]: https://docs.rs/arrow/latest/arrow/csv/index.html + +#![deny(rustdoc::broken_intra_doc_links)] +#![warn(missing_docs)] + +pub mod array; +pub use array::*; + +mod record_batch; +pub use record_batch::{ + RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, +}; + +mod arithmetic; +pub use arithmetic::ArrowNativeTypeOp; + +mod numeric; +pub use numeric::*; + +mod scalar; +pub use scalar::*; + +pub mod builder; +pub mod cast; +mod delta; +#[cfg(feature = "ffi")] +pub mod ffi; +#[cfg(feature = "ffi")] +pub mod ffi_stream; +pub mod iterator; +pub mod run_iterator; +pub mod temporal_conversions; +pub mod timezone; +mod trusted_len; +pub mod types; + +#[cfg(test)] +mod tests { + use crate::builder::*; + + #[test] + fn test_buffer_builder_availability() { + let _builder = Int8BufferBuilder::new(10); + let _builder = Int16BufferBuilder::new(10); + let _builder = Int32BufferBuilder::new(10); + let _builder = Int64BufferBuilder::new(10); + let _builder = UInt16BufferBuilder::new(10); + let _builder = UInt32BufferBuilder::new(10); + let _builder = Float32BufferBuilder::new(10); + let _builder = Float64BufferBuilder::new(10); + let _builder = TimestampSecondBufferBuilder::new(10); + let _builder = TimestampMillisecondBufferBuilder::new(10); + let _builder = TimestampMicrosecondBufferBuilder::new(10); + let _builder = TimestampNanosecondBufferBuilder::new(10); + let _builder = Date32BufferBuilder::new(10); + let _builder = Date64BufferBuilder::new(10); + let _builder = Time32SecondBufferBuilder::new(10); + let _builder = Time32MillisecondBufferBuilder::new(10); + let _builder = Time64MicrosecondBufferBuilder::new(10); + let _builder = Time64NanosecondBufferBuilder::new(10); + let _builder = IntervalYearMonthBufferBuilder::new(10); + let _builder = IntervalDayTimeBufferBuilder::new(10); + let _builder = IntervalMonthDayNanoBufferBuilder::new(10); + let _builder = DurationSecondBufferBuilder::new(10); + let _builder = DurationMillisecondBufferBuilder::new(10); + let _builder = DurationMicrosecondBufferBuilder::new(10); + let _builder = DurationNanosecondBufferBuilder::new(10); + } +} diff --git a/arrow/src/csv/mod.rs b/arrow-array/src/numeric.rs similarity index 73% rename from arrow/src/csv/mod.rs rename to arrow-array/src/numeric.rs index ffe82f335801..a3cd7bde5d36 100644 --- a/arrow/src/csv/mod.rs +++ b/arrow-array/src/numeric.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Transfer data between the Arrow memory format and CSV (comma-separated values). +use crate::ArrowPrimitiveType; -pub mod reader; -pub mod writer; +/// A subtype of primitive type that represents numeric values. +pub trait ArrowNumericType: ArrowPrimitiveType {} -pub use self::reader::infer_schema_from_files; -pub use self::reader::Reader; -pub use self::reader::ReaderBuilder; -pub use self::writer::Writer; -pub use self::writer::WriterBuilder; +impl ArrowNumericType for T {} diff --git a/arrow/src/record_batch.rs b/arrow-array/src/record_batch.rs similarity index 60% rename from arrow/src/record_batch.rs rename to arrow-array/src/record_batch.rs index 47257b496c1b..c56b1fd308cf 100644 --- a/arrow/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -16,17 +16,50 @@ // under the License. //! A two-dimensional batch of column-oriented data with a defined -//! [schema](crate::datatypes::Schema). +//! [schema](arrow_schema::Schema). +use crate::{new_empty_array, Array, ArrayRef, StructArray}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use std::ops::Index; use std::sync::Arc; -use crate::array::*; -use crate::compute::kernels::concat::concat; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; +/// Trait for types that can read `RecordBatch`'s. +/// +/// To create from an iterator, see [RecordBatchIterator]. +pub trait RecordBatchReader: Iterator> { + /// Returns the schema of this `RecordBatchReader`. + /// + /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this + /// reader should have the same schema as returned from this method. + fn schema(&self) -> SchemaRef; + + /// Reads the next `RecordBatch`. + #[deprecated( + since = "2.0.0", + note = "This method is deprecated in favour of `next` from the trait Iterator." + )] + fn next_batch(&mut self) -> Result, ArrowError> { + self.next().transpose() + } +} + +impl RecordBatchReader for Box { + fn schema(&self) -> SchemaRef { + self.as_ref().schema() + } +} + +/// Trait for types that can write `RecordBatch`'s. +pub trait RecordBatchWriter { + /// Write a single batch to the writer. + fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>; + + /// Write footer or termination data, then mark the writer as done. + fn close(self) -> Result<(), ArrowError>; +} /// A two-dimensional batch of column-oriented data with a defined -/// [schema](crate::datatypes::Schema). +/// [schema](arrow_schema::Schema). /// /// A `RecordBatch` is a two-dimensional dataset of a number of /// contiguous arrays, each the same length. @@ -35,8 +68,6 @@ use crate::error::{ArrowError, Result}; /// /// Record batches are a convenient unit of work for various /// serialization and computation functions, possibly incremental. -/// See also [CSV reader](crate::csv::Reader) and -/// [JSON reader](crate::json::Reader). #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { schema: SchemaRef, @@ -62,12 +93,10 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) @@ -76,12 +105,10 @@ impl RecordBatch { /// let batch = RecordBatch::try_new( /// Arc::new(schema), /// vec![Arc::new(id_array)] - /// )?; - /// # Ok(()) - /// # } + /// ).unwrap(); /// ``` - pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { - let options = RecordBatchOptions::default(); + pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { + let options = RecordBatchOptions::new(); Self::try_new_impl(schema, columns, &options) } @@ -93,7 +120,7 @@ impl RecordBatch { schema: SchemaRef, columns: Vec, options: &RecordBatchOptions, - ) -> Result { + ) -> Result { Self::try_new_impl(schema, columns, options) } @@ -118,7 +145,7 @@ impl RecordBatch { schema: SchemaRef, columns: Vec, options: &RecordBatchOptions, - ) -> Result { + ) -> Result { // check that number of fields in schema match column length if schema.fields().len() != columns.len() { return Err(ArrowError::InvalidArgumentError(format!( @@ -128,7 +155,6 @@ impl RecordBatch { ))); } - // check that all columns have the same row count let row_count = options .row_count .or_else(|| columns.first().map(|col| col.len())) @@ -147,11 +173,10 @@ impl RecordBatch { } } + // check that all columns have the same row count if columns.iter().any(|c| c.len() != row_count) { let err = match options.row_count { - Some(_) => { - "all columns in a record batch must have the specified row count" - } + Some(_) => "all columns in a record batch must have the specified row count", None => "all columns in a record batch must have the same length", }; return Err(ArrowError::InvalidArgumentError(err.to_string())); @@ -160,9 +185,7 @@ impl RecordBatch { // function for comparing column type and field type // return true if 2 types are not matched let type_not_match = if options.match_field_names { - |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { - col_type != field_type - } + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type } else { |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { !col_type.equals_datatype(field_type) @@ -179,10 +202,7 @@ impl RecordBatch { if let Some((i, (col_type, field_type))) = not_match { return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - field_type, - col_type, - i))); + "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); } Ok(RecordBatch { @@ -192,13 +212,37 @@ impl RecordBatch { }) } - /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. + /// Override the schema of this [`RecordBatch`] + /// + /// Returns an error if `schema` is not a superset of the current schema + /// as determined by [`Schema::contains`] + pub fn with_schema(self, schema: SchemaRef) -> Result { + if !schema.contains(self.schema.as_ref()) { + return Err(ArrowError::SchemaError(format!( + "target schema is not superset of current schema target={schema} current={}", + self.schema + ))); + } + + Ok(Self { + schema, + columns: self.columns, + row_count: self.row_count, + }) + } + + /// Returns the [`Schema`] of the record batch. pub fn schema(&self) -> SchemaRef { self.schema.clone() } + /// Returns a reference to the [`Schema`] of the record batch. + pub fn schema_ref(&self) -> &SchemaRef { + &self.schema + } + /// Projects the schema onto the specified columns - pub fn project(&self, indices: &[usize]) -> Result { + pub fn project(&self, indices: &[usize]) -> Result { let projected_schema = self.schema.project(indices)?; let batch_fields = indices .iter() @@ -211,9 +255,16 @@ impl RecordBatch { )) }) }) - .collect::>>()?; - - RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + .collect::, _>>()?; + + RecordBatch::try_new_with_options( + SchemaRef::new(projected_schema), + batch_fields, + &RecordBatchOptions { + match_field_names: true, + row_count: Some(self.row_count), + }, + ) } /// Returns the number of columns in the record batch. @@ -221,22 +272,18 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) /// ]); /// - /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); /// /// assert_eq!(batch.num_columns(), 1); - /// # Ok(()) - /// # } /// ``` pub fn num_columns(&self) -> usize { self.columns.len() @@ -247,22 +294,18 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) /// ]); /// - /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); /// /// assert_eq!(batch.num_rows(), 5); - /// # Ok(()) - /// # } /// ``` pub fn num_rows(&self) -> usize { self.row_count @@ -277,11 +320,52 @@ impl RecordBatch { &self.columns[index] } + /// Get a reference to a column's array by name. + pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { + self.schema() + .column_with_name(name) + .map(|(index, _)| &self.columns[index]) + } + /// Get a reference to all columns in the record batch. pub fn columns(&self) -> &[ArrayRef] { &self.columns[..] } + /// Remove column by index and return it. + /// + /// Return the `ArrayRef` if the column is removed. + /// + /// # Panics + /// + /// Panics if `index`` out of bounds. + /// + /// # Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow_array::{BooleanArray, Int32Array, RecordBatch}; + /// use arrow_schema::{DataType, Field, Schema}; + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("bool", DataType::Boolean, false), + /// ]); + /// + /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap(); + /// + /// let removed_column = batch.remove_column(0); + /// assert_eq!(removed_column.as_any().downcast_ref::().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5])); + /// assert_eq!(batch.num_columns(), 1); + /// ``` + pub fn remove_column(&mut self, index: usize) -> ArrayRef { + let mut builder = SchemaBuilder::from(self.schema.as_ref()); + builder.remove(index); + self.schema = Arc::new(builder.finish()); + self.columns.remove(index) + } + /// Return a new RecordBatch where each column is sliced /// according to `offset` and `length` /// @@ -316,10 +400,8 @@ impl RecordBatch { /// /// Example: /// ``` - /// use std::sync::Arc; - /// use arrow::array::{ArrayRef, Int32Array, StringArray}; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; /// /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); @@ -329,7 +411,7 @@ impl RecordBatch { /// ("b", b), /// ]); /// ``` - pub fn try_from_iter(value: I) -> Result + pub fn try_from_iter(value: I) -> Result where I: IntoIterator, F: AsRef, @@ -353,10 +435,8 @@ impl RecordBatch { /// /// Example: /// ``` - /// use std::sync::Arc; - /// use arrow::array::{ArrayRef, Int32Array, StringArray}; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; /// /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); @@ -368,54 +448,32 @@ impl RecordBatch { /// ("b", b, true), /// ]); /// ``` - pub fn try_from_iter_with_nullable(value: I) -> Result + pub fn try_from_iter_with_nullable(value: I) -> Result where I: IntoIterator, F: AsRef, { - // TODO: implement `TryFrom` trait, once - // https://github.com/rust-lang/rust/issues/50133 is no longer an - // issue - let (fields, columns) = value - .into_iter() - .map(|(field_name, array, nullable)| { - let field_name = field_name.as_ref(); - let field = Field::new(field_name, array.data_type().clone(), nullable); - (field, array) - }) - .unzip(); + let iter = value.into_iter(); + let capacity = iter.size_hint().0; + let mut schema = SchemaBuilder::with_capacity(capacity); + let mut columns = Vec::with_capacity(capacity); + + for (field_name, array, nullable) in iter { + let field_name = field_name.as_ref(); + schema.push(Field::new(field_name, array.data_type().clone(), nullable)); + columns.push(array); + } - let schema = Arc::new(Schema::new(fields)); + let schema = Arc::new(schema.finish()); RecordBatch::try_new(schema, columns) } - /// Concatenates `batches` together into a single record batch. - pub fn concat(schema: &SchemaRef, batches: &[Self]) -> Result { - if batches.is_empty() { - return Ok(RecordBatch::new_empty(schema.clone())); - } - if let Some((i, _)) = batches + /// Returns the total number of bytes of memory occupied physically by this batch. + pub fn get_array_memory_size(&self) -> usize { + self.columns() .iter() - .enumerate() - .find(|&(_, batch)| batch.schema() != *schema) - { - return Err(ArrowError::InvalidArgumentError(format!( - "batches[{}] schema is different with argument schema.", - i - ))); - } - let field_num = schema.fields().len(); - let mut arrays = Vec::with_capacity(field_num); - for i in 0..field_num { - let array = concat( - &batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - )?; - arrays.push(array); - } - Self::try_new(schema.clone(), arrays) + .map(|array| array.get_array_memory_size()) + .sum() } } @@ -430,71 +488,150 @@ pub struct RecordBatchOptions { pub row_count: Option, } -impl Default for RecordBatchOptions { - fn default() -> Self { +impl RecordBatchOptions { + /// Creates a new `RecordBatchOptions` + pub fn new() -> Self { Self { match_field_names: true, row_count: None, } } + /// Sets the row_count of RecordBatchOptions and returns self + pub fn with_row_count(mut self, row_count: Option) -> Self { + self.row_count = row_count; + self + } + /// Sets the match_field_names of RecordBatchOptions and returns self + pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { + self.match_field_names = match_field_names; + self + } +} +impl Default for RecordBatchOptions { + fn default() -> Self { + Self::new() + } +} +impl From for RecordBatch { + fn from(value: StructArray) -> Self { + let row_count = value.len(); + let (fields, columns, nulls) = value.into_parts(); + assert_eq!( + nulls.map(|n| n.null_count()).unwrap_or_default(), + 0, + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" + ); + + RecordBatch { + schema: Arc::new(Schema::new(fields)), + row_count, + columns, + } + } } impl From<&StructArray> for RecordBatch { - /// Create a record batch from struct array, where each field of - /// the `StructArray` becomes a `Field` in the schema. - /// - /// This currently does not flatten and nested struct types fn from(struct_array: &StructArray) -> Self { - if let DataType::Struct(fields) = struct_array.data_type() { - let schema = Schema::new(fields.clone()); - let columns = struct_array.boxed_fields.clone(); - RecordBatch { - schema: Arc::new(schema), - row_count: struct_array.len(), - columns, - } - } else { - unreachable!("unable to get datatype as struct") - } + struct_array.clone().into() } } -impl From for StructArray { - fn from(batch: RecordBatch) -> Self { - batch - .schema - .fields - .iter() - .zip(batch.columns.iter()) - .map(|t| (t.0.clone(), t.1.clone())) - .collect::>() - .into() +impl Index<&str> for RecordBatch { + type Output = ArrayRef; + + /// Get a reference to a column's array by name. + /// + /// # Panics + /// + /// Panics if the name is not in the schema. + fn index(&self, name: &str) -> &Self::Output { + self.column_by_name(name).unwrap() } } -/// Trait for types that can read `RecordBatch`'s. -pub trait RecordBatchReader: Iterator> { - /// Returns the schema of this `RecordBatchReader`. +/// Generic implementation of [RecordBatchReader] that wraps an iterator. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader}; +/// # +/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); +/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); +/// +/// let record_batch = RecordBatch::try_from_iter(vec![ +/// ("a", a), +/// ("b", b), +/// ]).unwrap(); +/// +/// let batches: Vec = vec![record_batch.clone(), record_batch.clone()]; +/// +/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema()); +/// +/// assert_eq!(reader.schema(), record_batch.schema()); +/// assert_eq!(reader.next().unwrap().unwrap(), record_batch); +/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch); +/// # assert!(reader.next().is_none()); +/// ``` +pub struct RecordBatchIterator +where + I: IntoIterator>, +{ + inner: I::IntoIter, + inner_schema: SchemaRef, +} + +impl RecordBatchIterator +where + I: IntoIterator>, +{ + /// Create a new [RecordBatchIterator]. /// - /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this - /// reader should have the same schema as returned from this method. - fn schema(&self) -> SchemaRef; + /// If `iter` is an infallible iterator, use `.map(Ok)`. + pub fn new(iter: I, schema: SchemaRef) -> Self { + Self { + inner: iter.into_iter(), + inner_schema: schema, + } + } +} - /// Reads the next `RecordBatch`. - #[deprecated( - since = "2.0.0", - note = "This method is deprecated in favour of `next` from the trait Iterator." - )] - fn next_batch(&mut self) -> Result> { - self.next().transpose() +impl Iterator for RecordBatchIterator +where + I: IntoIterator>, +{ + type Item = I::Item; + + fn next(&mut self) -> Option { + self.inner.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl RecordBatchReader for RecordBatchIterator +where + I: IntoIterator>, +{ + fn schema(&self) -> SchemaRef { + self.inner_schema.clone() } } #[cfg(test)] mod tests { - use super::*; + use std::collections::HashMap; - use crate::buffer::Buffer; + use super::*; + use crate::{ + BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray, + }; + use arrow_buffer::{Buffer, ToByteSlice}; + use arrow_data::{ArrayData, ArrayDataBuilder}; + use arrow_schema::Fields; #[test] fn create_record_batch() { @@ -507,18 +644,56 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); check_batch(record_batch, 5) } + #[test] + fn create_string_view_record_batch() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8View, false), + ]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); + + assert_eq!(5, record_batch.num_rows()); + assert_eq!(2, record_batch.num_columns()); + assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); + assert_eq!( + &DataType::Utf8View, + record_batch.schema().field(1).data_type() + ); + assert_eq!(5, record_batch.column(0).len()); + assert_eq!(5, record_batch.column(1).len()); + } + + #[test] + fn byte_size_should_not_regress() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); + assert_eq!(record_batch.get_array_memory_size(), 364); + } + fn check_batch(record_batch: RecordBatch, num_rows: usize) { assert_eq!(num_rows, record_batch.num_rows()); assert_eq!(2, record_batch.num_columns()); assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); - assert_eq!(num_rows, record_batch.column(0).data().len()); - assert_eq!(num_rows, record_batch.column(1).data().len()); + assert_eq!(num_rows, record_batch.column(0).len()); + assert_eq!(num_rows, record_batch.column(1).len()); } #[test] @@ -534,8 +709,7 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); let offset = 2; let length = 5; @@ -559,7 +733,7 @@ mod tests { #[test] #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] fn create_record_batch_slice_empty_batch() { - let schema = Schema::new(vec![]); + let schema = Schema::empty(); let record_batch = RecordBatch::new_empty(Arc::new(schema)); @@ -584,8 +758,8 @@ mod tests { ])); let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); - let record_batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), @@ -601,11 +775,9 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); // Note there are no nulls in a or b, but we specify that b is nullable - let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ - ("a", a, false), - ("b", b, true), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)]) + .expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), @@ -627,34 +799,29 @@ mod tests { #[test] fn create_record_batch_field_name_mismatch() { - let struct_fields = vec![ + let fields = vec![ Field::new("a1", DataType::Int32, false), - Field::new( - "a2", - DataType::List(Box::new(Field::new("item", DataType::Int8, false))), - false, - ), + Field::new_list("a2", Field::new("item", DataType::Int8, false), false), ]; - let struct_type = DataType::Struct(struct_fields); - let schema = Arc::new(Schema::new(vec![Field::new("a", struct_type, true)])); + let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)])); let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); let a2_child = Int8Array::from(vec![1, 2, 3, 4]); - let a2 = ArrayDataBuilder::new(DataType::List(Box::new(Field::new( + let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( "array", DataType::Int8, false, )))) .add_child_data(a2_child.into_data()) .len(2) - .add_buffer(Buffer::from(vec![0i32, 3, 4].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice())) .build() .unwrap(); let a2: ArrayRef = Arc::new(ListArray::from(a2)); - let a = ArrayDataBuilder::new(DataType::Struct(vec![ + let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![ Field::new("aa1", DataType::Int32, false), Field::new("a2", a2.data_type().clone(), false), - ])) + ]))) .add_child_data(a1.into_data()) .add_child_data(a2.into_data()) .len(2) @@ -682,8 +849,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let b = Int32Array::from(vec![1, 2, 3, 4, 5]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); assert!(batch.is_err()); } @@ -693,11 +859,11 @@ mod tests { let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); let struct_array = StructArray::from(vec![ ( - Field::new("b", DataType::Boolean, false), + Arc::new(Field::new("b", DataType::Boolean, false)), boolean.clone() as ArrayRef, ), ( - Field::new("c", DataType::Int32, false), + Arc::new(Field::new("c", DataType::Int32, false)), int.clone() as ArrayRef, ), ]); @@ -707,84 +873,12 @@ mod tests { assert_eq!(4, batch.num_rows()); assert_eq!( struct_array.data_type(), - &DataType::Struct(batch.schema().fields().to_vec()) + &DataType::Struct(batch.schema().fields().clone()) ); assert_eq!(batch.column(0).as_ref(), boolean.as_ref()); assert_eq!(batch.column(1).as_ref(), int.as_ref()); } - #[test] - fn concat_record_batches() { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], - ) - .unwrap(); - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], - ) - .unwrap(); - let new_batch = RecordBatch::concat(&schema, &[batch1, batch2]).unwrap(); - assert_eq!(new_batch.schema().as_ref(), schema.as_ref()); - assert_eq!(2, new_batch.num_columns()); - assert_eq!(4, new_batch.num_rows()); - } - - #[test] - fn concat_empty_record_batch() { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let batch = RecordBatch::concat(&schema, &[]).unwrap(); - assert_eq!(batch.schema().as_ref(), schema.as_ref()); - assert_eq!(0, batch.num_rows()); - } - - #[test] - fn concat_record_batches_of_different_schemas() { - let schema1 = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let schema2 = Arc::new(Schema::new(vec![ - Field::new("c", DataType::Int32, false), - Field::new("d", DataType::Utf8, false), - ])); - let batch1 = RecordBatch::try_new( - schema1.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], - ) - .unwrap(); - let batch2 = RecordBatch::try_new( - schema2, - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], - ) - .unwrap(); - let error = RecordBatch::concat(&schema1, &[batch1, batch2]).unwrap_err(); - assert_eq!( - error.to_string(), - "Invalid argument error: batches[1] schema is different with argument schema.", - ); - } - #[test] fn record_batch_equality() { let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); @@ -816,6 +910,22 @@ mod tests { assert_eq!(batch1, batch2); } + /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]` + #[test] + fn record_batch_index_access() { + let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); + let schema1 = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Int32, false), + ]); + let record_batch = + RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap(); + + assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref()); + assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref()); + } + #[test] fn record_batch_vals_ne() { let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); @@ -948,35 +1058,48 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); - let record_batch = RecordBatch::try_from_iter(vec![ - ("a", a.clone()), - ("b", b.clone()), - ("c", c.clone()), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())]) + .expect("valid conversion"); - let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)]) - .expect("valid conversion"); + let expected = + RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion"); assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); } + #[test] + fn project_empty() { + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = + RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); + + let expected = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions { + match_field_names: true, + row_count: Some(3), + }, + ) + .expect("valid conversion"); + + assert_eq!(expected, record_batch.project(&[]).unwrap()); + } + #[test] fn test_no_column_record_batch() { - let schema = Arc::new(Schema::new(vec![])); + let schema = Arc::new(Schema::empty()); let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); assert!(err .to_string() .contains("must either specify a row count or at least one column")); - let options = RecordBatchOptions { - row_count: Some(10), - ..Default::default() - }; + let options = RecordBatchOptions::new().with_row_count(Some(10)); - let ok = - RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); assert_eq!(ok.num_rows(), 10); let a = ok.slice(2, 5); @@ -998,4 +1121,98 @@ mod tests { ); assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap())); } + #[test] + fn test_record_batch_options() { + let options = RecordBatchOptions::new() + .with_match_field_names(false) + .with_row_count(Some(20)); + assert!(!options.match_field_names); + assert_eq!(options.row_count.unwrap(), 20) + } + + #[test] + #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")] + fn test_from_struct() { + let s = StructArray::from(ArrayData::new_null( + // Note child is not nullable + &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()), + 2, + )); + let _ = RecordBatch::from(s); + } + + #[test] + fn test_with_schema() { + let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let required_schema = Arc::new(required_schema); + let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let nullable_schema = Arc::new(nullable_schema); + + let batch = RecordBatch::try_new( + required_schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _], + ) + .unwrap(); + + // Can add nullability + let batch = batch.with_schema(nullable_schema.clone()).unwrap(); + + // Cannot remove nullability + batch.clone().with_schema(required_schema).unwrap_err(); + + // Can add metadata + let metadata = vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(); + let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata); + let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap(); + + // Cannot remove metadata + batch.with_schema(nullable_schema).unwrap_err(); + } + + #[test] + fn test_boxed_reader() { + // Make sure we can pass a boxed reader to a function generic over + // RecordBatchReader. + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Arc::new(schema); + + let reader = RecordBatchIterator::new(std::iter::empty(), schema); + let reader: Box = Box::new(reader); + + fn get_size(reader: impl RecordBatchReader) -> usize { + reader.size_hint().0 + } + + let size = get_size(reader); + assert_eq!(size, 0); + } + + #[test] + fn test_remove_column_maintains_schema_metadata() { + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let bool_array = BooleanArray::from(vec![true, false, false, true, true]); + + let mut metadata = HashMap::new(); + metadata.insert("foo".to_string(), "bar".to_string()); + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("bool", DataType::Boolean, false), + ]) + .with_metadata(metadata); + + let mut batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(id_array), Arc::new(bool_array)], + ) + .unwrap(); + + let _removed_column = batch.remove_column(0); + assert_eq!(batch.schema().metadata().len(), 1); + assert_eq!( + batch.schema().metadata().get("foo").unwrap().as_str(), + "bar" + ); + } } diff --git a/arrow-array/src/run_iterator.rs b/arrow-array/src/run_iterator.rs new file mode 100644 index 000000000000..2922bf04dd2f --- /dev/null +++ b/arrow-array/src/run_iterator.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Idiomatic iterator for [`RunArray`](crate::RunArray) + +use crate::{array::ArrayAccessor, types::RunEndIndexType, Array, TypedRunArray}; +use arrow_buffer::ArrowNativeType; + +/// The [`RunArrayIter`] provides an idiomatic way to iterate over the run array. +/// It returns Some(T) if there is a value or None if the value is null. +/// +/// The iterator comes with a cost as it has to iterate over three arrays to determine +/// the value to be returned. The run_ends array is used to determine the index of the value. +/// The nulls array is used to determine if the value is null and the values array is used to +/// get the value. +/// +/// Unlike other iterators in this crate, [`RunArrayIter`] does not use [`ArrayAccessor`] +/// because the run array accessor does binary search to access each value which is too slow. +/// The run array iterator can determine the next value in constant time. +/// +#[derive(Debug)] +pub struct RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + array: TypedRunArray<'a, R, V>, + current_front_logical: usize, + current_front_physical: usize, + current_back_logical: usize, + current_back_physical: usize, +} + +impl<'a, R, V> RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + /// create a new iterator + pub fn new(array: TypedRunArray<'a, R, V>) -> Self { + let current_front_physical = array.run_array().get_start_physical_index(); + let current_back_physical = array.run_array().get_end_physical_index() + 1; + RunArrayIter { + array, + current_front_logical: array.offset(), + current_front_physical, + current_back_logical: array.offset() + array.len(), + current_back_physical, + } + } +} + +impl<'a, R, V> Iterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = Option<<&'a V as ArrayAccessor>::Item>; + + #[inline] + fn next(&mut self) -> Option { + if self.current_front_logical == self.current_back_logical { + return None; + } + + // If current logical index is greater than current run end index then increment + // the physical index. + let run_ends = self.array.run_ends().values(); + if self.current_front_logical >= run_ends[self.current_front_physical].as_usize() { + // As the run_ends is expected to be strictly increasing, there + // should be at least one logical entry in one physical entry. Because of this + // reason the next value can be accessed by incrementing physical index once. + self.current_front_physical += 1; + } + if self.array.values().is_null(self.current_front_physical) { + self.current_front_logical += 1; + Some(None) + } else { + self.current_front_logical += 1; + // Safety: + // The self.current_physical is kept within bounds of self.current_logical. + // The self.current_logical will not go out of bounds because of the check + // `self.current_logical = self.current_end_logical` above. + unsafe { + Some(Some( + self.array + .values() + .value_unchecked(self.current_front_physical), + )) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.current_back_logical - self.current_front_logical, + Some(self.current_back_logical - self.current_front_logical), + ) + } +} + +impl<'a, R, V> DoubleEndedIterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + fn next_back(&mut self) -> Option { + if self.current_back_logical == self.current_front_logical { + return None; + } + + self.current_back_logical -= 1; + + let run_ends = self.array.run_ends().values(); + if self.current_back_physical > 0 + && self.current_back_logical < run_ends[self.current_back_physical - 1].as_usize() + { + // As the run_ends is expected to be strictly increasing, there + // should be at least one logical entry in one physical entry. Because of this + // reason the next value can be accessed by decrementing physical index once. + self.current_back_physical -= 1; + } + Some(if self.array.values().is_null(self.current_back_physical) { + None + } else { + // Safety: + // The check `self.current_end_physical > 0` ensures the value will not underflow. + // Also self.current_end_physical starts with array.len() and + // decrements based on the bounds of self.current_end_logical. + unsafe { + Some( + self.array + .values() + .value_unchecked(self.current_back_physical), + ) + } + }) + } +} + +/// all arrays have known size. +impl<'a, R, V> ExactSizeIterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ +} + +#[cfg(test)] +mod tests { + use rand::{seq::SliceRandom, thread_rng, Rng}; + + use crate::{ + array::{Int32Array, StringArray}, + builder::PrimitiveRunBuilder, + types::{Int16Type, Int32Type}, + Array, Int64RunArray, PrimitiveArray, RunArray, + }; + + fn build_input_array(size: usize) -> Vec> { + // The input array is created by shuffling and repeating + // the seed values random number of times. + let mut seed: Vec> = vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + let mut result: Vec> = Vec::with_capacity(size); + let mut ix = 0; + let mut rng = thread_rng(); + // run length can go up to 8. Cap the max run length for smaller arrays to size / 2. + let max_run_length = 8_usize.min(1_usize.max(size / 2)); + while result.len() < size { + // shuffle the seed array if all the values are iterated. + if ix == 0 { + seed.shuffle(&mut rng); + } + // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + for _ in 0..num { + result.push(seed[ix]); + } + ix += 1; + if ix == seed.len() { + ix = 0 + } + } + result.resize(size, None); + result + } + + #[test] + fn test_primitive_array_iter_round_trip() { + let mut input_vec = vec![ + Some(32), + Some(32), + None, + Some(64), + Some(64), + Some(64), + Some(72), + ]; + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_vec.iter().copied()); + let ree_array = builder.finish(); + let ree_array = ree_array.downcast::().unwrap(); + + let output_vec: Vec> = ree_array.into_iter().collect(); + assert_eq!(input_vec, output_vec); + + let rev_output_vec: Vec> = ree_array.into_iter().rev().collect(); + input_vec.reverse(); + assert_eq!(input_vec, rev_output_vec); + } + + #[test] + fn test_double_ended() { + let input_vec = vec![ + Some(32), + Some(32), + None, + Some(64), + Some(64), + Some(64), + Some(72), + ]; + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_vec); + let ree_array = builder.finish(); + let ree_array = ree_array.downcast::().unwrap(); + + let mut iter = ree_array.into_iter(); + assert_eq!(Some(Some(32)), iter.next()); + assert_eq!(Some(Some(72)), iter.next_back()); + assert_eq!(Some(Some(32)), iter.next()); + assert_eq!(Some(Some(64)), iter.next_back()); + assert_eq!(Some(None), iter.next()); + assert_eq!(Some(Some(64)), iter.next_back()); + assert_eq!(Some(Some(64)), iter.next()); + assert_eq!(None, iter.next_back()); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_run_iterator_comprehensive() { + // Test forward and backward iterator for different array lengths. + let logical_lengths = vec![1_usize, 2, 3, 4, 15, 16, 17, 63, 64, 65]; + + for logical_len in logical_lengths { + let input_array = build_input_array(logical_len); + + let mut run_array_builder = PrimitiveRunBuilder::::new(); + run_array_builder.extend(input_array.iter().copied()); + let run_array = run_array_builder.finish(); + let typed_array = run_array.downcast::().unwrap(); + + // test forward iterator + let mut input_iter = input_array.iter().copied(); + let mut run_array_iter = typed_array.into_iter(); + for _ in 0..logical_len { + assert_eq!(input_iter.next(), run_array_iter.next()); + } + assert_eq!(None, run_array_iter.next()); + + // test reverse iterator + let mut input_iter = input_array.iter().rev().copied(); + let mut run_array_iter = typed_array.into_iter().rev(); + for _ in 0..logical_len { + assert_eq!(input_iter.next(), run_array_iter.next()); + } + assert_eq!(None, run_array_iter.next()); + } + } + + #[test] + fn test_string_array_iter_round_trip() { + let input_vec = vec!["ab", "ab", "ba", "cc", "cc"]; + let input_ree_array: Int64RunArray = input_vec.into_iter().collect(); + let string_ree_array = input_ree_array.downcast::().unwrap(); + + // to and from iter, with a +1 + let result: Vec> = string_ree_array + .into_iter() + .map(|e| { + e.map(|e| { + let mut a = e.to_string(); + a.push('b'); + a + }) + }) + .collect(); + + let result_asref: Vec> = result.iter().map(|f| f.as_deref()).collect(); + + let expected_vec = vec![ + Some("abb"), + Some("abb"), + Some("bab"), + Some("ccb"), + Some("ccb"), + ]; + + assert_eq!(expected_vec, result_asref); + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_sliced_run_array_iterator() { + let total_len = 80; + let input_array = build_input_array(total_len); + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // test for offset = 0, slice length = slice_len + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + let sliced_typed_run_array = sliced_run_array + .downcast::>() + .unwrap(); + + // Iterate on sliced typed run array + let actual: Vec> = sliced_typed_run_array.into_iter().collect(); + let expected: Vec> = input_array.iter().take(slice_len).copied().collect(); + assert_eq!(expected, actual); + + // test for offset = total_len - slice_len, length = slice_len + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + let sliced_typed_run_array = sliced_run_array + .downcast::>() + .unwrap(); + + // Iterate on sliced typed run array + let actual: Vec> = sliced_typed_run_array.into_iter().collect(); + let expected: Vec> = input_array + .iter() + .skip(total_len - slice_len) + .copied() + .collect(); + assert_eq!(expected, actual); + } + } +} diff --git a/arrow-array/src/scalar.rs b/arrow-array/src/scalar.rs new file mode 100644 index 000000000000..4bbf668c4a98 --- /dev/null +++ b/arrow-array/src/scalar.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Array; + +/// A possibly [`Scalar`] [`Array`] +/// +/// This allows optimised binary kernels where one or more arguments are constant +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; +/// # use arrow_schema::ArrowError; +/// # +/// fn eq_impl( +/// a: &PrimitiveArray, +/// a_scalar: bool, +/// b: &PrimitiveArray, +/// b_scalar: bool, +/// ) -> BooleanArray { +/// let (array, scalar) = match (a_scalar, b_scalar) { +/// (true, true) | (false, false) => { +/// let len = a.len().min(b.len()); +/// let nulls = NullBuffer::union(a.nulls(), b.nulls()); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| a.value(idx) == b.value(idx)); +/// return BooleanArray::new(buffer, nulls); +/// } +/// (true, false) => (b, (a.null_count() == 0).then(|| a.value(0))), +/// (false, true) => (a, (b.null_count() == 0).then(|| b.value(0))), +/// }; +/// match scalar { +/// Some(v) => { +/// let len = array.len(); +/// let nulls = array.nulls().cloned(); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| array.value(idx) == v); +/// BooleanArray::new(buffer, nulls) +/// } +/// None => BooleanArray::new_null(array.len()), +/// } +/// } +/// +/// pub fn eq(l: &dyn Datum, r: &dyn Datum) -> Result { +/// let (l_array, l_scalar) = l.get(); +/// let (r_array, r_scalar) = r.get(); +/// downcast_primitive_array!( +/// (l_array, r_array) => Ok(eq_impl(l_array, l_scalar, r_array, r_scalar)), +/// (a, b) => Err(ArrowError::NotYetImplemented(format!("{a} == {b}"))), +/// ) +/// } +/// +/// // Comparison of two arrays +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::from(vec![1, 2, 4, 7, 3]); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, true, false, false, false]); +/// +/// // Comparison of an array and a scalar +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::new_scalar(1); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, false, false, false, false]); +pub trait Datum { + /// Returns the value for this [`Datum`] and a boolean indicating if the value is scalar + fn get(&self) -> (&dyn Array, bool); +} + +impl Datum for T { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for &dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (*self, false) + } +} + +/// A wrapper around a single value [`Array`] that implements +/// [`Datum`] and indicates [compute] kernels should treat this array +/// as a scalar value (a single value). +/// +/// Using a [`Scalar`] is often much more efficient than creating an +/// [`Array`] with the same (repeated) value. +/// +/// See [`Datum`] for more information. +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::{Scalar, Int32Array, ArrayRef}; +/// # fn get_array() -> ArrayRef { std::sync::Arc::new(Int32Array::from(vec![42])) } +/// // Create a (typed) scalar for Int32Array for the value 42 +/// let scalar = Scalar::new(Int32Array::from(vec![42])); +/// +/// // Create a scalar using PrimtiveArray::scalar +/// let scalar = Int32Array::new_scalar(42); +/// +/// // create a scalar from an ArrayRef (for dynamic typed Arrays) +/// let array: ArrayRef = get_array(); +/// let scalar = Scalar::new(array); +/// ``` +/// +/// [compute]: https://docs.rs/arrow/latest/arrow/compute/index.html +#[derive(Debug, Copy, Clone)] +pub struct Scalar(T); + +impl Scalar { + /// Create a new [`Scalar`] from an [`Array`] + /// + /// # Panics + /// + /// Panics if `array.len() != 1` + pub fn new(array: T) -> Self { + assert_eq!(array.len(), 1); + Self(array) + } + + /// Returns the inner array + #[inline] + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Datum for Scalar { + fn get(&self) -> (&dyn Array, bool) { + (&self.0, true) + } +} diff --git a/arrow-array/src/temporal_conversions.rs b/arrow-array/src/temporal_conversions.rs new file mode 100644 index 000000000000..8d238b3a196c --- /dev/null +++ b/arrow-array/src/temporal_conversions.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion methods for dates and times. + +use crate::timezone::Tz; +use crate::ArrowPrimitiveType; +use arrow_schema::{DataType, TimeUnit}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; + +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of microseconds in a day +pub const MICROSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MICROSECONDS; +/// Number of nanoseconds in a day +pub const NANOSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * NANOSECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> Option { + Some(DateTime::from_timestamp(v as i64 * SECONDS_IN_DAY, 0)?.naive_utc()) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> Option { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + + let datetime = DateTime::from_timestamp( + // extract seconds from milliseconds + sec, + // discard extracted seconds and convert milliseconds to nanoseconds + milli_sec * MICROSECONDS as u32, + )?; + Some(datetime.naive_utc()) +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveDateTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveDateTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> Option { + let v = v as i64; + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from milliseconds + (v / MILLISECONDS) as u32, + // discard extracted seconds and convert milliseconds to + // nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveDateTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveDateTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts [`NaiveTime`] to a `i32` representing a `time32(s)` +#[inline] +pub fn time_to_time32s(v: NaiveTime) -> i32 { + v.num_seconds_from_midnight() as i32 +} + +/// converts [`NaiveTime`] to a `i32` representing a `time32(ms)` +#[inline] +pub fn time_to_time32ms(v: NaiveTime) -> i32 { + (v.num_seconds_from_midnight() as i64 * MILLISECONDS + + v.nanosecond() as i64 * MILLISECONDS / NANOSECONDS) as i32 +} + +/// converts [`NaiveTime`] to a `i64` representing a `time64(us)` +#[inline] +pub fn time_to_time64us(v: NaiveTime) -> i64 { + v.num_seconds_from_midnight() as i64 * MICROSECONDS + + v.nanosecond() as i64 * MICROSECONDS / NANOSECONDS +} + +/// converts [`NaiveTime`] to a `i64` representing a `time64(ns)` +#[inline] +pub fn time_to_time64ns(v: NaiveTime) -> i64 { + v.num_seconds_from_midnight() as i64 * NANOSECONDS + v.nanosecond() as i64 +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(v: i64) -> Option { + Some(DateTime::from_timestamp(v, 0)?.naive_utc()) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> Option { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + + let datetime = DateTime::from_timestamp( + // extract seconds from milliseconds + sec, + // discard extracted seconds and convert milliseconds to nanoseconds + milli_sec * MICROSECONDS as u32, + )?; + Some(datetime.naive_utc()) +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> Option { + let (sec, micro_sec) = split_second(v, MICROSECONDS); + + let datetime = DateTime::from_timestamp( + // extract seconds from microseconds + sec, + // discard extracted seconds and convert microseconds to nanoseconds + micro_sec * MILLISECONDS as u32, + )?; + Some(datetime.naive_utc()) +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> Option { + let (sec, nano_sec) = split_second(v, NANOSECONDS); + + let datetime = DateTime::from_timestamp( + // extract seconds from nanoseconds + sec, // discard extracted seconds + nano_sec, + )?; + Some(datetime.naive_utc()) +} + +#[inline] +pub(crate) fn split_second(v: i64, base: i64) -> (i64, u32) { + (v.div_euclid(base), v.rem_euclid(base) as u32) +} + +/// converts a `i64` representing a `duration(s)` to [`Duration`] +#[inline] +pub fn duration_s_to_duration(v: i64) -> Duration { + Duration::try_seconds(v).unwrap() +} + +/// converts a `i64` representing a `duration(ms)` to [`Duration`] +#[inline] +pub fn duration_ms_to_duration(v: i64) -> Duration { + Duration::try_milliseconds(v).unwrap() +} + +/// converts a `i64` representing a `duration(us)` to [`Duration`] +#[inline] +pub fn duration_us_to_duration(v: i64) -> Duration { + Duration::microseconds(v) +} + +/// converts a `i64` representing a `duration(ns)` to [`Duration`] +#[inline] +pub fn duration_ns_to_duration(v: i64) -> Duration { + Duration::nanoseconds(v) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveDateTime`] +pub fn as_datetime(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Date32 => date32_to_datetime(v as i32), + DataType::Date64 => date64_to_datetime(v), + DataType::Time32(_) | DataType::Time64(_) => None, + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => timestamp_s_to_datetime(v), + TimeUnit::Millisecond => timestamp_ms_to_datetime(v), + TimeUnit::Microsecond => timestamp_us_to_datetime(v), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(v), + }, + // interval is not yet fully documented [ARROW-3097] + DataType::Interval(_) => None, + _ => None, + } +} + +/// Converts an [`ArrowPrimitiveType`] to [`DateTime`] +pub fn as_datetime_with_timezone(v: i64, tz: Tz) -> Option> { + let naive = as_datetime::(v)?; + Some(Utc.from_utc_datetime(&naive).with_timezone(&tz)) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveDate`] +pub fn as_date(v: i64) -> Option { + as_datetime::(v).map(|datetime| datetime.date()) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveTime`] +pub fn as_time(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Time32(unit) => { + // safe to immediately cast to u32 as `self.value(i)` is positive i32 + let v = v as u32; + match unit { + TimeUnit::Second => time32s_to_time(v as i32), + TimeUnit::Millisecond => time32ms_to_time(v as i32), + _ => None, + } + } + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => time64us_to_time(v), + TimeUnit::Nanosecond => time64ns_to_time(v), + _ => None, + }, + DataType::Timestamp(_, _) => as_datetime::(v).map(|datetime| datetime.time()), + DataType::Date32 | DataType::Date64 => NaiveTime::from_hms_opt(0, 0, 0), + DataType::Interval(_) => None, + _ => None, + } +} + +/// Converts an [`ArrowPrimitiveType`] to [`Duration`] +pub fn as_duration(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(duration_s_to_duration(v)), + TimeUnit::Millisecond => Some(duration_ms_to_duration(v)), + TimeUnit::Microsecond => Some(duration_us_to_duration(v)), + TimeUnit::Nanosecond => Some(duration_ns_to_duration(v)), + }, + _ => None, + } +} + +#[cfg(test)] +mod tests { + use crate::temporal_conversions::{ + date64_to_datetime, split_second, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_us_to_datetime, NANOSECONDS, + }; + use chrono::DateTime; + + #[test] + fn negative_input_timestamp_ns_to_datetime() { + assert_eq!( + timestamp_ns_to_datetime(-1), + DateTime::from_timestamp(-1, 999_999_999).map(|x| x.naive_utc()) + ); + + assert_eq!( + timestamp_ns_to_datetime(-1_000_000_001), + DateTime::from_timestamp(-2, 999_999_999).map(|x| x.naive_utc()) + ); + } + + #[test] + fn negative_input_timestamp_us_to_datetime() { + assert_eq!( + timestamp_us_to_datetime(-1), + DateTime::from_timestamp(-1, 999_999_000).map(|x| x.naive_utc()) + ); + + assert_eq!( + timestamp_us_to_datetime(-1_000_001), + DateTime::from_timestamp(-2, 999_999_000).map(|x| x.naive_utc()) + ); + } + + #[test] + fn negative_input_timestamp_ms_to_datetime() { + assert_eq!( + timestamp_ms_to_datetime(-1), + DateTime::from_timestamp(-1, 999_000_000).map(|x| x.naive_utc()) + ); + + assert_eq!( + timestamp_ms_to_datetime(-1_001), + DateTime::from_timestamp(-2, 999_000_000).map(|x| x.naive_utc()) + ); + } + + #[test] + fn negative_input_date64_to_datetime() { + assert_eq!( + date64_to_datetime(-1), + DateTime::from_timestamp(-1, 999_000_000).map(|x| x.naive_utc()) + ); + + assert_eq!( + date64_to_datetime(-1_001), + DateTime::from_timestamp(-2, 999_000_000).map(|x| x.naive_utc()) + ); + } + + #[test] + fn test_split_seconds() { + let (sec, nano_sec) = split_second(100, NANOSECONDS); + assert_eq!(sec, 0); + assert_eq!(nano_sec, 100); + + let (sec, nano_sec) = split_second(123_000_000_456, NANOSECONDS); + assert_eq!(sec, 123); + assert_eq!(nano_sec, 456); + + let (sec, nano_sec) = split_second(-1, NANOSECONDS); + assert_eq!(sec, -1); + assert_eq!(nano_sec, 999_999_999); + + let (sec, nano_sec) = split_second(-123_000_000_001, NANOSECONDS); + assert_eq!(sec, -124); + assert_eq!(nano_sec, 999_999_999); + } +} diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs new file mode 100644 index 000000000000..b4df77deb4f5 --- /dev/null +++ b/arrow-array/src/timezone.rs @@ -0,0 +1,339 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Timezone for timestamp arrays + +use arrow_schema::ArrowError; +use chrono::FixedOffset; +pub use private::{Tz, TzOffset}; + +/// Parses a fixed offset of the form "+09:00", "-09" or "+0930" +fn parse_fixed_offset(tz: &str) -> Option { + let bytes = tz.as_bytes(); + + let mut values = match bytes.len() { + // [+-]XX:XX + 6 if bytes[3] == b':' => [bytes[1], bytes[2], bytes[4], bytes[5]], + // [+-]XXXX + 5 => [bytes[1], bytes[2], bytes[3], bytes[4]], + // [+-]XX + 3 => [bytes[1], bytes[2], b'0', b'0'], + _ => return None, + }; + values.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); + if values.iter().any(|x| *x > 9) { + return None; + } + let secs = + (values[0] * 10 + values[1]) as i32 * 60 * 60 + (values[2] * 10 + values[3]) as i32 * 60; + + match bytes[0] { + b'+' => FixedOffset::east_opt(secs), + b'-' => FixedOffset::west_opt(secs), + _ => None, + } +} + +#[cfg(feature = "chrono-tz")] +mod private { + use super::*; + use chrono::offset::TimeZone; + use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::str::FromStr; + + /// An [`Offset`] for [`Tz`] + #[derive(Debug, Copy, Clone)] + pub struct TzOffset { + tz: Tz, + offset: FixedOffset, + } + + impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.offset.fmt(f) + } + } + + impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.offset + } + } + + /// An Arrow [`TimeZone`] + #[derive(Debug, Copy, Clone)] + pub struct Tz(TzInner); + + #[derive(Debug, Copy, Clone)] + enum TzInner { + Timezone(chrono_tz::Tz), + Offset(FixedOffset), + } + + impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + match parse_fixed_offset(tz) { + Some(offset) => Ok(Self(TzInner::Offset(offset))), + None => Ok(Self(TzInner::Timezone(tz.parse().map_err(|e| { + ArrowError::ParseError(format!("Invalid timezone \"{tz}\": {e}")) + })?))), + } + } + } + + macro_rules! tz { + ($s:ident, $tz:ident, $b:block) => { + match $s.0 { + TzInner::Timezone($tz) => $b, + TzInner::Offset($tz) => $b, + } + }; + } + + impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + offset.tz + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_date(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_datetime(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_date(utc).fix(), + } + }) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_datetime(utc).fix(), + } + }) + } + } + + #[cfg(test)] + mod tests { + use super::*; + use chrono::{Timelike, Utc}; + + #[test] + fn test_with_timezone() { + let vals = [ + Utc.timestamp_millis_opt(37800000).unwrap(), + Utc.timestamp_millis_opt(86339000).unwrap(), + ]; + + assert_eq!(10, vals[0].hour()); + assert_eq!(23, vals[1].hour()); + + let tz: Tz = "America/Los_Angeles".parse().unwrap(); + + assert_eq!(2, vals[0].with_timezone(&tz).hour()); + assert_eq!(15, vals[1].with_timezone(&tz).hour()); + } + + #[test] + fn test_using_chrono_tz_and_utc_naive_date_time() { + let sydney_tz = "Australia/Sydney".to_string(); + let tz: Tz = sydney_tz.parse().unwrap(); + let sydney_offset_without_dst = FixedOffset::east_opt(10 * 60 * 60).unwrap(); + let sydney_offset_with_dst = FixedOffset::east_opt(11 * 60 * 60).unwrap(); + // Daylight savings ends + // When local daylight time was about to reach + // Sunday, 4 April 2021, 3:00:00 am clocks were turned backward 1 hour to + // Sunday, 4 April 2021, 2:00:00 am local standard time instead. + + // Daylight savings starts + // When local standard time was about to reach + // Sunday, 3 October 2021, 2:00:00 am clocks were turned forward 1 hour to + // Sunday, 3 October 2021, 3:00:00 am local daylight time instead. + + // Sydney 2021-04-04T02:30:00+11:00 is 2021-04-03T15:30:00Z + let utc_just_before_sydney_dst_ends = NaiveDate::from_ymd_opt(2021, 4, 3) + .unwrap() + .and_hms_nano_opt(15, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_before_sydney_dst_ends) + .fix(), + sydney_offset_with_dst + ); + // Sydney 2021-04-04T02:30:00+10:00 is 2021-04-03T16:30:00Z + let utc_just_after_sydney_dst_ends = NaiveDate::from_ymd_opt(2021, 4, 3) + .unwrap() + .and_hms_nano_opt(16, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_after_sydney_dst_ends) + .fix(), + sydney_offset_without_dst + ); + // Sydney 2021-10-03T01:30:00+10:00 is 2021-10-02T15:30:00Z + let utc_just_before_sydney_dst_starts = NaiveDate::from_ymd_opt(2021, 10, 2) + .unwrap() + .and_hms_nano_opt(15, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_before_sydney_dst_starts) + .fix(), + sydney_offset_without_dst + ); + // Sydney 2021-04-04T03:30:00+11:00 is 2021-10-02T16:30:00Z + let utc_just_after_sydney_dst_starts = NaiveDate::from_ymd_opt(2022, 10, 2) + .unwrap() + .and_hms_nano_opt(16, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_after_sydney_dst_starts) + .fix(), + sydney_offset_with_dst + ); + } + } +} + +#[cfg(not(feature = "chrono-tz"))] +mod private { + use super::*; + use chrono::offset::TimeZone; + use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::str::FromStr; + + /// An [`Offset`] for [`Tz`] + #[derive(Debug, Copy, Clone)] + pub struct TzOffset(FixedOffset); + + impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.0 + } + } + + /// An Arrow [`TimeZone`] + #[derive(Debug, Copy, Clone)] + pub struct Tz(FixedOffset); + + impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + let offset = parse_fixed_offset(tz).ok_or_else(|| { + ArrowError::ParseError(format!( + "Invalid timezone \"{tz}\": only offset based timezones supported without chrono-tz feature" + )) + })?; + Ok(Self(offset)) + } + } + + impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + Self(offset.0) + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + self.0.offset_from_local_date(local).map(TzOffset) + } + + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { + self.0.offset_from_local_datetime(local).map(TzOffset) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + TzOffset(self.0.offset_from_utc_date(utc).fix()) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + TzOffset(self.0.offset_from_utc_datetime(utc).fix()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{NaiveDate, Offset, TimeZone}; + + #[test] + fn test_with_offset() { + let t = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + + let tz: Tz = "-00:00".parse().unwrap(); + assert_eq!(tz.offset_from_utc_date(&t).fix().local_minus_utc(), 0); + let tz: Tz = "+00:00".parse().unwrap(); + assert_eq!(tz.offset_from_utc_date(&t).fix().local_minus_utc(), 0); + + let tz: Tz = "-10:00".parse().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + -10 * 60 * 60 + ); + let tz: Tz = "+09:00".parse().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let tz = "+09".parse::().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let tz = "+0900".parse::().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let err = "+9:00".parse::().unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } +} diff --git a/arrow/src/util/trusted_len.rs b/arrow-array/src/trusted_len.rs similarity index 94% rename from arrow/src/util/trusted_len.rs rename to arrow-array/src/trusted_len.rs index 84a66238b634..781cad38f7e9 100644 --- a/arrow/src/util/trusted_len.rs +++ b/arrow-array/src/trusted_len.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::bit_util; -use crate::{ - buffer::{Buffer, MutableBuffer}, - datatypes::ArrowNativeType, -}; +use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; /// Creates two [`Buffer`]s from an iterator of `Option`. /// The first buffer corresponds to a bitmap buffer, the second one @@ -67,7 +63,7 @@ mod tests { #[test] fn trusted_len_unzip_good() { - let vec = vec![Some(1u32), None]; + let vec = [Some(1u32), None]; let (null, buffer) = unsafe { trusted_len_unzip(vec.iter()) }; assert_eq!(null.as_slice(), &[0b00000001]); assert_eq!(buffer.as_slice(), &[1u8, 0, 0, 0, 0, 0, 0, 0]); diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs new file mode 100644 index 000000000000..92262fc04a57 --- /dev/null +++ b/arrow-array/src/types.rs @@ -0,0 +1,1637 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Zero-sized types used to parameterize generic array implementations + +use crate::delta::{ + add_days_datetime, add_months_datetime, shift_months, sub_days_datetime, sub_months_datetime, +}; +use crate::temporal_conversions::as_datetime_with_timezone; +use crate::timezone::Tz; +use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; +use arrow_buffer::{i256, Buffer, OffsetBuffer}; +use arrow_data::decimal::{ + is_validate_decimal256_precision, is_validate_decimal_precision, validate_decimal256_precision, + validate_decimal_precision, +}; +use arrow_data::{validate_binary_view, validate_string_view}; +use arrow_schema::{ + ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, +}; +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use half::f16; +use std::fmt::Debug; +use std::marker::PhantomData; +use std::ops::{Add, Sub}; + +// re-export types so that they can be used without importing arrow_buffer explicitly +pub use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; + +// BooleanType is special: its bit-width is not the size of the primitive type, and its `index` +// operation assumes bit-packing. +/// A boolean datatype +#[derive(Debug)] +pub struct BooleanType {} + +impl BooleanType { + /// The corresponding Arrow data type + pub const DATA_TYPE: DataType = DataType::Boolean; +} + +/// Trait for [primitive values]. +/// +/// This trait bridges the dynamic-typed nature of Arrow +/// (via [`DataType`]) with the static-typed nature of rust types +/// ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. +/// +/// [primitive values]: https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout +/// [`ArrowNativeType`]: arrow_buffer::ArrowNativeType +pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { + /// Corresponding Rust native type for the primitive type. + type Native: ArrowNativeTypeOp; + + /// the corresponding Arrow data type of this primitive type. + const DATA_TYPE: DataType; + + /// Returns the byte width of this primitive type. + #[deprecated(note = "Use ArrowNativeType::get_byte_width")] + fn get_byte_width() -> usize { + std::mem::size_of::() + } + + /// Returns a default value of this primitive type. + /// + /// This is useful for aggregate array ops like `sum()`, `mean()`. + fn default_value() -> Self::Native { + Default::default() + } +} + +mod primitive { + pub trait PrimitiveTypeSealed {} +} + +macro_rules! make_type { + ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal) => { + #[derive(Debug)] + #[doc = $doc_string] + pub struct $name {} + + impl ArrowPrimitiveType for $name { + type Native = $native_ty; + const DATA_TYPE: DataType = $data_ty; + } + + impl primitive::PrimitiveTypeSealed for $name {} + }; +} + +make_type!(Int8Type, i8, DataType::Int8, "A signed 8-bit integer type."); +make_type!( + Int16Type, + i16, + DataType::Int16, + "Signed 16-bit integer type." +); +make_type!( + Int32Type, + i32, + DataType::Int32, + "Signed 32-bit integer type." +); +make_type!( + Int64Type, + i64, + DataType::Int64, + "Signed 64-bit integer type." +); +make_type!( + UInt8Type, + u8, + DataType::UInt8, + "Unsigned 8-bit integer type." +); +make_type!( + UInt16Type, + u16, + DataType::UInt16, + "Unsigned 16-bit integer type." +); +make_type!( + UInt32Type, + u32, + DataType::UInt32, + "Unsigned 32-bit integer type." +); +make_type!( + UInt64Type, + u64, + DataType::UInt64, + "Unsigned 64-bit integer type." +); +make_type!( + Float16Type, + f16, + DataType::Float16, + "16-bit floating point number type." +); +make_type!( + Float32Type, + f32, + DataType::Float32, + "32-bit floating point number type." +); +make_type!( + Float64Type, + f64, + DataType::Float64, + "64-bit floating point number type." +); +make_type!( + TimestampSecondType, + i64, + DataType::Timestamp(TimeUnit::Second, None), + "Timestamp second type with an optional timezone." +); +make_type!( + TimestampMillisecondType, + i64, + DataType::Timestamp(TimeUnit::Millisecond, None), + "Timestamp millisecond type with an optional timezone." +); +make_type!( + TimestampMicrosecondType, + i64, + DataType::Timestamp(TimeUnit::Microsecond, None), + "Timestamp microsecond type with an optional timezone." +); +make_type!( + TimestampNanosecondType, + i64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + "Timestamp nanosecond type with an optional timezone." +); +make_type!( + Date32Type, + i32, + DataType::Date32, + "32-bit date type: the elapsed time since UNIX epoch in days (32 bits)." +); +make_type!( + Date64Type, + i64, + DataType::Date64, + "64-bit date type: the elapsed time since UNIX epoch in milliseconds (64 bits). \ + Values must be divisible by `86_400_000`. \ + See [`DataType::Date64`] for more details." +); +make_type!( + Time32SecondType, + i32, + DataType::Time32(TimeUnit::Second), + "32-bit time type: the elapsed time since midnight in seconds." +); +make_type!( + Time32MillisecondType, + i32, + DataType::Time32(TimeUnit::Millisecond), + "32-bit time type: the elapsed time since midnight in milliseconds." +); +make_type!( + Time64MicrosecondType, + i64, + DataType::Time64(TimeUnit::Microsecond), + "64-bit time type: the elapsed time since midnight in microseconds." +); +make_type!( + Time64NanosecondType, + i64, + DataType::Time64(TimeUnit::Nanosecond), + "64-bit time type: the elapsed time since midnight in nanoseconds." +); +make_type!( + IntervalYearMonthType, + i32, + DataType::Interval(IntervalUnit::YearMonth), + "32-bit “calendar” interval type: the number of whole months." +); +make_type!( + IntervalDayTimeType, + IntervalDayTime, + DataType::Interval(IntervalUnit::DayTime), + "“Calendar” interval type: days and milliseconds. See [`IntervalDayTime`] for more details." +); +make_type!( + IntervalMonthDayNanoType, + IntervalMonthDayNano, + DataType::Interval(IntervalUnit::MonthDayNano), + r"“Calendar” interval type: months, days, and nanoseconds. See [`IntervalMonthDayNano`] for more details." +); +make_type!( + DurationSecondType, + i64, + DataType::Duration(TimeUnit::Second), + "Elapsed time type: seconds." +); +make_type!( + DurationMillisecondType, + i64, + DataType::Duration(TimeUnit::Millisecond), + "Elapsed time type: milliseconds." +); +make_type!( + DurationMicrosecondType, + i64, + DataType::Duration(TimeUnit::Microsecond), + "Elapsed time type: microseconds." +); +make_type!( + DurationNanosecondType, + i64, + DataType::Duration(TimeUnit::Nanosecond), + "Elapsed time type: nanoseconds." +); + +/// A subtype of primitive type that represents legal dictionary keys. +/// See +pub trait ArrowDictionaryKeyType: ArrowPrimitiveType {} + +impl ArrowDictionaryKeyType for Int8Type {} + +impl ArrowDictionaryKeyType for Int16Type {} + +impl ArrowDictionaryKeyType for Int32Type {} + +impl ArrowDictionaryKeyType for Int64Type {} + +impl ArrowDictionaryKeyType for UInt8Type {} + +impl ArrowDictionaryKeyType for UInt16Type {} + +impl ArrowDictionaryKeyType for UInt32Type {} + +impl ArrowDictionaryKeyType for UInt64Type {} + +/// A subtype of primitive type that is used as run-ends index +/// in `RunArray`. +/// See +pub trait RunEndIndexType: ArrowPrimitiveType {} + +impl RunEndIndexType for Int16Type {} + +impl RunEndIndexType for Int32Type {} + +impl RunEndIndexType for Int64Type {} + +/// A subtype of primitive type that represents temporal values. +pub trait ArrowTemporalType: ArrowPrimitiveType {} + +impl ArrowTemporalType for TimestampSecondType {} +impl ArrowTemporalType for TimestampMillisecondType {} +impl ArrowTemporalType for TimestampMicrosecondType {} +impl ArrowTemporalType for TimestampNanosecondType {} +impl ArrowTemporalType for Date32Type {} +impl ArrowTemporalType for Date64Type {} +impl ArrowTemporalType for Time32SecondType {} +impl ArrowTemporalType for Time32MillisecondType {} +impl ArrowTemporalType for Time64MicrosecondType {} +impl ArrowTemporalType for Time64NanosecondType {} +// impl ArrowTemporalType for IntervalYearMonthType {} +// impl ArrowTemporalType for IntervalDayTimeType {} +// impl ArrowTemporalType for IntervalMonthDayNanoType {} +impl ArrowTemporalType for DurationSecondType {} +impl ArrowTemporalType for DurationMillisecondType {} +impl ArrowTemporalType for DurationMicrosecondType {} +impl ArrowTemporalType for DurationNanosecondType {} + +/// A timestamp type allows us to create array builders that take a timestamp. +pub trait ArrowTimestampType: ArrowTemporalType { + /// The [`TimeUnit`] of this timestamp. + const UNIT: TimeUnit; + + /// Returns the `TimeUnit` of this timestamp. + #[deprecated(note = "Use Self::UNIT")] + fn get_time_unit() -> TimeUnit { + Self::UNIT + } + + /// Creates a ArrowTimestampType::Native from the provided [`NaiveDateTime`] + /// + /// See [`DataType::Timestamp`] for more information on timezone handling + fn make_value(naive: NaiveDateTime) -> Option; +} + +impl ArrowTimestampType for TimestampSecondType { + const UNIT: TimeUnit = TimeUnit::Second; + + fn make_value(naive: NaiveDateTime) -> Option { + Some(naive.and_utc().timestamp()) + } +} +impl ArrowTimestampType for TimestampMillisecondType { + const UNIT: TimeUnit = TimeUnit::Millisecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let utc = naive.and_utc(); + let millis = utc.timestamp().checked_mul(1_000)?; + millis.checked_add(utc.timestamp_subsec_millis() as i64) + } +} +impl ArrowTimestampType for TimestampMicrosecondType { + const UNIT: TimeUnit = TimeUnit::Microsecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let utc = naive.and_utc(); + let micros = utc.timestamp().checked_mul(1_000_000)?; + micros.checked_add(utc.timestamp_subsec_micros() as i64) + } +} +impl ArrowTimestampType for TimestampNanosecondType { + const UNIT: TimeUnit = TimeUnit::Nanosecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let utc = naive.and_utc(); + let nanos = utc.timestamp().checked_mul(1_000_000_000)?; + nanos.checked_add(utc.timestamp_subsec_nanos() as i64) + } +} + +fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::try_milliseconds(ms as i64)?)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::try_milliseconds(ms as i64)?)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + +impl TimestampSecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampMicrosecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampMillisecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampNanosecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl IntervalYearMonthType { + /// Creates a IntervalYearMonthType::Native + /// + /// # Arguments + /// + /// * `years` - The number of years (+/-) represented in this interval + /// * `months` - The number of months (+/-) represented in this interval + #[inline] + pub fn make_value( + years: i32, + months: i32, + ) -> ::Native { + years * 12 + months + } + + /// Turns a IntervalYearMonthType type into an i32 of months. + /// + /// This operation is technically a no-op, it is included for comprehensiveness. + /// + /// # Arguments + /// + /// * `i` - The IntervalYearMonthType::Native to convert + #[inline] + pub fn to_months(i: ::Native) -> i32 { + i + } +} + +impl IntervalDayTimeType { + /// Creates a IntervalDayTimeType::Native + /// + /// # Arguments + /// + /// * `days` - The number of days (+/-) represented in this interval + /// * `millis` - The number of milliseconds (+/-) represented in this interval + #[inline] + pub fn make_value(days: i32, milliseconds: i32) -> IntervalDayTime { + IntervalDayTime { days, milliseconds } + } + + /// Turns a IntervalDayTimeType into a tuple of (days, milliseconds) + /// + /// # Arguments + /// + /// * `i` - The IntervalDayTimeType to convert + #[inline] + pub fn to_parts(i: IntervalDayTime) -> (i32, i32) { + (i.days, i.milliseconds) + } +} + +impl IntervalMonthDayNanoType { + /// Creates a IntervalMonthDayNanoType::Native + /// + /// # Arguments + /// + /// * `months` - The number of months (+/-) represented in this interval + /// * `days` - The number of days (+/-) represented in this interval + /// * `nanos` - The number of nanoseconds (+/-) represented in this interval + #[inline] + pub fn make_value(months: i32, days: i32, nanoseconds: i64) -> IntervalMonthDayNano { + IntervalMonthDayNano { + months, + days, + nanoseconds, + } + } + + /// Turns a IntervalMonthDayNanoType into a tuple of (months, days, nanos) + /// + /// # Arguments + /// + /// * `i` - The IntervalMonthDayNanoType to convert + #[inline] + pub fn to_parts(i: IntervalMonthDayNano) -> (i32, i32, i64) { + (i.months, i.days, i.nanoseconds) + } +} + +impl Date32Type { + /// Converts an arrow Date32Type into a chrono::NaiveDate + /// + /// # Arguments + /// + /// * `i` - The Date32Type to convert + pub fn to_naive_date(i: ::Native) -> NaiveDate { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + epoch.add(Duration::try_days(i as i64).unwrap()) + } + + /// Converts a chrono::NaiveDate into an arrow Date32Type + /// + /// # Arguments + /// + /// * `d` - The NaiveDate to convert + pub fn from_naive_date(d: NaiveDate) -> ::Native { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + d.sub(epoch).num_days() as ::Native + } + + /// Adds the given IntervalYearMonthType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date32Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(delta); + let posterior = shift_months(prior, months); + Date32Type::from_naive_date(posterior) + } + + /// Adds the given IntervalDayTimeType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = res.add(Duration::try_days(days as i64).unwrap()); + let res = res.add(Duration::try_milliseconds(ms as i64).unwrap()); + Date32Type::from_naive_date(res) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = shift_months(res, months); + let res = res.add(Duration::try_days(days as i64).unwrap()); + let res = res.add(Duration::nanoseconds(nanos)); + Date32Type::from_naive_date(res) + } + + /// Subtract the given IntervalYearMonthType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date32Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(-delta); + let posterior = shift_months(prior, months); + Date32Type::from_naive_date(posterior) + } + + /// Subtract the given IntervalDayTimeType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = res.sub(Duration::try_days(days as i64).unwrap()); + let res = res.sub(Duration::try_milliseconds(ms as i64).unwrap()); + Date32Type::from_naive_date(res) + } + + /// Subtract the given IntervalMonthDayNanoType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = shift_months(res, -months); + let res = res.sub(Duration::try_days(days as i64).unwrap()); + let res = res.sub(Duration::nanoseconds(nanos)); + Date32Type::from_naive_date(res) + } +} + +impl Date64Type { + /// Converts an arrow Date64Type into a chrono::NaiveDate + /// + /// # Arguments + /// + /// * `i` - The Date64Type to convert + pub fn to_naive_date(i: ::Native) -> NaiveDate { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + epoch.add(Duration::try_milliseconds(i).unwrap()) + } + + /// Converts a chrono::NaiveDate into an arrow Date64Type + /// + /// # Arguments + /// + /// * `d` - The NaiveDate to convert + pub fn from_naive_date(d: NaiveDate) -> ::Native { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + d.sub(epoch).num_milliseconds() as ::Native + } + + /// Adds the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date64Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(delta); + let posterior = shift_months(prior, months); + Date64Type::from_naive_date(posterior) + } + + /// Adds the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = res.add(Duration::try_days(days as i64).unwrap()); + let res = res.add(Duration::try_milliseconds(ms as i64).unwrap()); + Date64Type::from_naive_date(res) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = shift_months(res, months); + let res = res.add(Duration::try_days(days as i64).unwrap()); + let res = res.add(Duration::nanoseconds(nanos)); + Date64Type::from_naive_date(res) + } + + /// Subtract the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date64Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(-delta); + let posterior = shift_months(prior, months); + Date64Type::from_naive_date(posterior) + } + + /// Subtract the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = res.sub(Duration::try_days(days as i64).unwrap()); + let res = res.sub(Duration::try_milliseconds(ms as i64).unwrap()); + Date64Type::from_naive_date(res) + } + + /// Subtract the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = shift_months(res, -months); + let res = res.sub(Duration::try_days(days as i64).unwrap()); + let res = res.sub(Duration::nanoseconds(nanos)); + Date64Type::from_naive_date(res) + } +} + +/// Crate private types for Decimal Arrays +/// +/// Not intended to be used outside this crate +mod decimal { + use super::*; + + pub trait DecimalTypeSealed {} + impl DecimalTypeSealed for Decimal128Type {} + impl DecimalTypeSealed for Decimal256Type {} +} + +/// A trait over the decimal types, used by [`PrimitiveArray`] to provide a generic +/// implementation across the various decimal types +/// +/// Implemented by [`Decimal128Type`] and [`Decimal256Type`] for [`Decimal128Array`] +/// and [`Decimal256Array`] respectively +/// +/// [`PrimitiveArray`]: crate::array::PrimitiveArray +/// [`Decimal128Array`]: crate::array::Decimal128Array +/// [`Decimal256Array`]: crate::array::Decimal256Array +pub trait DecimalType: + 'static + Send + Sync + ArrowPrimitiveType + decimal::DecimalTypeSealed +{ + /// Width of the type + const BYTE_LENGTH: usize; + /// Maximum number of significant digits + const MAX_PRECISION: u8; + /// Maximum no of digits after the decimal point (note the scale can be negative) + const MAX_SCALE: i8; + /// fn to create its [`DataType`] + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType; + /// Default values for [`DataType`] + const DEFAULT_TYPE: DataType; + + /// "Decimal128" or "Decimal256", for use in error messages + const PREFIX: &'static str; + + /// Formats the decimal value with the provided precision and scale + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String; + + /// Validates that `value` contains no more than `precision` decimal digits + fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>; + + /// Determines whether `value` contains no more than `precision` decimal digits + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool; +} + +/// Validate that `precision` and `scale` are valid for `T` +/// +/// Returns an Error if: +/// - `precision` is zero +/// - `precision` is larger than `T:MAX_PRECISION` +/// - `scale` is larger than `T::MAX_SCALE` +/// - `scale` is > `precision` +pub fn validate_decimal_precision_and_scale( + precision: u8, + scale: i8, +) -> Result<(), ArrowError> { + if precision == 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "precision cannot be 0, has to be between [1, {}]", + T::MAX_PRECISION + ))); + } + if precision > T::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + T::MAX_PRECISION + ))); + } + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + T::MAX_SCALE + ))); + } + if scale > 0 && scale as u8 > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {scale} is greater than precision {precision}" + ))); + } + + Ok(()) +} + +/// The decimal type for a Decimal128Array +#[derive(Debug)] +pub struct Decimal128Type {} + +impl DecimalType for Decimal128Type { + const BYTE_LENGTH: usize = 16; + const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128; + const DEFAULT_TYPE: DataType = + DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal128"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> { + validate_decimal_precision(num, precision) + } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal_precision(value, precision) + } +} + +impl ArrowPrimitiveType for Decimal128Type { + type Native = i128; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal128Type {} + +/// The decimal type for a Decimal256Array +#[derive(Debug)] +pub struct Decimal256Type {} + +impl DecimalType for Decimal256Type { + const BYTE_LENGTH: usize = 32; + const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256; + const DEFAULT_TYPE: DataType = + DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal256"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> { + validate_decimal256_precision(num, precision) + } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal256_precision(value, precision) + } +} + +impl ArrowPrimitiveType for Decimal256Type { + type Native = i256; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal256Type {} + +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + +/// Crate private types for Byte Arrays +/// +/// Not intended to be used outside this crate +pub(crate) mod bytes { + use super::*; + + pub trait ByteArrayTypeSealed {} + impl ByteArrayTypeSealed for GenericStringType {} + impl ByteArrayTypeSealed for GenericBinaryType {} + + pub trait ByteArrayNativeType: std::fmt::Debug + Send + Sync { + fn from_bytes_checked(b: &[u8]) -> Option<&Self>; + + /// # Safety + /// + /// `b` must be a valid byte sequence for `Self` + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self; + } + + impl ByteArrayNativeType for [u8] { + #[inline] + fn from_bytes_checked(b: &[u8]) -> Option<&Self> { + Some(b) + } + + #[inline] + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { + b + } + } + + impl ByteArrayNativeType for str { + #[inline] + fn from_bytes_checked(b: &[u8]) -> Option<&Self> { + std::str::from_utf8(b).ok() + } + + #[inline] + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { + std::str::from_utf8_unchecked(b) + } + } +} + +/// A trait over the variable-size byte array types +/// +/// See [Variable Size Binary Layout](https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout) +pub trait ByteArrayType: 'static + Send + Sync + bytes::ByteArrayTypeSealed { + /// Type of offset i.e i32/i64 + type Offset: OffsetSizeTrait; + /// Type for representing its equivalent rust type i.e + /// Utf8Array will have native type has &str + /// BinaryArray will have type as [u8] + type Native: bytes::ByteArrayNativeType + AsRef + AsRef<[u8]> + ?Sized; + + /// "Binary" or "String", for use in error messages + const PREFIX: &'static str; + + /// Datatype of array elements + const DATA_TYPE: DataType; + + /// Verifies that every consecutive pair of `offsets` denotes a valid slice of `values` + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError>; +} + +/// [`ByteArrayType`] for string arrays +pub struct GenericStringType { + phantom: PhantomData, +} + +impl ByteArrayType for GenericStringType { + type Offset = O; + type Native = str; + const PREFIX: &'static str = "String"; + + const DATA_TYPE: DataType = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { + // Verify that the slice as a whole is valid UTF-8 + let validated = std::str::from_utf8(values).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Encountered non UTF-8 data: {e}")) + })?; + + // Verify each offset is at a valid character boundary in this UTF-8 array + for offset in offsets.iter() { + let o = offset.as_usize(); + if !validated.is_char_boundary(o) { + if o < validated.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Split UTF-8 codepoint at offset {o}" + ))); + } + return Err(ArrowError::InvalidArgumentError(format!( + "Offset of {o} exceeds length of values {}", + validated.len() + ))); + } + } + Ok(()) + } +} + +/// An arrow utf8 array with i32 offsets +pub type Utf8Type = GenericStringType; +/// An arrow utf8 array with i64 offsets +pub type LargeUtf8Type = GenericStringType; + +/// [`ByteArrayType`] for binary arrays +pub struct GenericBinaryType { + phantom: PhantomData, +} + +impl ByteArrayType for GenericBinaryType { + type Offset = O; + type Native = [u8]; + const PREFIX: &'static str = "Binary"; + + const DATA_TYPE: DataType = if O::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + }; + + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { + // offsets are guaranteed to be monotonically increasing and non-empty + let max_offset = offsets.last().unwrap().as_usize(); + if values.len() < max_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "Maximum offset of {max_offset} is larger than values of length {}", + values.len() + ))); + } + Ok(()) + } +} + +/// An arrow binary array with i32 offsets +pub type BinaryType = GenericBinaryType; +/// An arrow binary array with i64 offsets +pub type LargeBinaryType = GenericBinaryType; + +mod byte_view { + use crate::types::{BinaryViewType, StringViewType}; + + pub trait Sealed: Send + Sync {} + impl Sealed for StringViewType {} + impl Sealed for BinaryViewType {} +} + +/// A trait over the variable length bytes view array types +pub trait ByteViewType: byte_view::Sealed + 'static + PartialEq + Send + Sync { + /// If element in array is utf8 encoded string. + const IS_UTF8: bool; + + /// Datatype of array elements + const DATA_TYPE: DataType = if Self::IS_UTF8 { + DataType::Utf8View + } else { + DataType::BinaryView + }; + + /// "Binary" or "String", for use in displayed or error messages + const PREFIX: &'static str; + + /// Type for representing its equivalent rust type i.e + /// Utf8Array will have native type has &str + /// BinaryArray will have type as [u8] + type Native: bytes::ByteArrayNativeType + AsRef + AsRef<[u8]> + ?Sized; + + /// Type for owned corresponding to `Native` + type Owned: Debug + Clone + Sync + Send + AsRef; + + /// Verifies that the provided buffers are valid for this array type + fn validate(views: &[u128], buffers: &[Buffer]) -> Result<(), ArrowError>; +} + +/// [`ByteViewType`] for string arrays +#[derive(PartialEq)] +pub struct StringViewType {} + +impl ByteViewType for StringViewType { + const IS_UTF8: bool = true; + const PREFIX: &'static str = "String"; + + type Native = str; + type Owned = String; + + fn validate(views: &[u128], buffers: &[Buffer]) -> Result<(), ArrowError> { + validate_string_view(views, buffers) + } +} + +/// [`BinaryViewType`] for string arrays +#[derive(PartialEq)] +pub struct BinaryViewType {} + +impl ByteViewType for BinaryViewType { + const IS_UTF8: bool = false; + const PREFIX: &'static str = "Binary"; + type Native = [u8]; + type Owned = Vec; + + fn validate(views: &[u128], buffers: &[Buffer]) -> Result<(), ArrowError> { + validate_binary_view(views, buffers) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_data::{layout, BufferSpec}; + + #[test] + fn month_day_nano_should_roundtrip() { + let value = IntervalMonthDayNanoType::make_value(1, 2, 3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (1, 2, 3)); + } + + #[test] + fn month_day_nano_should_roundtrip_neg() { + let value = IntervalMonthDayNanoType::make_value(-1, -2, -3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (-1, -2, -3)); + } + + #[test] + fn day_time_should_roundtrip() { + let value = IntervalDayTimeType::make_value(1, 2); + assert_eq!(IntervalDayTimeType::to_parts(value), (1, 2)); + } + + #[test] + fn day_time_should_roundtrip_neg() { + let value = IntervalDayTimeType::make_value(-1, -2); + assert_eq!(IntervalDayTimeType::to_parts(value), (-1, -2)); + } + + #[test] + fn year_month_should_roundtrip() { + let value = IntervalYearMonthType::make_value(1, 2); + assert_eq!(IntervalYearMonthType::to_months(value), 14); + } + + #[test] + fn year_month_should_roundtrip_neg() { + let value = IntervalYearMonthType::make_value(-1, -2); + assert_eq!(IntervalYearMonthType::to_months(value), -14); + } + + fn test_layout() { + let layout = layout(&T::DATA_TYPE); + + assert_eq!(layout.buffers.len(), 1); + + let spec = &layout.buffers[0]; + assert_eq!( + spec, + &BufferSpec::FixedWidth { + byte_width: std::mem::size_of::(), + alignment: std::mem::align_of::(), + } + ); + } + + #[test] + fn test_layouts() { + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + } +} diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml new file mode 100644 index 000000000000..d2436f0c15de --- /dev/null +++ b/arrow-avro/Cargo.toml @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-avro" +version = { workspace = true } +description = "Support for parsing Avro format into the Arrow format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_avro" +path = "src/lib.rs" +bench = false + +[features] +default = ["deflate", "snappy", "zstd"] +deflate = ["flate2"] +snappy = ["snap", "crc"] + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +serde = { version = "1.0.188", features = ["derive"] } +flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } +snap = { version = "1.0", default-features = false, optional = true } +zstd = { version = "0.13", default-features = false, optional = true } +crc = { version = "3.0", optional = true } + + +[dev-dependencies] + diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs new file mode 100644 index 000000000000..1e2acd99d828 --- /dev/null +++ b/arrow-avro/src/codec.rs @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, +}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +/// Avro types are not nullable, with nullability instead encoded as a union +/// where one of the variants is the null type. +/// +/// To accommodate this we special case two-variant unions where one of the +/// variants is the null type, and use this to derive arrow's notion of nullability +#[derive(Debug, Copy, Clone)] +enum Nulls { + /// The nulls are encoded as the first union variant + NullFirst, + /// The nulls are encoded as the second union variant + NullSecond, +} + +/// An Avro datatype mapped to the arrow data model +#[derive(Debug, Clone)] +pub struct AvroDataType { + nulls: Option, + metadata: HashMap, + codec: Codec, +} + +impl AvroDataType { + /// Returns an arrow [`Field`] with the given name + pub fn field_with_name(&self, name: &str) -> Field { + let d = self.codec.data_type(); + Field::new(name, d, self.nulls.is_some()).with_metadata(self.metadata.clone()) + } +} + +/// A named [`AvroDataType`] +#[derive(Debug, Clone)] +pub struct AvroField { + name: String, + data_type: AvroDataType, +} + +impl AvroField { + /// Returns the arrow [`Field`] + pub fn field(&self) -> Field { + self.data_type.field_with_name(&self.name) + } + + /// Returns the [`Codec`] + pub fn codec(&self) -> &Codec { + &self.data_type.codec + } +} + +impl<'a> TryFrom<&Schema<'a>> for AvroField { + type Error = ArrowError; + + fn try_from(schema: &Schema<'a>) -> Result { + match schema { + Schema::Complex(ComplexType::Record(r)) => { + let mut resolver = Resolver::default(); + let data_type = make_data_type(schema, None, &mut resolver)?; + Ok(AvroField { + data_type, + name: r.name.to_string(), + }) + } + _ => Err(ArrowError::ParseError(format!( + "Expected record got {schema:?}" + ))), + } + } +} + +/// An Avro encoding +/// +/// +#[derive(Debug, Clone)] +pub enum Codec { + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Binary, + Utf8, + Date32, + TimeMillis, + TimeMicros, + /// TimestampMillis(is_utc) + TimestampMillis(bool), + /// TimestampMicros(is_utc) + TimestampMicros(bool), + Fixed(i32), + List(Arc), + Struct(Arc<[AvroField]>), + Duration, +} + +impl Codec { + fn data_type(&self) -> DataType { + match self { + Self::Null => DataType::Null, + Self::Boolean => DataType::Boolean, + Self::Int32 => DataType::Int32, + Self::Int64 => DataType::Int64, + Self::Float32 => DataType::Float32, + Self::Float64 => DataType::Float64, + Self::Binary => DataType::Binary, + Self::Utf8 => DataType::Utf8, + Self::Date32 => DataType::Date32, + Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), + Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::TimestampMillis(is_utc) => { + DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + } + Self::TimestampMicros(is_utc) => { + DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + } + Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), + Self::Fixed(size) => DataType::FixedSizeBinary(*size), + Self::List(f) => DataType::List(Arc::new(f.field_with_name("item"))), + Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + } + } +} + +impl From for Codec { + fn from(value: PrimitiveType) -> Self { + match value { + PrimitiveType::Null => Self::Null, + PrimitiveType::Boolean => Self::Boolean, + PrimitiveType::Int => Self::Int32, + PrimitiveType::Long => Self::Int64, + PrimitiveType::Float => Self::Float32, + PrimitiveType::Double => Self::Float64, + PrimitiveType::Bytes => Self::Binary, + PrimitiveType::String => Self::Utf8, + } + } +} + +/// Resolves Avro type names to [`AvroDataType`] +/// +/// See +#[derive(Debug, Default)] +struct Resolver<'a> { + map: HashMap<(&'a str, &'a str), AvroDataType>, +} + +impl<'a> Resolver<'a> { + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { + self.map.insert((name, namespace.unwrap_or("")), schema); + } + + fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { + let (namespace, name) = name + .rsplit_once('.') + .unwrap_or_else(|| (namespace.unwrap_or(""), name)); + + self.map + .get(&(namespace, name)) + .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) + .cloned() + } +} + +/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` +/// +/// `name`: is name used to refer to `schema` in its parent +/// `namespace`: an optional qualifier used as part of a type hierarchy +/// +/// See [`Resolver`] for more information +fn make_data_type<'a>( + schema: &Schema<'a>, + namespace: Option<&'a str>, + resolver: &mut Resolver<'a>, +) -> Result { + match schema { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { + nulls: None, + metadata: Default::default(), + codec: (*p).into(), + }), + Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), + Schema::Union(f) => { + // Special case the common case of nullable primitives + let null = f + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); + match (f.len() == 2, null) { + (true, Some(0)) => { + let mut field = make_data_type(&f[1], namespace, resolver)?; + field.nulls = Some(Nulls::NullFirst); + Ok(field) + } + (true, Some(1)) => { + let mut field = make_data_type(&f[0], namespace, resolver)?; + field.nulls = Some(Nulls::NullSecond); + Ok(field) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Union of {f:?} not currently supported" + ))), + } + } + Schema::Complex(c) => match c { + ComplexType::Record(r) => { + let namespace = r.namespace.or(namespace); + let fields = r + .fields + .iter() + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: make_data_type(&field.r#type, namespace, resolver)?, + }) + }) + .collect::>()?; + + let field = AvroDataType { + nulls: None, + codec: Codec::Struct(fields), + metadata: r.attributes.field_metadata(), + }; + resolver.register(r.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Array(a) => { + let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + Ok(AvroDataType { + nulls: None, + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), + }) + } + ComplexType::Fixed(f) => { + let size = f.size.try_into().map_err(|e| { + ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) + })?; + + let field = AvroDataType { + nulls: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( + "Enum of {e:?} not currently supported" + ))), + ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( + "Map of {m:?} not currently supported" + ))), + }, + Schema::Type(t) => { + let mut field = + make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; + + // https://avro.apache.org/docs/1.11.1/specification/#logical-types + match (t.attributes.logical_type, &mut field.codec) { + (Some("decimal"), c @ Codec::Fixed(_)) => { + return Err(ArrowError::NotYetImplemented( + "Decimals are not currently supported".to_string(), + )) + } + (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, + (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, + (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, + (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), + (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } + (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, + (Some(logical), _) => { + // Insert unrecognized logical type into metadata map + field.metadata.insert("logicalType".into(), logical.into()); + } + (None, _) => {} + } + + if !t.attributes.additional.is_empty() { + for (k, v) in &t.attributes.additional { + field.metadata.insert(k.to_string(), v.to_string()); + } + } + Ok(field) + } + } +} diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs new file mode 100644 index 000000000000..c5c7a6dabc33 --- /dev/null +++ b/arrow-avro/src/compression.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use flate2::read; +use std::io; +use std::io::Read; + +/// The metadata key used for storing the JSON encoded [`CompressionCodec`] +pub const CODEC_METADATA_KEY: &str = "avro.codec"; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum CompressionCodec { + Deflate, + Snappy, + ZStandard, +} + +impl CompressionCodec { + pub(crate) fn decompress(&self, block: &[u8]) -> Result, ArrowError> { + match self { + #[cfg(feature = "deflate")] + CompressionCodec::Deflate => { + let mut decoder = read::DeflateDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "deflate"))] + CompressionCodec::Deflate => Err(ArrowError::ParseError( + "Deflate codec requires deflate feature".to_string(), + )), + #[cfg(feature = "snappy")] + CompressionCodec::Snappy => { + // Each compressed block is followed by the 4-byte, big-endian CRC32 + // checksum of the uncompressed data in the block. + let crc = &block[block.len() - 4..]; + let block = &block[..block.len() - 4]; + + let mut decoder = snap::raw::Decoder::new(); + let decoded = decoder + .decompress_vec(block) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + + let checksum = crc::Crc::::new(&crc::CRC_32_ISO_HDLC).checksum(&decoded); + if checksum != u32::from_be_bytes(crc.try_into().unwrap()) { + return Err(ArrowError::ParseError("Snappy CRC mismatch".to_string())); + } + Ok(decoded) + } + #[cfg(not(feature = "snappy"))] + CompressionCodec::Snappy => Err(ArrowError::ParseError( + "Snappy codec requires snappy feature".to_string(), + )), + + #[cfg(feature = "zstd")] + CompressionCodec::ZStandard => { + let mut decoder = zstd::Decoder::new(block)?; + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "zstd"))] + CompressionCodec::ZStandard => Err(ArrowError::ParseError( + "ZStandard codec requires zstd feature".to_string(), + )), + } + } +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs new file mode 100644 index 000000000000..d01d681b7af0 --- /dev/null +++ b/arrow-avro/src/lib.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convert data to / from the [Apache Arrow] memory format and [Apache Avro] +//! +//! [Apache Arrow]: https://arrow.apache.org +//! [Apache Avro]: https://avro.apache.org/ + +#![warn(missing_docs)] +#![allow(unused)] // Temporary + +pub mod reader; +mod schema; + +mod compression; + +mod codec; + +#[cfg(test)] +mod test_util { + pub fn arrow_test_data(path: &str) -> String { + match std::env::var("ARROW_TEST_DATA") { + Ok(dir) => format!("{dir}/{path}"), + Err(_) => format!("../testing/data/{path}"), + } + } +} diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs new file mode 100644 index 000000000000..479f0ef90909 --- /dev/null +++ b/arrow-avro/src/reader/block.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Block`] + +use crate::reader::vlq::VLQDecoder; +use arrow_schema::ArrowError; + +/// A file data block +/// +/// +#[derive(Debug, Default)] +pub struct Block { + /// The number of objects in this block + pub count: usize, + /// The serialized objects within this block + pub data: Vec, + /// The sync marker + pub sync: [u8; 16], +} + +/// A decoder for [`Block`] +#[derive(Debug)] +pub struct BlockDecoder { + state: BlockDecoderState, + in_progress: Block, + vlq_decoder: VLQDecoder, + bytes_remaining: usize, +} + +#[derive(Debug)] +enum BlockDecoderState { + Count, + Size, + Data, + Sync, + Finished, +} + +impl Default for BlockDecoder { + fn default() -> Self { + Self { + state: BlockDecoderState::Count, + in_progress: Default::default(), + vlq_decoder: Default::default(), + bytes_remaining: 0, + } + } +} + +impl BlockDecoder { + /// Parse [`Block`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once an entire [`Block`] has been decoded this method will not read any further + /// input bytes, until [`Self::flush`] is called. Afterwards [`Self::decode`] + /// can then be used again to read the next block, if any + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + BlockDecoderState::Count => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.in_progress.count = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block count cannot be negative, got {c}" + )) + })?; + + self.state = BlockDecoderState::Size; + } + } + BlockDecoderState::Size => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block size cannot be negative, got {c}" + )) + })?; + + self.in_progress.data.reserve(self.bytes_remaining); + self.state = BlockDecoderState::Data; + } + } + BlockDecoderState::Data => { + let to_read = self.bytes_remaining.min(buf.len()); + self.in_progress.data.extend_from_slice(&buf[..to_read]); + buf = &buf[to_read..]; + self.bytes_remaining -= to_read; + if self.bytes_remaining == 0 { + self.bytes_remaining = 16; + self.state = BlockDecoderState::Sync; + } + } + BlockDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.in_progress.sync[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = BlockDecoderState::Finished; + } + } + BlockDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Block`] if any + pub fn flush(&mut self) -> Option { + match self.state { + BlockDecoderState::Finished => { + self.state = BlockDecoderState::Count; + Some(std::mem::take(&mut self.in_progress)) + } + _ => None, + } + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs new file mode 100644 index 000000000000..19d48d1f89a1 --- /dev/null +++ b/arrow-avro/src/reader/header.rs @@ -0,0 +1,345 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Header`] + +use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; +use crate::reader::vlq::VLQDecoder; +use crate::schema::Schema; +use arrow_schema::ArrowError; + +#[derive(Debug)] +enum HeaderDecoderState { + /// Decoding the [`MAGIC`] prefix + Magic, + /// Decoding a block count + BlockCount, + /// Decoding a block byte length + BlockLen, + /// Decoding a key length + KeyLen, + /// Decoding a key string + Key, + /// Decoding a value length + ValueLen, + /// Decoding a value payload + Value, + /// Decoding sync marker + Sync, + /// Finished decoding + Finished, +} + +/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files) +#[derive(Debug, Clone)] +pub struct Header { + meta_offsets: Vec, + meta_buf: Vec, + sync: [u8; 16], +} + +impl Header { + /// Returns an iterator over the meta keys in this header + pub fn metadata(&self) -> impl Iterator { + let mut last = 0; + self.meta_offsets.chunks_exact(2).map(move |w| { + let start = last; + last = w[1]; + (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]]) + }) + } + + /// Returns the value for a given metadata key if present + pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> { + self.metadata() + .find_map(|(k, v)| (k == key.as_ref()).then_some(v)) + } + + /// Returns the sync token for this file + pub fn sync(&self) -> [u8; 16] { + self.sync + } + + /// Returns the [`CompressionCodec`] if any + pub fn compression(&self) -> Result, ArrowError> { + let v = self.get(CODEC_METADATA_KEY); + + match v { + None | Some(b"null") => Ok(None), + Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)), + Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)), + Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)), + Some(v) => Err(ArrowError::ParseError(format!( + "Unrecognized compression codec \'{}\'", + String::from_utf8_lossy(v) + ))), + } + } +} + +/// A decoder for [`Header`] +/// +/// The avro file format does not encode the length of the header, and so it +/// is necessary to provide a push-based decoder that can be used with streams +#[derive(Debug)] +pub struct HeaderDecoder { + state: HeaderDecoderState, + vlq_decoder: VLQDecoder, + + /// The end offsets of strings in `meta_buf` + meta_offsets: Vec, + /// The raw binary data of the metadata map + meta_buf: Vec, + + /// The decoded sync marker + sync_marker: [u8; 16], + + /// The number of remaining tuples in the current block + tuples_remaining: usize, + /// The number of bytes remaining in the current string/bytes payload + bytes_remaining: usize, +} + +impl Default for HeaderDecoder { + fn default() -> Self { + Self { + state: HeaderDecoderState::Magic, + meta_offsets: vec![], + meta_buf: vec![], + sync_marker: [0; 16], + vlq_decoder: Default::default(), + tuples_remaining: 0, + bytes_remaining: MAGIC.len(), + } + } +} + +const MAGIC: &[u8; 4] = b"Obj\x01"; + +impl HeaderDecoder { + /// Parse [`Header`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once the entire [`Header`] has been decoded this method will not read any further + /// input bytes, and the header can be obtained with [`Self::flush`] + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + HeaderDecoderState::Magic => { + let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..]; + let to_decode = buf.len().min(remaining.len()); + if !buf.starts_with(&remaining[..to_decode]) { + return Err(ArrowError::ParseError("Incorrect avro magic".to_string())); + } + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::BlockCount; + } + } + HeaderDecoderState::BlockCount => { + if let Some(block_count) = self.vlq_decoder.long(&mut buf) { + match block_count.try_into() { + Ok(0) => { + self.state = HeaderDecoderState::Sync; + self.bytes_remaining = 16; + } + Ok(remaining) => { + self.tuples_remaining = remaining; + self.state = HeaderDecoderState::KeyLen; + } + Err(_) => { + self.tuples_remaining = block_count.unsigned_abs() as _; + self.state = HeaderDecoderState::BlockLen; + } + } + } + } + HeaderDecoderState::BlockLen => { + if self.vlq_decoder.long(&mut buf).is_some() { + self.state = HeaderDecoderState::KeyLen + } + } + HeaderDecoderState::Key => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + self.state = HeaderDecoderState::ValueLen; + } + } + HeaderDecoderState::Value => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + + self.tuples_remaining -= 1; + match self.tuples_remaining { + 0 => self.state = HeaderDecoderState::BlockCount, + _ => self.state = HeaderDecoderState::KeyLen, + } + } + } + HeaderDecoderState::KeyLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Key; + } + } + HeaderDecoderState::ValueLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Value; + } + } + HeaderDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.sync_marker[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::Finished; + } + } + HeaderDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Header`] if any + pub fn flush(&mut self) -> Option

{ + match self.state { + HeaderDecoderState::Finished => { + self.state = HeaderDecoderState::Magic; + Some(Header { + meta_offsets: std::mem::take(&mut self.meta_offsets), + meta_buf: std::mem::take(&mut self.meta_buf), + sync: self.sync_marker, + }) + } + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::codec::{AvroDataType, AvroField}; + use crate::reader::read_header; + use crate::schema::SCHEMA_METADATA_KEY; + use crate::test_util::arrow_test_data; + use arrow_schema::{DataType, Field, Fields, TimeUnit}; + use std::fs::File; + use std::io::{BufRead, BufReader}; + + #[test] + fn test_header_decode() { + let mut decoder = HeaderDecoder::default(); + for m in MAGIC { + decoder.decode(std::slice::from_ref(m)).unwrap(); + } + + let mut decoder = HeaderDecoder::default(); + assert_eq!(decoder.decode(MAGIC).unwrap(), 4); + + let mut decoder = HeaderDecoder::default(); + decoder.decode(b"Ob").unwrap(); + let err = decoder.decode(b"s").unwrap_err().to_string(); + assert_eq!(err, "Parser error: Incorrect avro magic"); + } + + fn decode_file(file: &str) -> Header { + let file = File::open(file).unwrap(); + read_header(BufReader::with_capacity(100, file)).unwrap() + } + + #[test] + fn test_header() { + let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro")); + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; + assert_eq!(schema_json, expected); + let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + let field = AvroField::try_from(&schema).unwrap(); + + assert_eq!( + field.field(), + Field::new( + "topLevelRecord", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, true), + Field::new("bool_col", DataType::Boolean, true), + Field::new("tinyint_col", DataType::Int32, true), + Field::new("smallint_col", DataType::Int32, true), + Field::new("int_col", DataType::Int32, true), + Field::new("bigint_col", DataType::Int64, true), + Field::new("float_col", DataType::Float32, true), + Field::new("double_col", DataType::Float64, true), + Field::new("date_string_col", DataType::Binary, true), + Field::new("string_col", DataType::Binary, true), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + true + ), + ])), + false + ) + ); + + assert_eq!( + u128::from_le_bytes(header.sync()), + 226966037233754408753420635932530907102 + ); + + let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro")); + + let meta: Vec<_> = header + .metadata() + .map(|(k, _)| std::str::from_utf8(k).unwrap()) + .collect(); + + assert_eq!( + meta, + &["avro.schema", "org.apache.spark.version", "avro.codec"] + ); + + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#; + assert_eq!(schema_json, expected); + let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + assert_eq!( + u128::from_le_bytes(header.sync()), + 325166208089902833952788552656412487328 + ); + } +} diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs new file mode 100644 index 000000000000..0151db7f855a --- /dev/null +++ b/arrow-avro/src/reader/mod.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read Avro data to Arrow + +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::header::{Header, HeaderDecoder}; +use arrow_schema::ArrowError; +use std::io::BufRead; + +mod header; + +mod block; + +mod vlq; + +/// Read a [`Header`] from the provided [`BufRead`] +fn read_header(mut reader: R) -> Result { + let mut decoder = HeaderDecoder::default(); + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + + decoder + .flush() + .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +} + +/// Return an iterator of [`Block`] from the provided [`BufRead`] +fn read_blocks(mut reader: R) -> impl Iterator> { + let mut decoder = BlockDecoder::default(); + + let mut try_next = move || { + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + Ok(decoder.flush()) + }; + std::iter::from_fn(move || try_next().transpose()) +} + +#[cfg(test)] +mod test { + use crate::compression::CompressionCodec; + use crate::reader::{read_blocks, read_header}; + use crate::test_util::arrow_test_data; + use std::fs::File; + use std::io::BufReader; + + #[test] + fn test_mux() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_nulls_plain.avro", + ]; + + for file in files { + println!("file: {file}"); + let file = File::open(arrow_test_data(file)).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + let compression = header.compression().unwrap(); + println!("compression: {compression:?}"); + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + if let Some(c) = compression { + c.decompress(&block.data).unwrap(); + } + } + } + } +} diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs new file mode 100644 index 000000000000..80f1c60eec7d --- /dev/null +++ b/arrow-avro/src/reader/vlq.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Decoder for zig-zag encoded variable length (VLW) integers +/// +/// See also: +/// +/// +#[derive(Debug, Default)] +pub struct VLQDecoder { + /// Scratch space for decoding VLQ integers + in_progress: u64, + shift: u32, +} + +impl VLQDecoder { + /// Decode a signed long from `buf` + pub fn long(&mut self, buf: &mut &[u8]) -> Option { + while let Some(byte) = buf.first().copied() { + *buf = &buf[1..]; + self.in_progress |= ((byte & 0x7F) as u64) << self.shift; + self.shift += 7; + if byte & 0x80 == 0 { + let val = self.in_progress; + self.in_progress = 0; + self.shift = 0; + return Some((val >> 1) as i64 ^ -((val & 1) as i64)); + } + } + None + } +} diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs new file mode 100644 index 000000000000..6707f8137c9b --- /dev/null +++ b/arrow-avro/src/schema.rs @@ -0,0 +1,512 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The metadata key used for storing the JSON encoded [`Schema`] +pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; + +/// Either a [`PrimitiveType`] or a reference to a previously defined named type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TypeName<'a> { + Primitive(PrimitiveType), + Ref(&'a str), +} + +/// A primitive type +/// +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PrimitiveType { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, +} + +/// Additional attributes within a [`Schema`] +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Attributes<'a> { + /// A logical type name + /// + /// + #[serde(default)] + pub logical_type: Option<&'a str>, + + /// Additional JSON attributes + #[serde(flatten)] + pub additional: HashMap<&'a str, serde_json::Value>, +} + +impl<'a> Attributes<'a> { + /// Returns the field metadata for this [`Attributes`] + pub(crate) fn field_metadata(&self) -> HashMap { + self.additional + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } +} + +/// A type definition that is not a variant of [`ComplexType`] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Type<'a> { + #[serde(borrow)] + pub r#type: TypeName<'a>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An Avro schema +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Schema<'a> { + #[serde(borrow)] + TypeName(TypeName<'a>), + #[serde(borrow)] + Union(Vec>), + #[serde(borrow)] + Complex(ComplexType<'a>), + #[serde(borrow)] + Type(Type<'a>), +} + +/// A complex type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ComplexType<'a> { + #[serde(borrow)] + Record(Record<'a>), + #[serde(borrow)] + Enum(Enum<'a>), + #[serde(borrow)] + Array(Array<'a>), + #[serde(borrow)] + Map(Map<'a>), + #[serde(borrow)] + Fixed(Fixed<'a>), +} + +/// A record +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Record<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub fields: Vec>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A field within a [`Record`] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Field<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow)] + pub r#type: Schema<'a>, + #[serde(borrow, default)] + pub default: Option<&'a str>, +} + +/// An enumeration +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Enum<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub symbols: Vec<&'a str>, + #[serde(borrow, default)] + pub default: Option<&'a str>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Array<'a> { + #[serde(borrow)] + pub items: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A map +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Map<'a> { + #[serde(borrow)] + pub values: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A fixed length binary array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Fixed<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + pub size: usize, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::codec::{AvroDataType, AvroField}; + use arrow_schema::{DataType, Fields, TimeUnit}; + use serde_json::json; + + #[test] + fn test_deserialize() { + let t: Schema = serde_json::from_str("\"string\"").unwrap(); + assert_eq!( + t, + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)) + ); + + let t: Schema = serde_json::from_str("[\"int\", \"null\"]").unwrap(); + assert_eq!( + t, + Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]) + ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"long", + "logicalType":"timestamp-micros" + }"#, + ) + .unwrap(); + + let timestamp = Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("timestamp-micros"), + additional: Default::default(), + }, + }; + + assert_eq!(t, timestamp); + + let t: ComplexType = serde_json::from_str( + r#"{ + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }"#, + ) + .unwrap(); + + let decimal = ComplexType::Fixed(Fixed { + name: "fixed", + namespace: Some("topLevelRecord.value"), + aliases: vec![], + size: 11, + attributes: Attributes { + logical_type: Some("decimal"), + additional: vec![("precision", json!(25)), ("scale", json!(2))] + .into_iter() + .collect(), + }, + }); + + assert_eq!(t, decimal); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"value", + "type":[ + { + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "value", + doc: None, + r#type: Schema::Union(vec![ + Schema::Complex(decimal), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + },], + attributes: Default::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "LongList", + "aliases": ["LinkedLongs"], + "fields" : [ + {"name": "value", "type": "long"}, + {"name": "next", "type": ["null", "LongList"]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "LongList", + namespace: None, + doc: None, + aliases: vec!["LinkedLongs"], + fields: vec![ + Field { + name: "value", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + }, + Field { + name: "next", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Ref("LongList")), + ]), + default: None, + } + ], + attributes: Attributes::default(), + })) + ); + + // Recursive schema are not supported + let err = AvroField::try_from(&schema).unwrap_err().to_string(); + assert_eq!(err, "Parser error: Failed to resolve .LongList"); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"id", + "type":[ + "int", + "null" + ] + }, + { + "name":"timestamp_col", + "type":[ + { + "type":"long", + "logicalType":"timestamp-micros" + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "id", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + }, + Field { + name: "timestamp_col", + doc: None, + r#type: Schema::Union(vec![ + Schema::Type(timestamp), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + let codec = AvroField::try_from(&schema).unwrap(); + assert_eq!( + codec.field(), + arrow_schema::Field::new( + "topLevelRecord", + DataType::Struct(Fields::from(vec![ + arrow_schema::Field::new("id", DataType::Int32, true), + arrow_schema::Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + true + ), + ])), + false + ) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc", + "fields": [ + {"name": "clientHash", "type": {"type": "fixed", "name": "MD5", "size": 16}}, + {"name": "clientProtocol", "type": ["null", "string"]}, + {"name": "serverHash", "type": "MD5"}, + {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "HandshakeRequest", + namespace: Some("org.apache.avro.ipc"), + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "clientHash", + doc: None, + r#type: Schema::Complex(ComplexType::Fixed(Fixed { + name: "MD5", + namespace: None, + aliases: vec![], + size: 16, + attributes: Default::default(), + })), + default: None, + }, + Field { + name: "clientProtocol", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]), + default: None, + }, + Field { + name: "serverHash", + doc: None, + r#type: Schema::TypeName(TypeName::Ref("MD5")), + default: None, + }, + Field { + name: "meta", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::Complex(ComplexType::Map(Map { + values: Box::new(Schema::TypeName(TypeName::Primitive( + PrimitiveType::Bytes + ))), + attributes: Default::default(), + })), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + } +} diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml new file mode 100644 index 000000000000..68bfe8ddf732 --- /dev/null +++ b/arrow-buffer/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-buffer" +version = { workspace = true } +description = "Buffer abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_buffer" +path = "src/lib.rs" +bench = false + +[dependencies] +bytes = { version = "1.4" } +num = { version = "0.4", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false } + +[dev-dependencies] +criterion = { version = "0.5", default-features = false } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } + +[build-dependencies] + +[[bench]] +name = "bit_mask" +harness = false + +[[bench]] +name = "i256" +harness = false + +[[bench]] +name = "offset" +harness = false diff --git a/arrow-buffer/benches/bit_mask.rs b/arrow-buffer/benches/bit_mask.rs new file mode 100644 index 000000000000..6907e336a418 --- /dev/null +++ b/arrow-buffer/benches/bit_mask.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::bit_mask::set_bits; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("bit_mask"); + + for offset_write in [0, 5] { + for offset_read in [0, 5] { + for len in [1, 17, 65] { + for datum in [0u8, 0xADu8] { + let x = (offset_write, offset_read, len, datum); + group.bench_with_input( + BenchmarkId::new( + "set_bits", + format!( + "offset_write_{}_offset_read_{}_len_{}_datum_{}", + x.0, x.1, x.2, x.3 + ), + ), + &x, + |b, &x| { + b.iter(|| { + set_bits( + black_box(&mut [0u8; 9]), + black_box(&[x.3; 9]), + black_box(x.0), + black_box(x.1), + black_box(x.2), + ) + }); + }, + ); + } + } + } + } + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs new file mode 100644 index 000000000000..ebb45e793bd0 --- /dev/null +++ b/arrow-buffer/benches/i256.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::i256; +use criterion::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::str::FromStr; + +const SIZE: usize = 1024; + +fn criterion_benchmark(c: &mut Criterion) { + let numbers = vec![ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(1233456789), + i256::from_i128(-1233456789), + i256::from_i128(i128::MAX), + i256::from_i128(i128::MIN), + i256::MIN, + i256::MAX, + ]; + + for number in numbers { + let t = black_box(number.to_string()); + c.bench_function(&format!("i256_parse({t})"), |b| { + b.iter(|| i256::from_str(&t).unwrap()); + }); + } + + let mut rng = StdRng::seed_from_u64(42); + + let numerators: Vec<_> = (0..SIZE) + .map(|_| { + let high = rng.gen_range(1000..i128::MAX); + let low = rng.gen(); + i256::from_parts(low, high) + }) + .collect(); + + let divisors: Vec<_> = numerators + .iter() + .map(|n| { + let quotient = rng.gen_range(1..100_i32); + n.wrapping_div(i256::from(quotient)) + }) + .collect(); + + c.bench_function("i256_div_rem small quotient", |b| { + b.iter(|| { + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); + } + }); + }); + + let divisors: Vec<_> = (0..SIZE) + .map(|_| i256::from(rng.gen_range(1..100_i32))) + .collect(); + + c.bench_function("i256_div_rem small divisor", |b| { + b.iter(|| { + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); + } + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-buffer/benches/offset.rs b/arrow-buffer/benches/offset.rs new file mode 100644 index 000000000000..1aea5024fbd1 --- /dev/null +++ b/arrow-buffer/benches/offset.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{OffsetBuffer, OffsetBufferBuilder}; +use criterion::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +const SIZE: usize = 1024; + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(42); + let lengths: Vec = black_box((0..SIZE).map(|_| rng.gen_range(0..40)).collect()); + + c.bench_function("OffsetBuffer::from_lengths", |b| { + b.iter(|| OffsetBuffer::::from_lengths(lengths.iter().copied())); + }); + + c.bench_function("OffsetBufferBuilder::push_length", |b| { + b.iter(|| { + let mut builder = OffsetBufferBuilder::::new(lengths.len()); + lengths.iter().for_each(|x| builder.push_length(*x)); + builder.finish() + }); + }); + + let offsets = OffsetBuffer::::from_lengths(lengths.iter().copied()).into_inner(); + + c.bench_function("OffsetBuffer::new", |b| { + b.iter(|| OffsetBuffer::new(black_box(offsets.clone()))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow/src/alloc/alignment.rs b/arrow-buffer/src/alloc/alignment.rs similarity index 91% rename from arrow/src/alloc/alignment.rs rename to arrow-buffer/src/alloc/alignment.rs index 1bd15c54b990..de8a491c824a 100644 --- a/arrow/src/alloc/alignment.rs +++ b/arrow-buffer/src/alloc/alignment.rs @@ -18,7 +18,7 @@ // NOTE: Below code is written for spatial/temporal prefetcher optimizations. Memory allocation // should align well with usage pattern of cache access and block sizes on layers of storage levels from // registers to non-volatile memory. These alignments are all cache aware alignments incorporated -// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimicks Intel TBB's +// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimics Intel TBB's // cache_aligned_allocator which exploits cache locality and minimizes prefetch signals // resulting in less round trip time between the layers of storage. // For further info: https://software.intel.com/en-us/node/506094 @@ -80,15 +80,6 @@ pub const ALIGNMENT: usize = 1 << 5; #[cfg(target_arch = "sparc64")] pub const ALIGNMENT: usize = 1 << 6; -// On ARM cache line sizes are fixed. both v6 and v7. -// Need to add board specific or platform specific things later. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "thumbv6")] -pub const ALIGNMENT: usize = 1 << 5; -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "thumbv7")] -pub const ALIGNMENT: usize = 1 << 5; - // Operating Systems cache size determines this. // Currently no way to determine this without runtime inference. /// Cache and allocation multiple alignment size @@ -107,9 +98,6 @@ pub const ALIGNMENT: usize = 1 << 5; // If you have smaller data with less padded functionality then use 32 with force option. // - https://devtalk.nvidia.com/default/topic/803600/variable-cache-line-width-/ /// Cache and allocation multiple alignment size -#[cfg(target_arch = "nvptx")] -pub const ALIGNMENT: usize = 1 << 7; -/// Cache and allocation multiple alignment size #[cfg(target_arch = "nvptx64")] pub const ALIGNMENT: usize = 1 << 7; @@ -117,3 +105,7 @@ pub const ALIGNMENT: usize = 1 << 7; /// Cache and allocation multiple alignment size #[cfg(target_arch = "aarch64")] pub const ALIGNMENT: usize = 1 << 6; + +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "loongarch64")] +pub const ALIGNMENT: usize = 1 << 6; diff --git a/arrow-buffer/src/alloc/mod.rs b/arrow-buffer/src/alloc/mod.rs new file mode 100644 index 000000000000..d7108d2969bb --- /dev/null +++ b/arrow-buffer/src/alloc/mod.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the low-level [`Allocation`] API for shared memory regions + +use std::alloc::Layout; +use std::fmt::{Debug, Formatter}; +use std::panic::RefUnwindSafe; +use std::sync::Arc; + +mod alignment; + +pub use alignment::ALIGNMENT; + +/// The owner of an allocation. +/// The trait implementation is responsible for dropping the allocations once no more references exist. +pub trait Allocation: RefUnwindSafe + Send + Sync {} + +impl Allocation for T {} + +/// Mode of deallocating memory regions +pub(crate) enum Deallocation { + /// An allocation using [`std::alloc`] + Standard(Layout), + /// An allocation from an external source like the FFI interface + /// Deallocation will happen on `Allocation::drop` + /// The size of the allocation is tracked here separately only + /// for memory usage reporting via `Array::get_buffer_memory_size` + Custom(Arc, usize), +} + +impl Debug for Deallocation { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Deallocation::Standard(layout) => { + write!(f, "Deallocation::Standard {layout:?}") + } + Deallocation::Custom(_, size) => { + write!(f, "Deallocation::Custom {{ capacity: {size} }}") + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::alloc::Deallocation; + + #[test] + fn test_size_of_deallocation() { + assert_eq!( + std::mem::size_of::(), + 3 * std::mem::size_of::() + ); + } +} diff --git a/arrow-buffer/src/arith.rs b/arrow-buffer/src/arith.rs new file mode 100644 index 000000000000..a576b2677131 --- /dev/null +++ b/arrow-buffer/src/arith.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Derives `std::ops::$t` for `$ty` calling `$wrapping` or `$checked` variants +/// based on if debug_assertions enabled +macro_rules! derive_arith { + ($ty:ty, $t:ident, $t_assign:ident, $op:ident, $op_assign:ident, $wrapping:ident, $checked:ident) => { + impl std::ops::$t for $ty { + type Output = $ty; + + #[cfg(debug_assertions)] + fn $op(self, rhs: Self) -> Self::Output { + self.$checked(rhs) + .expect(concat!(stringify!($ty), " overflow")) + } + + #[cfg(not(debug_assertions))] + fn $op(self, rhs: Self) -> Self::Output { + self.$wrapping(rhs) + } + } + + impl std::ops::$t_assign for $ty { + #[cfg(debug_assertions)] + fn $op_assign(&mut self, rhs: Self) { + *self = self + .$checked(rhs) + .expect(concat!(stringify!($ty), " overflow")); + } + + #[cfg(not(debug_assertions))] + fn $op_assign(&mut self, rhs: Self) { + *self = self.$wrapping(rhs); + } + } + + impl<'a> std::ops::$t<$ty> for &'a $ty { + type Output = $ty; + + fn $op(self, rhs: $ty) -> Self::Output { + (*self).$op(rhs) + } + } + + impl<'a> std::ops::$t<&'a $ty> for $ty { + type Output = $ty; + + fn $op(self, rhs: &'a $ty) -> Self::Output { + self.$op(*rhs) + } + } + + impl<'a, 'b> std::ops::$t<&'b $ty> for &'a $ty { + type Output = $ty; + + fn $op(self, rhs: &'b $ty) -> Self::Output { + (*self).$op(*rhs) + } + } + }; +} + +pub(crate) use derive_arith; diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs new file mode 100644 index 000000000000..8a75dad0ffd8 --- /dev/null +++ b/arrow-buffer/src/bigint/div.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! N-digit division +//! +//! Implementation heavily inspired by [uint] +//! +//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 + +/// Unsigned, little-endian, n-digit division with remainder +/// +/// # Panics +/// +/// Panics if divisor is zero +pub fn div_rem(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) { + let numerator_bits = bits(numerator); + let divisor_bits = bits(divisor); + assert_ne!(divisor_bits, 0, "division by zero"); + + if numerator_bits < divisor_bits { + return ([0; N], *numerator); + } + + if divisor_bits <= 64 { + return div_rem_small(numerator, divisor[0]); + } + + let numerator_words = (numerator_bits + 63) / 64; + let divisor_words = (divisor_bits + 63) / 64; + let n = divisor_words; + let m = numerator_words - divisor_words; + + div_rem_knuth(numerator, divisor, n, m) +} + +/// Return the least number of bits needed to represent the number +fn bits(arr: &[u64]) -> usize { + for (idx, v) in arr.iter().enumerate().rev() { + if *v > 0 { + return 64 - v.leading_zeros() as usize + 64 * idx; + } + } + 0 +} + +/// Division of numerator by a u64 divisor +fn div_rem_small(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) { + let mut rem = 0u64; + let mut numerator = *numerator; + numerator.iter_mut().rev().for_each(|d| { + let (q, r) = div_rem_word(rem, *d, divisor); + *d = q; + rem = r; + }); + + let mut rem_padded = [0; N]; + rem_padded[0] = rem; + (numerator, rem_padded) +} + +/// Use Knuth Algorithm D to compute `numerator / divisor` returning the +/// quotient and remainder +/// +/// `n` is the number of non-zero 64-bit words in `divisor` +/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and +/// therefore the number of words in the quotient +/// +/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html) +fn div_rem_knuth( + numerator: &[u64; N], + divisor: &[u64; N], + n: usize, + m: usize, +) -> ([u64; N], [u64; N]) { + assert!(n + m <= N); + + // The algorithm works by incrementally generating guesses `q_hat`, for the next digit + // of the quotient, starting from the most significant digit. + // + // This relies on the property that for any `q_hat` where + // + // (q_hat << (j * 64)) * divisor <= numerator` + // + // We can set + // + // q += q_hat << (j * 64) + // numerator -= (q_hat << (j * 64)) * divisor + // + // And then iterate until `numerator < divisor` + + // We normalize the divisor so that the highest bit in the highest digit of the + // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from + // the correct value for q[j] + let shift = divisor[n - 1].leading_zeros(); + // As the shift is computed based on leading zeros, don't need to perform full_shl + let divisor = shl_word(divisor, shift); + // numerator may have fewer leading zeros than divisor, so must add another digit + let mut numerator = full_shl(numerator, shift); + + // The two most significant digits of the divisor + let b0 = divisor[n - 1]; + let b1 = divisor[n - 2]; + + let mut q = [0; N]; + + for j in (0..=m).rev() { + let a0 = numerator[j + n]; + let a1 = numerator[j + n - 1]; + + let mut q_hat = if a0 < b0 { + // The first estimate is [a1, a0] / b0, it may be too large by at most 2 + let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0); + + // r_hat = [a1, a0] - q_hat * b0 + // + // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0] + // which can only be less or equal to the current q_hat + // + // q_hat is too large if: + // [a2,a1,a0] < q_hat * [b1,b0] + // [a2,r_hat] < q_hat * b1 + let a2 = numerator[j + n - 2]; + loop { + let r = u128::from(q_hat) * u128::from(b1); + let (lo, hi) = (r as u64, (r >> 64) as u64); + if (hi, lo) <= (r_hat, a2) { + break; + } + + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(b0); + r_hat = new_r_hat; + + if overflow { + break; + } + } + q_hat + } else { + u64::MAX + }; + + // q_hat is now either the correct quotient digit, or in rare cases 1 too large + + // Compute numerator -= (q_hat * divisor) << (j * 64) + let q_hat_v = full_mul_u64(&divisor, q_hat); + let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]); + + // If underflow, q_hat was too large by 1 + if c { + // Reduce q_hat by 1 + q_hat -= 1; + + // Add back one multiple of divisor + let c = add_assign(&mut numerator[j..], &divisor[..n]); + numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c)); + } + + // q_hat is the correct value for q[j] + q[j] = q_hat; + } + + // The remainder is what is left in numerator, with the initial normalization shl reversed + let remainder = full_shr(&numerator, shift); + (q, remainder) +} + +/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder +/// +/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit +/// into a 64-bit integer +fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) { + debug_assert!(hi < divisor); + debug_assert_ne!(divisor, 0); + + // LLVM fails to use the div instruction as it is not able to prove + // that hi < divisor, and therefore the result will fit into 64-bits + #[cfg(all(target_arch = "x86_64", not(miri)))] + unsafe { + let mut quot = lo; + let mut rem = hi; + std::arch::asm!( + "div {divisor}", + divisor = in(reg) divisor, + inout("rax") quot, + inout("rdx") rem, + options(pure, nomem, nostack) + ); + (quot, rem) + } + #[cfg(any(not(target_arch = "x86_64"), miri))] + { + let x = (u128::from(hi) << 64) + u128::from(lo); + let y = u128::from(divisor); + ((x / y) as u64, (x % y) as u64) + } +} + +/// Perform `a += b` +fn add_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_add) +} + +/// Perform `a -= b` +fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_sub) +} + +/// Converts an overflowing binary operation on scalars to one on slices +fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool { + let mut c = false; + a.iter_mut().zip(b.iter()).for_each(|(x, y)| { + let (res1, overflow1) = y.overflowing_add(u64::from(c)); + let (res2, overflow2) = binop(*x, res1); + *x = res2; + c = overflow1 || overflow2; + }); + c +} + +/// Widening multiplication of an N-digit array with a u64 +fn full_mul_u64(a: &[u64; N], b: u64) -> ArrayPlusOne { + let mut carry = 0; + let mut out = [0; N]; + out.iter_mut().zip(a).for_each(|(o, v)| { + let r = *v as u128 * b as u128 + carry as u128; + *o = r as u64; + carry = (r >> 64) as u64; + }); + ArrayPlusOne(out, carry) +} + +/// Left shift of an N-digit array by at most 63 bits +fn shl_word(v: &[u64; N], shift: u32) -> [u64; N] { + full_shl(v, shift).0 +} + +/// Widening left shift of an N-digit array by at most 63 bits +fn full_shl(v: &[u64; N], shift: u32) -> ArrayPlusOne { + debug_assert!(shift < 64); + if shift == 0 { + return ArrayPlusOne(*v, 0); + } + let mut out = [0u64; N]; + out[0] = v[0] << shift; + for i in 1..N { + out[i] = v[i - 1] >> (64 - shift) | v[i] << shift + } + let carry = v[N - 1] >> (64 - shift); + ArrayPlusOne(out, carry) +} + +/// Narrowing right shift of an (N+1)-digit array by at most 63 bits +fn full_shr(a: &ArrayPlusOne, shift: u32) -> [u64; N] { + debug_assert!(shift < 64); + if shift == 0 { + return a.0; + } + let mut out = [0; N]; + for i in 0..N - 1 { + out[i] = a[i] >> shift | a[i + 1] << (64 - shift) + } + out[N - 1] = a[N - 1] >> shift; + out +} + +/// An array of N + 1 elements +/// +/// This is a hack around lack of support for const arithmetic +#[repr(C)] +struct ArrayPlusOne([T; N], T); + +impl std::ops::Deref for ArrayPlusOne { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + let x = self as *const Self; + unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } + } +} + +impl std::ops::DerefMut for ArrayPlusOne { + fn deref_mut(&mut self) -> &mut Self::Target { + let x = self as *mut Self; + unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } + } +} diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs new file mode 100644 index 000000000000..f5fab75dc5ef --- /dev/null +++ b/arrow-buffer/src/bigint/mod.rs @@ -0,0 +1,1267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arith::derive_arith; +use crate::bigint::div::div_rem; +use num::cast::AsPrimitive; +use num::{BigInt, FromPrimitive, ToPrimitive}; +use std::cmp::Ordering; +use std::num::ParseIntError; +use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; +use std::str::FromStr; + +mod div; + +/// An opaque error similar to [`std::num::ParseIntError`] +#[derive(Debug)] +pub struct ParseI256Error {} + +impl From for ParseI256Error { + fn from(_: ParseIntError) -> Self { + Self {} + } +} + +impl std::fmt::Display for ParseI256Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to parse as i256") + } +} +impl std::error::Error for ParseI256Error {} + +/// Error returned by i256::DivRem +enum DivRemError { + /// Division by zero + DivideByZero, + /// Division overflow + DivideOverflow, +} + +/// A signed 256-bit integer +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Default, Eq, PartialEq, Hash)] +#[repr(C)] +pub struct i256 { + low: u128, + high: i128, +} + +impl std::fmt::Debug for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl std::fmt::Display for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", BigInt::from_signed_bytes_le(&self.to_le_bytes())) + } +} + +impl FromStr for i256 { + type Err = ParseI256Error; + + fn from_str(s: &str) -> Result { + // i128 can store up to 38 decimal digits + if s.len() <= 38 { + return Ok(Self::from_i128(i128::from_str(s)?)); + } + + let (negative, s) = match s.as_bytes()[0] { + b'-' => (true, &s[1..]), + b'+' => (false, &s[1..]), + _ => (false, s), + }; + + // Trim leading 0s + let s = s.trim_start_matches('0'); + if s.is_empty() { + return Ok(i256::ZERO); + } + + if !s.as_bytes()[0].is_ascii_digit() { + // Ensures no duplicate sign + return Err(ParseI256Error {}); + } + + parse_impl(s, negative) + } +} + +impl From for i256 { + fn from(value: i8) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i16) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i32) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i64) -> Self { + Self::from_i128(value.into()) + } +} + +/// Parse `s` with any sign and leading 0s removed +fn parse_impl(s: &str, negative: bool) -> Result { + if s.len() <= 38 { + let low = i128::from_str(s)?; + return Ok(match negative { + true => i256::from_parts(low.neg() as _, -1), + false => i256::from_parts(low as _, 0), + }); + } + + let split = s.len() - 38; + if !s.as_bytes()[split].is_ascii_digit() { + // Ensures not splitting codepoint and no sign + return Err(ParseI256Error {}); + } + let (hs, ls) = s.split_at(split); + + let mut low = i128::from_str(ls)?; + let high = parse_impl(hs, negative)?; + + if negative { + low = -low; + } + + let low = i256::from_i128(low); + + high.checked_mul(i256::from_i128(10_i128.pow(38))) + .and_then(|high| high.checked_add(low)) + .ok_or(ParseI256Error {}) +} + +impl PartialOrd for i256 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for i256 { + fn cmp(&self, other: &Self) -> Ordering { + // This is 25x faster than using a variable length encoding such + // as BigInt as it avoids allocation and branching + self.high.cmp(&other.high).then(self.low.cmp(&other.low)) + } +} + +impl i256 { + /// The additive identity for this integer type, i.e. `0`. + pub const ZERO: Self = i256 { low: 0, high: 0 }; + + /// The multiplicative identity for this integer type, i.e. `1`. + pub const ONE: Self = i256 { low: 1, high: 0 }; + + /// The multiplicative inverse for this integer type, i.e. `-1`. + pub const MINUS_ONE: Self = i256 { + low: u128::MAX, + high: -1, + }; + + /// The maximum value that can be represented by this integer type + pub const MAX: Self = i256 { + low: u128::MAX, + high: i128::MAX, + }; + + /// The minimum value that can be represented by this integer type + pub const MIN: Self = i256 { + low: u128::MIN, + high: i128::MIN, + }; + + /// Create an integer value from its representation as a byte array in little-endian. + #[inline] + pub const fn from_le_bytes(b: [u8; 32]) -> Self { + let (low, high) = split_array(b); + Self { + high: i128::from_le_bytes(high), + low: u128::from_le_bytes(low), + } + } + + /// Create an integer value from its representation as a byte array in big-endian. + #[inline] + pub const fn from_be_bytes(b: [u8; 32]) -> Self { + let (high, low) = split_array(b); + Self { + high: i128::from_be_bytes(high), + low: u128::from_be_bytes(low), + } + } + + /// Create an `i256` value from a 128-bit value. + pub const fn from_i128(v: i128) -> Self { + Self::from_parts(v as u128, v >> 127) + } + + /// Create an integer value from its representation as string. + #[inline] + pub fn from_string(value_str: &str) -> Option { + value_str.parse().ok() + } + + /// Create an optional i256 from the provided `f64`. Returning `None` + /// if overflow occurred + pub fn from_f64(v: f64) -> Option { + BigInt::from_f64(v).and_then(|i| { + let (integer, overflow) = i256::from_bigint_with_overflow(i); + if overflow { + None + } else { + Some(integer) + } + }) + } + + /// Create an i256 from the provided low u128 and high i128 + #[inline] + pub const fn from_parts(low: u128, high: i128) -> Self { + Self { low, high } + } + + /// Returns this `i256` as a low u128 and high i128 + pub const fn to_parts(self) -> (u128, i128) { + (self.low, self.high) + } + + /// Converts this `i256` into an `i128` returning `None` if this would result + /// in truncation/overflow + pub fn to_i128(self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + (high_negative == low_negative && high_valid).then_some(self.low as i128) + } + + /// Wraps this `i256` into an `i128` + pub fn as_i128(self) -> i128 { + self.low as i128 + } + + /// Return the memory representation of this integer as a byte array in little-endian byte order. + #[inline] + pub const fn to_le_bytes(self) -> [u8; 32] { + let low = self.low.to_le_bytes(); + let high = self.high.to_le_bytes(); + let mut t = [0; 32]; + let mut i = 0; + while i != 16 { + t[i] = low[i]; + t[i + 16] = high[i]; + i += 1; + } + t + } + + /// Return the memory representation of this integer as a byte array in big-endian byte order. + #[inline] + pub const fn to_be_bytes(self) -> [u8; 32] { + let low = self.low.to_be_bytes(); + let high = self.high.to_be_bytes(); + let mut t = [0; 32]; + let mut i = 0; + while i != 16 { + t[i] = high[i]; + t[i + 16] = low[i]; + i += 1; + } + t + } + + /// Create an i256 from the provided [`BigInt`] returning a bool indicating + /// if overflow occurred + fn from_bigint_with_overflow(v: BigInt) -> (Self, bool) { + let v_bytes = v.to_signed_bytes_le(); + match v_bytes.len().cmp(&32) { + Ordering::Less => { + let mut bytes = if num::Signed::is_negative(&v) { + [255_u8; 32] + } else { + [0; 32] + }; + bytes[0..v_bytes.len()].copy_from_slice(&v_bytes[..v_bytes.len()]); + (Self::from_le_bytes(bytes), false) + } + Ordering::Equal => (Self::from_le_bytes(v_bytes.try_into().unwrap()), false), + Ordering::Greater => (Self::from_le_bytes(v_bytes[..32].try_into().unwrap()), true), + } + } + + /// Computes the absolute value of this i256 + #[inline] + pub fn wrapping_abs(self) -> Self { + // -1 if negative, otherwise 0 + let sa = self.high >> 127; + let sa = Self::from_parts(sa as u128, sa); + + // Inverted if negative + Self::from_parts(self.low ^ sa.low, self.high ^ sa.high).wrapping_sub(sa) + } + + /// Computes the absolute value of this i256 returning `None` if `Self == Self::MIN` + #[inline] + pub fn checked_abs(self) -> Option { + (self != Self::MIN).then(|| self.wrapping_abs()) + } + + /// Negates this i256 + #[inline] + pub fn wrapping_neg(self) -> Self { + Self::from_parts(!self.low, !self.high).wrapping_add(i256::ONE) + } + + /// Negates this i256 returning `None` if `Self == Self::MIN` + #[inline] + pub fn checked_neg(self) -> Option { + (self != Self::MIN).then(|| self.wrapping_neg()) + } + + /// Performs wrapping addition + #[inline] + pub fn wrapping_add(self, other: Self) -> Self { + let (low, carry) = self.low.overflowing_add(other.low); + let high = self.high.wrapping_add(other.high).wrapping_add(carry as _); + Self { low, high } + } + + /// Performs checked addition + #[inline] + pub fn checked_add(self, other: Self) -> Option { + let r = self.wrapping_add(other); + ((other.is_negative() && r < self) || (!other.is_negative() && r >= self)).then_some(r) + } + + /// Performs wrapping subtraction + #[inline] + pub fn wrapping_sub(self, other: Self) -> Self { + let (low, carry) = self.low.overflowing_sub(other.low); + let high = self.high.wrapping_sub(other.high).wrapping_sub(carry as _); + Self { low, high } + } + + /// Performs checked subtraction + #[inline] + pub fn checked_sub(self, other: Self) -> Option { + let r = self.wrapping_sub(other); + ((other.is_negative() && r > self) || (!other.is_negative() && r <= self)).then_some(r) + } + + /// Performs wrapping multiplication + #[inline] + pub fn wrapping_mul(self, other: Self) -> Self { + let (low, high) = mulx(self.low, other.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = self.high.wrapping_mul(other.low as i128); + let lh = (self.low as i128).wrapping_mul(other.high); + + Self { + low, + high: (high as i128).wrapping_add(hl).wrapping_add(lh), + } + } + + /// Performs checked multiplication + #[inline] + pub fn checked_mul(self, other: Self) -> Option { + if self == i256::ZERO || other == i256::ZERO { + return Some(i256::ZERO); + } + + // Shift sign bit down to construct mask of all set bits if negative + let l_sa = self.high >> 127; + let r_sa = other.high >> 127; + let out_sa = (l_sa ^ r_sa) as u128; + + // Compute absolute values + let l_abs = self.wrapping_abs(); + let r_abs = other.wrapping_abs(); + + // Overflow if both high parts are non-zero + if l_abs.high != 0 && r_abs.high != 0 { + return None; + } + + // Perform checked multiplication on absolute values + let (low, high) = mulx(l_abs.low, r_abs.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = (l_abs.high as u128).checked_mul(r_abs.low)?; + let lh = l_abs.low.checked_mul(r_abs.high as u128)?; + + let high = high.checked_add(hl)?.checked_add(lh)?; + + // Reverse absolute value, if necessary + let (low, c) = (low ^ out_sa).overflowing_sub(out_sa); + let high = (high ^ out_sa).wrapping_sub(out_sa).wrapping_sub(c as u128) as i128; + + // Check for overflow in final conversion + (high.is_negative() == (self.is_negative() ^ other.is_negative())) + .then_some(Self { low, high }) + } + + /// Division operation, returns (quotient, remainder). + /// This basically implements [Long division]: `` + #[inline] + fn div_rem(self, other: Self) -> Result<(Self, Self), DivRemError> { + if other == Self::ZERO { + return Err(DivRemError::DivideByZero); + } + if other == Self::MINUS_ONE && self == Self::MIN { + return Err(DivRemError::DivideOverflow); + } + + let a = self.wrapping_abs(); + let b = other.wrapping_abs(); + + let (div, rem) = div_rem(&a.as_digits(), &b.as_digits()); + let div = Self::from_digits(div); + let rem = Self::from_digits(rem); + + Ok(( + if self.is_negative() == other.is_negative() { + div + } else { + div.wrapping_neg() + }, + if self.is_negative() { + rem.wrapping_neg() + } else { + rem + }, + )) + } + + /// Interpret this [`i256`] as 4 `u64` digits, least significant first + fn as_digits(self) -> [u64; 4] { + [ + self.low as u64, + (self.low >> 64) as u64, + self.high as u64, + (self.high as u128 >> 64) as u64, + ] + } + + /// Interpret 4 `u64` digits, least significant first, as a [`i256`] + fn from_digits(digits: [u64; 4]) -> Self { + Self::from_parts( + digits[0] as u128 | (digits[1] as u128) << 64, + digits[2] as i128 | (digits[3] as i128) << 64, + ) + } + + /// Performs wrapping division + #[inline] + pub fn wrapping_div(self, other: Self) -> Self { + match self.div_rem(other) { + Ok((v, _)) => v, + Err(DivRemError::DivideByZero) => panic!("attempt to divide by zero"), + Err(_) => Self::MIN, + } + } + + /// Performs checked division + #[inline] + pub fn checked_div(self, other: Self) -> Option { + self.div_rem(other).map(|(v, _)| v).ok() + } + + /// Performs wrapping remainder + #[inline] + pub fn wrapping_rem(self, other: Self) -> Self { + match self.div_rem(other) { + Ok((_, v)) => v, + Err(DivRemError::DivideByZero) => panic!("attempt to divide by zero"), + Err(_) => Self::ZERO, + } + } + + /// Performs checked remainder + #[inline] + pub fn checked_rem(self, other: Self) -> Option { + self.div_rem(other).map(|(_, v)| v).ok() + } + + /// Performs checked exponentiation + #[inline] + pub fn checked_pow(self, mut exp: u32) -> Option { + if exp == 0 { + return Some(i256::from_i128(1)); + } + + let mut base = self; + let mut acc: Self = i256::from_i128(1); + + while exp > 1 { + if (exp & 1) == 1 { + acc = acc.checked_mul(base)?; + } + exp /= 2; + base = base.checked_mul(base)?; + } + // since exp!=0, finally the exp must be 1. + // Deal with the final bit of the exponent separately, since + // squaring the base afterwards is not necessary and may cause a + // needless overflow. + acc.checked_mul(base) + } + + /// Performs wrapping exponentiation + #[inline] + pub fn wrapping_pow(self, mut exp: u32) -> Self { + if exp == 0 { + return i256::from_i128(1); + } + + let mut base = self; + let mut acc: Self = i256::from_i128(1); + + while exp > 1 { + if (exp & 1) == 1 { + acc = acc.wrapping_mul(base); + } + exp /= 2; + base = base.wrapping_mul(base); + } + + // since exp!=0, finally the exp must be 1. + // Deal with the final bit of the exponent separately, since + // squaring the base afterwards is not necessary and may cause a + // needless overflow. + acc.wrapping_mul(base) + } + + /// Returns a number [`i256`] representing sign of this [`i256`]. + /// + /// 0 if the number is zero + /// 1 if the number is positive + /// -1 if the number is negative + pub const fn signum(self) -> Self { + if self.is_positive() { + i256::ONE + } else if self.is_negative() { + i256::MINUS_ONE + } else { + i256::ZERO + } + } + + /// Returns `true` if this [`i256`] is negative + #[inline] + pub const fn is_negative(self) -> bool { + self.high.is_negative() + } + + /// Returns `true` if this [`i256`] is positive + pub const fn is_positive(self) -> bool { + self.high.is_positive() || self.high == 0 && self.low != 0 + } +} + +/// Temporary workaround due to lack of stable const array slicing +/// See +const fn split_array(vals: [u8; N]) -> ([u8; M], [u8; M]) { + let mut a = [0; M]; + let mut b = [0; M]; + let mut i = 0; + while i != M { + a[i] = vals[i]; + b[i] = vals[i + M]; + i += 1; + } + (a, b) +} + +/// Performs an unsigned multiplication of `a * b` returning a tuple of +/// `(low, high)` where `low` contains the lower 128-bits of the result +/// and `high` the higher 128-bits +/// +/// This mirrors the x86 mulx instruction but for 128-bit types +#[inline] +fn mulx(a: u128, b: u128) -> (u128, u128) { + let split = |a: u128| (a & (u64::MAX as u128), a >> 64); + + const MASK: u128 = u64::MAX as _; + + let (a_low, a_high) = split(a); + let (b_low, b_high) = split(b); + + // Carry stores the upper 64-bits of low and lower 64-bits of high + let (mut low, mut carry) = split(a_low * b_low); + carry += a_high * b_low; + + // Update low and high with corresponding parts of carry + low += carry << 64; + let mut high = carry >> 64; + + // Update carry with overflow from low + carry = low >> 64; + low &= MASK; + + // Perform multiply including overflow from low + carry += b_high * a_low; + + // Update low and high with values from carry + low += carry << 64; + high += carry >> 64; + + // Perform 4th multiplication + high += a_high * b_high; + + (low, high) +} + +derive_arith!( + i256, + Add, + AddAssign, + add, + add_assign, + wrapping_add, + checked_add +); +derive_arith!( + i256, + Sub, + SubAssign, + sub, + sub_assign, + wrapping_sub, + checked_sub +); +derive_arith!( + i256, + Mul, + MulAssign, + mul, + mul_assign, + wrapping_mul, + checked_mul +); +derive_arith!( + i256, + Div, + DivAssign, + div, + div_assign, + wrapping_div, + checked_div +); +derive_arith!( + i256, + Rem, + RemAssign, + rem, + rem_assign, + wrapping_rem, + checked_rem +); + +impl Neg for i256 { + type Output = i256; + + #[cfg(debug_assertions)] + fn neg(self) -> Self::Output { + self.checked_neg().expect("i256 overflow") + } + + #[cfg(not(debug_assertions))] + fn neg(self) -> Self::Output { + self.wrapping_neg() + } +} + +impl BitAnd for i256 { + type Output = i256; + + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + Self { + low: self.low & rhs.low, + high: self.high & rhs.high, + } + } +} + +impl BitOr for i256 { + type Output = i256; + + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + Self { + low: self.low | rhs.low, + high: self.high | rhs.high, + } + } +} + +impl BitXor for i256 { + type Output = i256; + + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + Self { + low: self.low ^ rhs.low, + high: self.high ^ rhs.high, + } + } +} + +impl Shl for i256 { + type Output = i256; + + #[inline] + fn shl(self, rhs: u8) -> Self::Output { + if rhs == 0 { + self + } else if rhs < 128 { + Self { + high: self.high << rhs | (self.low >> (128 - rhs)) as i128, + low: self.low << rhs, + } + } else { + Self { + high: (self.low << (rhs - 128)) as i128, + low: 0, + } + } + } +} + +impl Shr for i256 { + type Output = i256; + + #[inline] + fn shr(self, rhs: u8) -> Self::Output { + if rhs == 0 { + self + } else if rhs < 128 { + Self { + high: self.high >> rhs, + low: self.low >> rhs | ((self.high as u128) << (128 - rhs)), + } + } else { + Self { + high: self.high >> 127, + low: (self.high >> (rhs - 128)) as u128, + } + } + } +} + +macro_rules! define_as_primitive { + ($native_ty:ty) => { + impl AsPrimitive for $native_ty { + fn as_(self) -> i256 { + i256::from_i128(self as i128) + } + } + }; +} + +define_as_primitive!(i8); +define_as_primitive!(i16); +define_as_primitive!(i32); +define_as_primitive!(i64); +define_as_primitive!(u8); +define_as_primitive!(u16); +define_as_primitive!(u32); +define_as_primitive!(u64); + +impl ToPrimitive for i256 { + fn to_i64(&self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + if high_negative == low_negative && high_valid { + let (low_bytes, high_bytes) = split_array(u128::to_le_bytes(self.low)); + let high = i64::from_le_bytes(high_bytes); + let low = i64::from_le_bytes(low_bytes); + + let high_negative = high < 0; + let low_negative = low < 0; + let high_valid = self.high == -1 || self.high == 0; + + (high_negative == low_negative && high_valid).then_some(low) + } else { + None + } + } + + fn to_u64(&self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + if high_negative == low_negative && high_valid { + self.low.to_u64() + } else { + None + } + } +} + +#[cfg(all(test, not(miri)))] // llvm.x86.subborrow.64 not supported by MIRI +mod tests { + use super::*; + use num::Signed; + use rand::{thread_rng, Rng}; + + #[test] + fn test_signed_cmp() { + let a = i256::from_parts(i128::MAX as u128, 12); + let b = i256::from_parts(i128::MIN as u128, 12); + assert!(a < b); + + let a = i256::from_parts(i128::MAX as u128, 12); + let b = i256::from_parts(i128::MIN as u128, -12); + assert!(a > b); + } + + #[test] + fn test_to_i128() { + let vals = [ + BigInt::from_i128(-1).unwrap(), + BigInt::from_i128(i128::MAX).unwrap(), + BigInt::from_i128(i128::MIN).unwrap(), + BigInt::from_u128(u128::MIN).unwrap(), + BigInt::from_u128(u128::MAX).unwrap(), + ]; + + for v in vals { + let (t, overflow) = i256::from_bigint_with_overflow(v.clone()); + assert!(!overflow); + assert_eq!(t.to_i128(), v.to_i128(), "{v} vs {t}"); + } + } + + /// Tests operations against the two provided [`i256`] + fn test_ops(il: i256, ir: i256) { + let bl = BigInt::from_signed_bytes_le(&il.to_le_bytes()); + let br = BigInt::from_signed_bytes_le(&ir.to_le_bytes()); + + // Comparison + assert_eq!(il.cmp(&ir), bl.cmp(&br), "{bl} cmp {br}"); + + // Conversions + assert_eq!(i256::from_le_bytes(il.to_le_bytes()), il); + assert_eq!(i256::from_be_bytes(il.to_be_bytes()), il); + assert_eq!(i256::from_le_bytes(ir.to_le_bytes()), ir); + assert_eq!(i256::from_be_bytes(ir.to_be_bytes()), ir); + + // To i128 + assert_eq!(il.to_i128(), bl.to_i128(), "{bl}"); + assert_eq!(ir.to_i128(), br.to_i128(), "{br}"); + + // Absolute value + let (abs, overflow) = i256::from_bigint_with_overflow(bl.abs()); + assert_eq!(il.wrapping_abs(), abs); + assert_eq!(il.checked_abs().is_none(), overflow); + + let (abs, overflow) = i256::from_bigint_with_overflow(br.abs()); + assert_eq!(ir.wrapping_abs(), abs); + assert_eq!(ir.checked_abs().is_none(), overflow); + + // Negation + let (neg, overflow) = i256::from_bigint_with_overflow(bl.clone().neg()); + assert_eq!(il.wrapping_neg(), neg); + assert_eq!(il.checked_neg().is_none(), overflow); + + // Negation + let (neg, overflow) = i256::from_bigint_with_overflow(br.clone().neg()); + assert_eq!(ir.wrapping_neg(), neg); + assert_eq!(ir.checked_neg().is_none(), overflow); + + // Addition + let actual = il.wrapping_add(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() + br.clone()); + assert_eq!(actual, expected); + + let checked = il.checked_add(ir); + match overflow { + true => assert!(checked.is_none()), + false => assert_eq!(checked, Some(actual)), + } + + // Subtraction + let actual = il.wrapping_sub(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() - br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_sub(ir); + match overflow { + true => assert!(checked.is_none()), + false => assert_eq!(checked, Some(actual), "{bl} - {br} = {expected}"), + } + + // Multiplication + let actual = il.wrapping_mul(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() * br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_mul(ir); + match overflow { + true => assert!( + checked.is_none(), + "{il} * {ir} = {actual} vs {bl} * {br} = {expected}" + ), + false => assert_eq!( + checked, + Some(actual), + "{il} * {ir} = {actual} vs {bl} * {br} = {expected}" + ), + } + + // Division + if ir != i256::ZERO { + let actual = il.wrapping_div(ir); + let expected = bl.clone() / br.clone(); + let checked = il.checked_div(ir); + + if ir == i256::MINUS_ONE && il == i256::MIN { + // BigInt produces an integer over i256::MAX + assert_eq!(actual, i256::MIN); + assert!(checked.is_none()); + } else { + assert_eq!(actual.to_string(), expected.to_string()); + assert_eq!(checked.unwrap().to_string(), expected.to_string()); + } + } else { + // `wrapping_div` panics on division by zero + assert!(il.checked_div(ir).is_none()); + } + + // Remainder + if ir != i256::ZERO { + let actual = il.wrapping_rem(ir); + let expected = bl.clone() % br.clone(); + let checked = il.checked_rem(ir); + + assert_eq!(actual.to_string(), expected.to_string(), "{il} % {ir}"); + + if ir == i256::MINUS_ONE && il == i256::MIN { + assert!(checked.is_none()); + } else { + assert_eq!(checked.unwrap().to_string(), expected.to_string()); + } + } else { + // `wrapping_rem` panics on division by zero + assert!(il.checked_rem(ir).is_none()); + } + + // Exponentiation + for exp in vec![0, 1, 2, 3, 8, 100].into_iter() { + let actual = il.wrapping_pow(exp); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone().pow(exp)); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_pow(exp); + match overflow { + true => assert!( + checked.is_none(), + "{il} ^ {exp} = {actual} vs {bl} * {exp} = {expected}" + ), + false => assert_eq!( + checked, + Some(actual), + "{il} ^ {exp} = {actual} vs {bl} ^ {exp} = {expected}" + ), + } + } + + // Bit operations + let actual = il & ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() & br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il | ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() | br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il ^ ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() ^ br); + assert_eq!(actual.to_string(), expected.to_string()); + + for shift in [0_u8, 1, 4, 126, 128, 129, 254, 255] { + let actual = il << shift; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() << shift); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il >> shift; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() >> shift); + assert_eq!(actual.to_string(), expected.to_string()); + } + } + + #[test] + fn test_i256() { + let candidates = [ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(2), + i256::from_i128(-2), + i256::from_parts(u128::MAX, 1), + i256::from_parts(u128::MAX, -1), + i256::from_parts(0, 1), + i256::from_parts(0, -1), + i256::from_parts(1, -1), + i256::from_parts(1, 1), + i256::from_parts(0, i128::MAX), + i256::from_parts(0, i128::MIN), + i256::from_parts(1, i128::MAX), + i256::from_parts(1, i128::MIN), + i256::from_parts(u128::MAX, i128::MIN), + i256::from_parts(100, 32), + i256::MIN, + i256::MAX, + i256::MIN >> 1, + i256::MAX >> 1, + i256::ONE << 127, + i256::ONE << 128, + i256::ONE << 129, + i256::MINUS_ONE << 127, + i256::MINUS_ONE << 128, + i256::MINUS_ONE << 129, + ]; + + for il in candidates { + for ir in candidates { + test_ops(il, ir) + } + } + } + + #[test] + fn test_signed_ops() { + // signum + assert_eq!(i256::from_i128(1).signum(), i256::ONE); + assert_eq!(i256::from_i128(0).signum(), i256::ZERO); + assert_eq!(i256::from_i128(-0).signum(), i256::ZERO); + assert_eq!(i256::from_i128(-1).signum(), i256::MINUS_ONE); + + // is_positive + assert!(i256::from_i128(1).is_positive()); + assert!(!i256::from_i128(0).is_positive()); + assert!(!i256::from_i128(-0).is_positive()); + assert!(!i256::from_i128(-1).is_positive()); + + // is_negative + assert!(!i256::from_i128(1).is_negative()); + assert!(!i256::from_i128(0).is_negative()); + assert!(!i256::from_i128(-0).is_negative()); + assert!(i256::from_i128(-1).is_negative()); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_i256_fuzz() { + let mut rng = thread_rng(); + + for _ in 0..1000 { + let mut l = [0_u8; 32]; + let len = rng.gen_range(0..32); + l.iter_mut().take(len).for_each(|x| *x = rng.gen()); + + let mut r = [0_u8; 32]; + let len = rng.gen_range(0..32); + r.iter_mut().take(len).for_each(|x| *x = rng.gen()); + + test_ops(i256::from_le_bytes(l), i256::from_le_bytes(r)) + } + } + + #[test] + fn test_i256_to_primitive() { + let a = i256::MAX; + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i128::MAX); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MAX as i128); + assert_eq!(a.to_i64().unwrap(), i64::MAX); + assert_eq!(a.to_u64().unwrap(), i64::MAX as u64); + + let a = i256::from_i128(i64::MAX as i128 + 1); + assert!(a.to_i64().is_none()); + assert_eq!(a.to_u64().unwrap(), i64::MAX as u64 + 1); + + let a = i256::MIN; + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i128::MIN); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MIN as i128); + assert_eq!(a.to_i64().unwrap(), i64::MIN); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MIN as i128 - 1); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + } + + #[test] + fn test_i256_as_i128() { + let a = i256::from_i128(i128::MAX).wrapping_add(i256::from_i128(1)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MIN); + + let a = i256::from_i128(i128::MAX).wrapping_add(i256::from_i128(2)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MIN + 1); + + let a = i256::from_i128(i128::MIN).wrapping_sub(i256::from_i128(1)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MAX); + + let a = i256::from_i128(i128::MIN).wrapping_sub(i256::from_i128(2)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MAX - 1); + } + + #[test] + fn test_string_roundtrip() { + let roundtrip_cases = [ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(123456789), + i256::from_i128(-123456789), + i256::from_i128(i128::MIN), + i256::from_i128(i128::MAX), + i256::MIN, + i256::MAX, + ]; + for case in roundtrip_cases { + let formatted = case.to_string(); + let back: i256 = formatted.parse().unwrap(); + assert_eq!(case, back); + } + } + + #[test] + fn test_from_string() { + let cases = [ + ( + "000000000000000000000000000000000000000011", + Some(i256::from_i128(11)), + ), + ( + "-000000000000000000000000000000000000000011", + Some(i256::from_i128(-11)), + ), + ( + "-0000000000000000000000000000000000000000123456789", + Some(i256::from_i128(-123456789)), + ), + ("-", None), + ("+", None), + ("--1", None), + ("-+1", None), + ("000000000000000000000000000000000000000", Some(i256::ZERO)), + ("0000000000000000000000000000000000000000-11", None), + ("11-1111111111111111111111111111111111111", None), + ( + "115792089237316195423570985008687907853269984665640564039457584007913129639936", + None, + ), + ]; + for (case, expected) in cases { + assert_eq!(i256::from_string(case), expected) + } + } + + #[allow(clippy::op_ref)] + fn test_reference_op(il: i256, ir: i256) { + let r1 = il + ir; + let r2 = &il + ir; + let r3 = il + &ir; + let r4 = &il + &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il - ir; + let r2 = &il - ir; + let r3 = il - &ir; + let r4 = &il - &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il * ir; + let r2 = &il * ir; + let r3 = il * &ir; + let r4 = &il * &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il / ir; + let r2 = &il / ir; + let r3 = il / &ir; + let r4 = &il / &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + } + + #[test] + fn test_i256_reference_op() { + let candidates = [ + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(2), + i256::from_i128(-2), + i256::from_i128(3), + i256::from_i128(-3), + ]; + + for il in candidates { + for ir in candidates { + test_reference_op(il, ir) + } + } + } +} diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs new file mode 100644 index 000000000000..49a75b468dbe --- /dev/null +++ b/arrow-buffer/src/buffer/boolean.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_chunk_iterator::BitChunks; +use crate::bit_iterator::{BitIndexIterator, BitIterator, BitSliceIterator}; +use crate::{ + bit_util, buffer_bin_and, buffer_bin_or, buffer_bin_xor, buffer_unary_not, + BooleanBufferBuilder, Buffer, MutableBuffer, +}; + +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +/// A slice-able [`Buffer`] containing bit-packed booleans +#[derive(Debug, Clone, Eq)] +pub struct BooleanBuffer { + buffer: Buffer, + offset: usize, + len: usize, +} + +impl PartialEq for BooleanBuffer { + fn eq(&self, other: &Self) -> bool { + if self.len != other.len { + return false; + } + + let lhs = self.bit_chunks().iter_padded(); + let rhs = other.bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(a, b)| a == b) + } +} + +impl BooleanBuffer { + /// Create a new [`BooleanBuffer`] from a [`Buffer`], an `offset` and `length` in bits + /// + /// # Panics + /// + /// This method will panic if `buffer` is not large enough + pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + let total_len = offset.saturating_add(len); + let bit_len = buffer.len().saturating_mul(8); + assert!(total_len <= bit_len); + Self { + buffer, + offset, + len, + } + } + + /// Create a new [`BooleanBuffer`] of `length` where all values are `true` + pub fn new_set(length: usize) -> Self { + let mut builder = BooleanBufferBuilder::new(length); + builder.append_n(length, true); + builder.finish() + } + + /// Create a new [`BooleanBuffer`] of `length` where all values are `false` + pub fn new_unset(length: usize) -> Self { + let buffer = MutableBuffer::new_null(length).into_buffer(); + Self { + buffer, + offset: 0, + len: length, + } + } + + /// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BooleanBuffer` + pub fn collect_bool bool>(len: usize, f: F) -> Self { + let buffer = MutableBuffer::collect_bool(len, f); + Self::new(buffer.into(), 0, len) + } + + /// Returns the number of set bits in this buffer + pub fn count_set_bits(&self) -> usize { + self.buffer.count_set_bits_offset(self.offset, self.len) + } + + /// Returns a `BitChunks` instance which can be used to iterate over + /// this buffer's bits in `u64` chunks + #[inline] + pub fn bit_chunks(&self) -> BitChunks { + BitChunks::new(self.values(), self.offset, self.len) + } + + /// Returns `true` if the bit at index `i` is set + /// + /// # Panics + /// + /// Panics if `i >= self.len()` + #[inline] + #[deprecated(note = "use BooleanBuffer::value")] + pub fn is_set(&self, i: usize) -> bool { + self.value(i) + } + + /// Returns the offset of this [`BooleanBuffer`] in bits + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// Returns the length of this [`BooleanBuffer`] in bits + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if this [`BooleanBuffer`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the boolean value at index `i`. + /// + /// # Panics + /// + /// Panics if `i >= self.len()` + #[inline] + pub fn value(&self, idx: usize) -> bool { + assert!(idx < self.len); + unsafe { self.value_unchecked(idx) } + } + + /// Returns the boolean value at index `i`. + /// + /// # Safety + /// This doesn't check bounds, the caller must ensure that index < self.len() + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + unsafe { bit_util::get_bit_raw(self.buffer.as_ptr(), i + self.offset) } + } + + /// Returns the packed values of this [`BooleanBuffer`] not including any offset + #[inline] + pub fn values(&self) -> &[u8] { + &self.buffer + } + + /// Slices this [`BooleanBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced BooleanBuffer cannot exceed the existing length" + ); + Self { + buffer: self.buffer.clone(), + offset: self.offset + offset, + len, + } + } + + /// Returns a [`Buffer`] containing the sliced contents of this [`BooleanBuffer`] + /// + /// Equivalent to `self.buffer.bit_slice(self.offset, self.len)` + pub fn sliced(&self) -> Buffer { + self.buffer.bit_slice(self.offset, self.len) + } + + /// Returns true if this [`BooleanBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + pub fn ptr_eq(&self, other: &Self) -> bool { + self.buffer.as_ptr() == other.buffer.as_ptr() + && self.offset == other.offset + && self.len == other.len + } + + /// Returns the inner [`Buffer`] + #[inline] + pub fn inner(&self) -> &Buffer { + &self.buffer + } + + /// Returns the inner [`Buffer`], consuming self + pub fn into_inner(self) -> Buffer { + self.buffer + } + + /// Returns an iterator over the bits in this [`BooleanBuffer`] + pub fn iter(&self) -> BitIterator<'_> { + self.into_iter() + } + + /// Returns an iterator over the set bit positions in this [`BooleanBuffer`] + pub fn set_indices(&self) -> BitIndexIterator<'_> { + BitIndexIterator::new(self.values(), self.offset, self.len) + } + + /// Returns a [`BitSliceIterator`] yielding contiguous ranges of set bits + pub fn set_slices(&self) -> BitSliceIterator<'_> { + BitSliceIterator::new(self.values(), self.offset, self.len) + } +} + +impl Not for &BooleanBuffer { + type Output = BooleanBuffer; + + fn not(self) -> Self::Output { + BooleanBuffer { + buffer: buffer_unary_not(&self.buffer, self.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitAnd<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitand(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_and(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitOr<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitor(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_or(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitXor<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitxor(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_xor(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl<'a> IntoIterator for &'a BooleanBuffer { + type Item = bool; + type IntoIter = BitIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitIterator::new(self.values(), self.offset, self.len) + } +} + +impl From<&[bool]> for BooleanBuffer { + fn from(value: &[bool]) -> Self { + let mut builder = BooleanBufferBuilder::new(value.len()); + builder.append_slice(value); + builder.finish() + } +} + +impl From> for BooleanBuffer { + fn from(value: Vec) -> Self { + value.as_slice().into() + } +} + +impl FromIterator for BooleanBuffer { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let (hint, _) = iter.size_hint(); + let mut builder = BooleanBufferBuilder::new(hint); + iter.for_each(|b| builder.append(b)); + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_boolean_new() { + let bytes = &[0, 1, 2, 3, 4]; + let buf = Buffer::from(bytes); + let offset = 0; + let len = 24; + + let boolean_buf = BooleanBuffer::new(buf.clone(), offset, len); + assert_eq!(bytes, boolean_buf.values()); + assert_eq!(offset, boolean_buf.offset()); + assert_eq!(len, boolean_buf.len()); + + assert_eq!(2, boolean_buf.count_set_bits()); + assert_eq!(&buf, boolean_buf.inner()); + assert_eq!(buf, boolean_buf.clone().into_inner()); + + assert!(!boolean_buf.is_empty()) + } + + #[test] + fn test_boolean_data_equality() { + let boolean_buf1 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 32); + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 32); + assert_eq!(boolean_buf1, boolean_buf2); + + // slice with same offset and same length should still preserve equality + let boolean_buf3 = boolean_buf1.slice(8, 16); + assert_ne!(boolean_buf1, boolean_buf3); + let boolean_buf4 = boolean_buf1.slice(0, 32); + assert_eq!(boolean_buf1, boolean_buf4); + + // unequal because of different elements + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 0, 2, 3, 4]), 0, 32); + assert_ne!(boolean_buf1, boolean_buf2); + + // unequal because of different length + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 24); + assert_ne!(boolean_buf1, boolean_buf2); + + // ptr_eq + assert!(boolean_buf1.ptr_eq(&boolean_buf1)); + assert!(boolean_buf2.ptr_eq(&boolean_buf2)); + assert!(!boolean_buf1.ptr_eq(&boolean_buf2)); + } + + #[test] + fn test_boolean_slice() { + let bytes = &[0, 3, 2, 6, 2]; + let boolean_buf1 = BooleanBuffer::new(Buffer::from(bytes), 0, 32); + let boolean_buf2 = BooleanBuffer::new(Buffer::from(bytes), 0, 32); + + let boolean_slice1 = boolean_buf1.slice(16, 16); + let boolean_slice2 = boolean_buf2.slice(0, 16); + assert_eq!(boolean_slice1.values(), boolean_slice2.values()); + + assert_eq!(bytes, boolean_slice1.values()); + assert_eq!(16, boolean_slice1.offset); + assert_eq!(16, boolean_slice1.len); + + assert_eq!(bytes, boolean_slice2.values()); + assert_eq!(0, boolean_slice2.offset); + assert_eq!(16, boolean_slice2.len); + } + + #[test] + fn test_boolean_bitand() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 1, 1, 0, 0]), offset, len); + assert_eq!(boolean_buf1 & boolean_buf2, expected); + } + + #[test] + fn test_boolean_bitor() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 1, 1, 1, 0]), offset, len); + assert_eq!(boolean_buf1 | boolean_buf2, expected); + } + + #[test] + fn test_boolean_bitxor() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 0, 0, 1, 0]), offset, len); + assert_eq!(boolean_buf1 ^ boolean_buf2, expected); + } + + #[test] + fn test_boolean_not() { + let offset = 0; + let len = 40; + + let buf = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf = &BooleanBuffer::new(buf, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[255, 254, 254, 255, 255]), offset, len); + assert_eq!(!boolean_buf, expected); + } + + #[test] + fn test_boolean_from_slice_bool() { + let v = [true, false, false]; + let buf = BooleanBuffer::from(&v[..]); + assert_eq!(buf.offset(), 0); + assert_eq!(buf.len(), 3); + assert_eq!(buf.values().len(), 1); + assert!(buf.value(0)); + } +} diff --git a/arrow/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs similarity index 56% rename from arrow/src/buffer/immutable.rs rename to arrow-buffer/src/buffer/immutable.rs index 28042a3817be..8d1a46583fca 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -15,55 +15,91 @@ // specific language governing permissions and limitations // under the License. +use std::alloc::Layout; use std::fmt::Debug; -use std::iter::FromIterator; use std::ptr::NonNull; use std::sync::Arc; -use std::{convert::AsRef, usize}; -use crate::alloc::{Allocation, Deallocation}; +use crate::alloc::{Allocation, Deallocation, ALIGNMENT}; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; -use crate::{bytes::Bytes, datatypes::ArrowNativeType}; +use crate::BufferBuilder; +use crate::{bytes::Bytes, native::ArrowNativeType}; use super::ops::bitwise_unary_op_helper; -use super::MutableBuffer; +use super::{MutableBuffer, ScalarBuffer}; /// Buffer represents a contiguous memory region that can be shared with other buffers and across /// thread boundaries. -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, Debug)] pub struct Buffer { /// the internal byte buffer. data: Arc, - /// The offset into the buffer. - offset: usize, + /// Pointer into `data` valid + /// + /// We store a pointer instead of an offset to avoid pointer arithmetic + /// which causes LLVM to fail to vectorise code correctly + ptr: *const u8, /// Byte length of the buffer. + /// + /// Must be less than or equal to `data.len()` length: usize, } +impl PartialEq for Buffer { + fn eq(&self, other: &Self) -> bool { + self.as_slice().eq(other.as_slice()) + } +} + +impl Eq for Buffer {} + +unsafe impl Send for Buffer where Bytes: Send {} +unsafe impl Sync for Buffer where Bytes: Sync {} + impl Buffer { /// Auxiliary method to create a new Buffer #[inline] pub fn from_bytes(bytes: Bytes) -> Self { let length = bytes.len(); + let ptr = bytes.as_ptr(); Buffer { data: Arc::new(bytes), - offset: 0, + ptr, length, } } + /// Returns the offset, in bytes, of `Self::ptr` to `Self::data` + /// + /// self.ptr and self.data can be different after slicing or advancing the buffer. + pub fn ptr_offset(&self) -> usize { + // Safety: `ptr` is always in bounds of `data`. + unsafe { self.ptr.offset_from(self.data.ptr().as_ptr()) as usize } + } + + /// Returns the pointer to the start of the buffer without the offset. + pub fn data_ptr(&self) -> NonNull { + self.data.ptr() + } + + /// Create a [`Buffer`] from the provided [`Vec`] without copying + #[inline] + pub fn from_vec(vec: Vec) -> Self { + MutableBuffer::from(vec).into() + } + /// Initializes a [Buffer] from a slice of items. - pub fn from_slice_ref>(items: &T) -> Self { + pub fn from_slice_ref>(items: T) -> Self { let slice = items.as_ref(); - let capacity = slice.len() * std::mem::size_of::(); + let capacity = std::mem::size_of_val(slice); let mut buffer = MutableBuffer::with_capacity(capacity); buffer.extend_from_slice(slice); buffer.into() } - /// Creates a buffer from an existing memory region (must already be byte-aligned), this + /// Creates a buffer from an existing aligned memory region (must already be byte-aligned), this /// `Buffer` will free this piece of memory when dropped. /// /// # Arguments @@ -76,9 +112,11 @@ impl Buffer { /// /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. + #[deprecated(note = "Use Buffer::from_vec")] pub unsafe fn from_raw_parts(ptr: NonNull, len: usize, capacity: usize) -> Self { assert!(len <= capacity); - Buffer::build_with_arguments(ptr, len, Deallocation::Arrow(capacity)) + let layout = Layout::from_size_align(capacity, ALIGNMENT).unwrap(); + Buffer::build_with_arguments(ptr, len, Deallocation::Standard(layout)) } /// Creates a buffer from an existing memory region. Ownership of the memory is tracked via reference counting @@ -98,7 +136,7 @@ impl Buffer { len: usize, owner: Arc, ) -> Self { - Buffer::build_with_arguments(ptr, len, Deallocation::Custom(owner)) + Buffer::build_with_arguments(ptr, len, Deallocation::Custom(owner, len)) } /// Auxiliary method to create a new Buffer @@ -108,9 +146,10 @@ impl Buffer { deallocation: Deallocation, ) -> Self { let bytes = Bytes::new(ptr, len, deallocation); + let ptr = bytes.as_ptr(); Buffer { + ptr, data: Arc::new(bytes), - offset: 0, length: len, } } @@ -136,23 +175,44 @@ impl Buffer { /// Returns the byte slice stored in this buffer pub fn as_slice(&self) -> &[u8] { - &self.data[self.offset..(self.offset + self.length)] + unsafe { std::slice::from_raw_parts(self.ptr, self.length) } + } + + pub(crate) fn deallocation(&self) -> &Deallocation { + self.data.deallocation() } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. /// Doing so allows the same memory region to be shared between buffers. + /// /// # Panics + /// /// Panics iff `offset` is larger than `len`. pub fn slice(&self, offset: usize) -> Self { + let mut s = self.clone(); + s.advance(offset); + s + } + + /// Increases the offset of this buffer by `offset` + /// + /// # Panics + /// + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn advance(&mut self, offset: usize) { assert!( - offset <= self.len(), - "the offset of the new Buffer cannot exceed the existing length" + offset <= self.length, + "the offset of the new Buffer cannot exceed the existing length: offset={} length={}", + offset, + self.length ); - Self { - data: self.data.clone(), - offset: self.offset + offset, - length: self.length - offset, - } + self.length -= offset; + // Safety: + // This cannot overflow as + // `self.offset + self.length < self.data.len()` + // `offset < self.length` + self.ptr = unsafe { self.ptr.add(offset) }; } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`, @@ -162,12 +222,16 @@ impl Buffer { /// Panics iff `(offset + length)` is larger than the existing length. pub fn slice_with_length(&self, offset: usize, length: usize) -> Self { assert!( - offset + length <= self.len(), - "the offset of the new Buffer cannot exceed the existing length" + offset.saturating_add(length) <= self.length, + "the offset of the new Buffer cannot exceed the existing length: slice offset={offset} length={length} selflen={}", + self.length ); + // Safety: + // offset + length <= self.length + let ptr = unsafe { self.ptr.add(offset) }; Self { data: self.data.clone(), - offset: self.offset + offset, + ptr, length, } } @@ -178,10 +242,10 @@ impl Buffer { /// stored anywhere, to avoid dangling pointers. #[inline] pub fn as_ptr(&self) -> *const u8 { - unsafe { self.data.ptr().as_ptr().add(self.offset) } + self.ptr } - /// View buffer as typed slice. + /// View buffer as a slice of a specific type. /// /// # Panics /// @@ -215,6 +279,7 @@ impl Buffer { } /// Returns the number of 1-bits in this buffer. + #[deprecated(note = "use count_set_bits_offset instead")] pub fn count_set_bits(&self) -> usize { let len_in_bits = self.len() * 8; // self.offset is already taken into consideration by the bit_chunks implementation @@ -226,23 +291,114 @@ impl Buffer { pub fn count_set_bits_offset(&self, offset: usize, len: usize) -> usize { UnalignedBitChunk::new(self.as_slice(), offset, len).count_ones() } + + /// Returns `MutableBuffer` for mutating the buffer if this buffer is not shared. + /// Returns `Err` if this is shared or its allocation is from an external source or + /// it is not allocated with alignment [`ALIGNMENT`] + pub fn into_mutable(self) -> Result { + let ptr = self.ptr; + let length = self.length; + Arc::try_unwrap(self.data) + .and_then(|bytes| { + // The pointer of underlying buffer should not be offset. + assert_eq!(ptr, bytes.ptr().as_ptr()); + MutableBuffer::from_bytes(bytes).map_err(Arc::new) + }) + .map_err(|bytes| Buffer { + data: bytes, + ptr, + length, + }) + } + + /// Returns `Vec` for mutating the buffer + /// + /// Returns `Err(self)` if this buffer does not have the same [`Layout`] as + /// the destination Vec or contains a non-zero offset + pub fn into_vec(self) -> Result, Self> { + let layout = match self.data.deallocation() { + Deallocation::Standard(l) => l, + _ => return Err(self), // Custom allocation + }; + + if self.ptr != self.data.as_ptr() { + return Err(self); // Data is offset + } + + let v_capacity = layout.size() / std::mem::size_of::(); + match Layout::array::(v_capacity) { + Ok(expected) if layout == &expected => {} + _ => return Err(self), // Incorrect layout + } + + let length = self.length; + let ptr = self.ptr; + let v_len = self.length / std::mem::size_of::(); + + Arc::try_unwrap(self.data) + .map(|bytes| unsafe { + let ptr = bytes.ptr().as_ptr() as _; + std::mem::forget(bytes); + // Safety + // Verified that bytes layout matches that of Vec + Vec::from_raw_parts(ptr, v_len, v_capacity) + }) + .map_err(|bytes| Buffer { + data: bytes, + ptr, + length, + }) + } + + /// Returns true if this [`Buffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.ptr == other.ptr && self.length == other.length + } } -/// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly -/// allocated memory region. -impl> From for Buffer { - fn from(p: T) -> Self { - // allocate aligned memory buffer - let slice = p.as_ref(); - let len = slice.len(); - let mut buffer = MutableBuffer::new(len); - buffer.extend_from_slice(slice); - buffer.into() +/// Note that here we deliberately do not implement +/// `impl> From for Buffer` +/// As it would accept `Buffer::from(vec![...])` that would cause an unexpected copy. +/// Instead, we ask user to be explicit when copying is occurring, e.g., `Buffer::from(vec![...].to_byte_slice())`. +/// For zero-copy conversion, user should use `Buffer::from_vec(vec![...])`. +/// +/// Since we removed impl for `AsRef`, we added the following three specific implementations to reduce API breakage. +/// See for more discussion on this. +impl From<&[u8]> for Buffer { + fn from(p: &[u8]) -> Self { + Self::from_slice_ref(p) + } +} + +impl From<[u8; N]> for Buffer { + fn from(p: [u8; N]) -> Self { + Self::from_slice_ref(p) + } +} + +impl From<&[u8; N]> for Buffer { + fn from(p: &[u8; N]) -> Self { + Self::from_slice_ref(p) + } +} + +impl From> for Buffer { + fn from(value: Vec) -> Self { + Self::from_vec(value) + } +} + +impl From> for Buffer { + fn from(value: ScalarBuffer) -> Self { + value.into_inner() } } /// Creating a `Buffer` instance by storing the boolean values into the buffer -impl std::iter::FromIterator for Buffer { +impl FromIterator for Buffer { fn from_iter(iter: I) -> Self where I: IntoIterator, @@ -266,12 +422,18 @@ impl From for Buffer { } } +impl From> for Buffer { + fn from(mut value: BufferBuilder) -> Self { + value.finish() + } +} + impl Buffer { /// Creates a [`Buffer`] from an [`Iterator`] with a trusted (upper) length. /// Prefer this to `collect` whenever possible, as it is ~60% faster. /// # Example /// ``` - /// # use arrow::buffer::Buffer; + /// # use arrow_buffer::buffer::Buffer; /// let v = vec![1u32]; /// let iter = v.iter().map(|x| x * 2); /// let buffer = unsafe { Buffer::from_trusted_len_iter(iter) }; @@ -301,40 +463,24 @@ impl Buffer { pub unsafe fn try_from_trusted_len_iter< E, T: ArrowNativeType, - I: Iterator>, + I: Iterator>, >( iterator: I, - ) -> std::result::Result { + ) -> Result { Ok(MutableBuffer::try_from_trusted_len_iter(iterator)?.into()) } } impl FromIterator for Buffer { fn from_iter>(iter: I) -> Self { - let mut iterator = iter.into_iter(); - let size = std::mem::size_of::(); - - // first iteration, which will likely reserve sufficient space for the buffer. - let mut buffer = match iterator.next() { - None => MutableBuffer::new(0), - Some(element) => { - let (lower, _) = iterator.size_hint(); - let mut buffer = MutableBuffer::new(lower.saturating_add(1) * size); - unsafe { - std::ptr::write(buffer.as_mut_ptr() as *mut T, element); - buffer.set_len(size); - } - buffer - } - }; - - buffer.extend_from_iter(iterator); - buffer.into() + let vec = Vec::from_iter(iter); + Buffer::from_vec(vec) } } #[cfg(test)] mod tests { + use crate::i256; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::thread; @@ -417,9 +563,7 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_slice_offset_out_of_bound() { let buf = Buffer::from(&[2, 4, 6, 8, 10]); buf.slice(6); @@ -427,7 +571,7 @@ mod tests { #[test] fn test_access_concurrently() { - let buffer = Buffer::from(vec![1, 2, 3, 4, 5]); + let buffer = Buffer::from([1, 2, 3, 4, 5]); let buffer2 = buffer.clone(); assert_eq!([1, 2, 3, 4, 5], buffer.as_slice()); @@ -466,11 +610,17 @@ mod tests { #[test] fn test_count_bits() { - assert_eq!(0, Buffer::from(&[0b00000000]).count_set_bits()); - assert_eq!(8, Buffer::from(&[0b11111111]).count_set_bits()); - assert_eq!(3, Buffer::from(&[0b00001101]).count_set_bits()); - assert_eq!(6, Buffer::from(&[0b01001001, 0b01010010]).count_set_bits()); - assert_eq!(16, Buffer::from(&[0b11111111, 0b11111111]).count_set_bits()); + assert_eq!(0, Buffer::from(&[0b00000000]).count_set_bits_offset(0, 8)); + assert_eq!(8, Buffer::from(&[0b11111111]).count_set_bits_offset(0, 8)); + assert_eq!(3, Buffer::from(&[0b00001101]).count_set_bits_offset(0, 8)); + assert_eq!( + 6, + Buffer::from(&[0b01001001, 0b01010010]).count_set_bits_offset(0, 16) + ); + assert_eq!( + 16, + Buffer::from(&[0b11111111, 0b11111111]).count_set_bits_offset(0, 16) + ); } #[test] @@ -479,31 +629,31 @@ mod tests { 0, Buffer::from(&[0b11111111, 0b00000000]) .slice(1) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 8, Buffer::from(&[0b11111111, 0b11111111]) .slice_with_length(1, 1) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 3, Buffer::from(&[0b11111111, 0b11111111, 0b00001101]) .slice(2) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 6, Buffer::from(&[0b11111111, 0b01001001, 0b01010010]) .slice_with_length(1, 2) - .count_set_bits() + .count_set_bits_offset(0, 16) ); assert_eq!( 16, Buffer::from(&[0b11111111, 0b11111111, 0b11111111, 0b11111111]) .slice(2) - .count_set_bits() + .count_set_bits_offset(0, 16) ); } @@ -574,4 +724,140 @@ mod tests { let slice = buffer.typed_data::(); assert_eq!(slice, &[2, 3, 4, 5]); } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn slice_overflow() { + let buffer = Buffer::from(MutableBuffer::from_len_zeroed(12)); + buffer.slice_with_length(2, usize::MAX); + } + + #[test] + fn test_vec_interop() { + // Test empty vec + let a: Vec = Vec::new(); + let b = Buffer::from_vec(a); + b.into_vec::().unwrap(); + + // Test vec with capacity + let a: Vec = Vec::with_capacity(20); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 0); + assert_eq!(back.capacity(), 20); + + // Test vec with values + let mut a: Vec = Vec::with_capacity(3); + a.extend_from_slice(&[1, 2, 3]); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 3); + assert_eq!(back.capacity(), 3); + + // Test vec with values and spare capacity + let mut a: Vec = Vec::with_capacity(20); + a.extend_from_slice(&[1, 4, 7, 8, 9, 3, 6]); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 7); + assert_eq!(back.capacity(), 20); + + // Test incorrect alignment + let a: Vec = Vec::new(); + let b = Buffer::from_vec(a); + let b = b.into_vec::().unwrap_err(); + b.into_vec::().unwrap_err(); + + // Test convert between types with same alignment + // This is an implementation quirk, but isn't harmful + // as ArrowNativeType are trivially transmutable + let a: Vec = vec![1, 2, 3, 4]; + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 4); + assert_eq!(back.capacity(), 4); + + // i256 has the same layout as i128 so this is valid + let mut b: Vec = Vec::with_capacity(4); + b.extend_from_slice(&[1, 2, 3, 4]); + let b = Buffer::from_vec(b); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 2); + assert_eq!(back.capacity(), 2); + + // Invalid layout + let b: Vec = vec![1, 2, 3]; + let b = Buffer::from_vec(b); + b.into_vec::().unwrap_err(); + + // Invalid layout + let mut b: Vec = Vec::with_capacity(5); + b.extend_from_slice(&[1, 2, 3, 4]); + let b = Buffer::from_vec(b); + b.into_vec::().unwrap_err(); + + // Truncates length + // This is an implementation quirk, but isn't harmful + let mut b: Vec = Vec::with_capacity(4); + b.extend_from_slice(&[1, 2, 3]); + let b = Buffer::from_vec(b); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 1); + assert_eq!(back.capacity(), 2); + + // Cannot use aligned allocation + let b = Buffer::from(MutableBuffer::new(10)); + let b = b.into_vec::().unwrap_err(); + b.into_vec::().unwrap_err(); + + // Test slicing + let mut a: Vec = Vec::with_capacity(20); + a.extend_from_slice(&[1, 4, 7, 8, 9, 3, 6]); + let b = Buffer::from_vec(a); + let slice = b.slice_with_length(0, 64); + + // Shared reference fails + let slice = slice.into_vec::().unwrap_err(); + drop(b); + + // Succeeds as no outstanding shared reference + let back = slice.into_vec::().unwrap(); + assert_eq!(&back, &[1, 4, 7, 8]); + assert_eq!(back.capacity(), 20); + + // Slicing by non-multiple length truncates + let mut a: Vec = Vec::with_capacity(8); + a.extend_from_slice(&[1, 4, 7, 3]); + + let b = Buffer::from_vec(a); + let slice = b.slice_with_length(0, 34); + drop(b); + + let back = slice.into_vec::().unwrap(); + assert_eq!(&back, &[1, 4]); + assert_eq!(back.capacity(), 8); + + // Offset prevents conversion + let a: Vec = vec![1, 3, 4, 6]; + let b = Buffer::from_vec(a).slice(2); + b.into_vec::().unwrap_err(); + + let b = MutableBuffer::new(16).into_buffer(); + let b = b.into_vec::().unwrap_err(); // Invalid layout + let b = b.into_vec::().unwrap_err(); // Invalid layout + b.into_mutable().unwrap(); + + let b = Buffer::from_vec(vec![1_u32, 3, 5]); + let b = b.into_mutable().unwrap(); + let b = Buffer::from(b); + let b = b.into_vec::().unwrap(); + assert_eq!(b, &[1, 3, 5]); + } + + #[test] + #[should_panic(expected = "capacity overflow")] + fn test_from_iter_overflow() { + let iter_len = usize::MAX / std::mem::size_of::() + 1; + let _ = Buffer::from_iter(std::iter::repeat(0_u64).take(iter_len)); + } } diff --git a/arrow-buffer/src/buffer/mod.rs b/arrow-buffer/src/buffer/mod.rs new file mode 100644 index 000000000000..d33e68795e4e --- /dev/null +++ b/arrow-buffer/src/buffer/mod.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types of shared memory region + +mod offset; +pub use offset::*; +mod immutable; +pub use immutable::*; +mod mutable; +pub use mutable::*; +mod ops; +pub use ops::*; +mod scalar; +pub use scalar::*; +mod boolean; +pub use boolean::*; +mod null; +pub use null::*; +mod run; +pub use run::*; diff --git a/arrow/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs similarity index 70% rename from arrow/src/buffer/mutable.rs rename to arrow-buffer/src/buffer/mutable.rs index 1c662ec23eef..7fcbd89dd262 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -15,28 +15,35 @@ // specific language governing permissions and limitations // under the License. -use super::Buffer; -use crate::alloc::Deallocation; +use std::alloc::{handle_alloc_error, Layout}; +use std::mem; +use std::ptr::NonNull; + +use crate::alloc::{Deallocation, ALIGNMENT}; use crate::{ - alloc, bytes::Bytes, - datatypes::{ArrowNativeType, ToByteSlice}, + native::{ArrowNativeType, ToByteSlice}, util::bit_util, }; -use std::ptr::NonNull; + +use super::Buffer; /// A [`MutableBuffer`] is Arrow's interface to build a [`Buffer`] out of items or slices of items. +/// /// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to have its pointer aligned /// along cache lines and in multiple of 64 bytes. +/// /// Use [MutableBuffer::push] to insert an item, [MutableBuffer::extend_from_slice] /// to insert many items, and `into` to convert it to [`Buffer`]. /// -/// For a safe, strongly typed API consider using [`crate::array::BufferBuilder`] +/// For a safe, strongly typed API consider using [`Vec`] and [`ScalarBuffer`](crate::ScalarBuffer) +/// +/// Note: this may be deprecated in a future release ([#1176](https://github.com/apache/arrow-rs/issues/1176)) /// /// # Example /// /// ``` -/// # use arrow::buffer::{Buffer, MutableBuffer}; +/// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.push(256u32); /// buffer.extend_from_slice(&[1u32]); @@ -49,25 +56,41 @@ pub struct MutableBuffer { data: NonNull, // invariant: len <= capacity len: usize, - capacity: usize, + layout: Layout, } impl MutableBuffer { /// Allocate a new [MutableBuffer] with initial capacity to be at least `capacity`. + /// + /// See [`MutableBuffer::with_capacity`]. #[inline] pub fn new(capacity: usize) -> Self { Self::with_capacity(capacity) } /// Allocate a new [MutableBuffer] with initial capacity to be at least `capacity`. + /// + /// # Panics + /// + /// If `capacity`, when rounded up to the nearest multiple of [`ALIGNMENT`], is greater + /// then `isize::MAX`, then this function will panic. #[inline] pub fn with_capacity(capacity: usize) -> Self { let capacity = bit_util::round_upto_multiple_of_64(capacity); - let ptr = alloc::allocate_aligned(capacity); + let layout = Layout::from_size_align(capacity, ALIGNMENT) + .expect("failed to create layout for MutableBuffer"); + let data = match layout.size() { + 0 => dangling_ptr(), + _ => { + // Safety: Verified size != 0 + let raw_ptr = unsafe { std::alloc::alloc(layout) }; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + }; Self { - data: ptr, + data, len: 0, - capacity, + layout, } } @@ -75,7 +98,7 @@ impl MutableBuffer { /// all bytes are guaranteed to be `0u8`. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::from_len_zeroed(127); /// assert_eq!(buffer.len(), 127); /// assert!(buffer.capacity() >= 127); @@ -83,13 +106,37 @@ impl MutableBuffer { /// assert_eq!(data[126], 0u8); /// ``` pub fn from_len_zeroed(len: usize) -> Self { - let new_capacity = bit_util::round_upto_multiple_of_64(len); - let ptr = alloc::allocate_aligned_zeroed(new_capacity); - Self { - data: ptr, - len, - capacity: new_capacity, - } + let layout = Layout::from_size_align(len, ALIGNMENT).unwrap(); + let data = match layout.size() { + 0 => dangling_ptr(), + _ => { + // Safety: Verified size != 0 + let raw_ptr = unsafe { std::alloc::alloc_zeroed(layout) }; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + }; + Self { data, len, layout } + } + + /// Create a [`MutableBuffer`] from the provided [`Vec`] without copying + #[inline] + #[deprecated(note = "Use From>")] + pub fn from_vec(vec: Vec) -> Self { + Self::from(vec) + } + + /// Allocates a new [MutableBuffer] from given `Bytes`. + pub(crate) fn from_bytes(bytes: Bytes) -> Result { + let layout = match bytes.deallocation() { + Deallocation::Standard(layout) => *layout, + _ => return Err(bytes), + }; + + let len = bytes.len(); + let data = bytes.ptr(); + mem::forget(bytes); + + Ok(Self { data, len, layout }) } /// creates a new [MutableBuffer] with capacity and length capable of holding `len` bits. @@ -106,7 +153,7 @@ impl MutableBuffer { /// the buffer directly (e.g., modifying the buffer by holding a mutable reference /// from `data_mut()`). pub fn with_bitset(mut self, end: usize, val: bool) -> Self { - assert!(end <= self.capacity); + assert!(end <= self.layout.size()); let v = if val { 255 } else { 0 }; unsafe { std::ptr::write_bytes(self.data.as_ptr(), v, end); @@ -121,7 +168,14 @@ impl MutableBuffer { /// `len` of the buffer and so can be used to initialize the memory region from /// `len` to `capacity`. pub fn set_null_bits(&mut self, start: usize, count: usize) { - assert!(start + count <= self.capacity); + assert!( + start.saturating_add(count) <= self.layout.size(), + "range start index {start} and count {count} out of bounds for \ + buffer of length {}", + self.layout.size(), + ); + + // Safety: `self.data[start..][..count]` is in-bounds and well-aligned for `u8` unsafe { std::ptr::write_bytes(self.data.as_ptr().add(start), 0, count); } @@ -131,7 +185,7 @@ impl MutableBuffer { /// `self.len + additional > capacity`. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.reserve(253); // allocates for the first time /// (0..253u8).for_each(|i| buffer.push(i)); // no reallocation @@ -143,19 +197,35 @@ impl MutableBuffer { #[inline(always)] pub fn reserve(&mut self, additional: usize) { let required_cap = self.len + additional; - if required_cap > self.capacity { - // JUSTIFICATION - // Benefit - // necessity - // Soundness - // `self.data` is valid for `self.capacity`. - let (ptr, new_capacity) = - unsafe { reallocate(self.data, self.capacity, required_cap) }; - self.data = ptr; - self.capacity = new_capacity; + if required_cap > self.layout.size() { + let new_capacity = bit_util::round_upto_multiple_of_64(required_cap); + let new_capacity = std::cmp::max(new_capacity, self.layout.size() * 2); + self.reallocate(new_capacity) } } + #[cold] + fn reallocate(&mut self, capacity: usize) { + let new_layout = Layout::from_size_align(capacity, self.layout.align()).unwrap(); + if new_layout.size() == 0 { + if self.layout.size() != 0 { + // Safety: data was allocated with layout + unsafe { std::alloc::dealloc(self.as_mut_ptr(), self.layout) }; + self.layout = new_layout + } + return; + } + + let data = match self.layout.size() { + // Safety: new_layout is not empty + 0 => unsafe { std::alloc::alloc(new_layout) }, + // Safety: verified new layout is valid and not empty + _ => unsafe { std::alloc::realloc(self.as_mut_ptr(), self.layout, capacity) }, + }; + self.data = NonNull::new(data).unwrap_or_else(|| handle_alloc_error(new_layout)); + self.layout = new_layout; + } + /// Truncates this buffer to `len` bytes /// /// If `len` is greater than the buffer's current length, this has no effect @@ -171,7 +241,7 @@ impl MutableBuffer { /// growing it (potentially reallocating it) and writing `value` in the newly available bytes. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.resize(253, 2); // allocates for the first time /// assert_eq!(buffer.as_slice()[252], 2u8); @@ -195,7 +265,7 @@ impl MutableBuffer { /// /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// // 2 cache lines /// let mut buffer = MutableBuffer::new(128); /// assert_eq!(buffer.capacity(), 128); @@ -207,17 +277,8 @@ impl MutableBuffer { /// ``` pub fn shrink_to_fit(&mut self) { let new_capacity = bit_util::round_upto_multiple_of_64(self.len); - if new_capacity < self.capacity { - // JUSTIFICATION - // Benefit - // necessity - // Soundness - // `self.data` is valid for `self.capacity`. - let ptr = - unsafe { alloc::reallocate(self.data, self.capacity, new_capacity) }; - - self.data = ptr; - self.capacity = new_capacity; + if new_capacity < self.layout.size() { + self.reallocate(new_capacity) } } @@ -238,7 +299,7 @@ impl MutableBuffer { /// The invariant `buffer.len() <= buffer.capacity()` is always upheld. #[inline] pub const fn capacity(&self) -> usize { - self.capacity + self.layout.size() } /// Clear all existing data from this buffer. @@ -281,14 +342,12 @@ impl MutableBuffer { #[inline] pub(super) fn into_buffer(self) -> Buffer { - let bytes = unsafe { - Bytes::new(self.data, self.len, Deallocation::Arrow(self.capacity)) - }; + let bytes = unsafe { Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) }; std::mem::forget(self); Buffer::from_bytes(bytes) } - /// View this buffer as a slice of a specific type. + /// View this buffer as a mutable slice of a specific type. /// /// # Panics /// @@ -298,8 +357,22 @@ impl MutableBuffer { // SAFETY // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect // implementation outside this crate, and this method checks alignment - let (prefix, offsets, suffix) = - unsafe { self.as_slice_mut().align_to_mut::() }; + let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + offsets + } + + /// View buffer as a immutable slice of a specific type. + /// + /// # Panics + /// + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub fn typed_data(&self) -> &[T] { + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -307,15 +380,14 @@ impl MutableBuffer { /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let mut buffer = MutableBuffer::new(0); /// buffer.extend_from_slice(&[2u32, 0]); /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes /// ``` #[inline] pub fn extend_from_slice(&mut self, items: &[T]) { - let len = items.len(); - let additional = len * std::mem::size_of::(); + let additional = mem::size_of_val(items); self.reserve(additional); unsafe { // this assumes that `[ToByteSlice]` can be copied directly @@ -331,7 +403,7 @@ impl MutableBuffer { /// Extends the buffer with a new item, increasing its capacity if needed. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let mut buffer = MutableBuffer::new(0); /// buffer.push(256u32); /// assert_eq!(buffer.len(), 4) // u32 has 4 bytes @@ -350,7 +422,7 @@ impl MutableBuffer { /// Extends the buffer with a new item, without checking for sufficient capacity /// # Safety - /// Caller must ensure that the capacity()-len()>=size_of() + /// Caller must ensure that the capacity()-len()>=`size_of`() #[inline] pub unsafe fn push_unchecked(&mut self, item: T) { let additional = std::mem::size_of::(); @@ -369,24 +441,62 @@ impl MutableBuffer { /// # Safety /// The caller must ensure that the buffer was properly initialized up to `len`. #[inline] - pub(crate) unsafe fn set_len(&mut self, len: usize) { + pub unsafe fn set_len(&mut self, len: usize) { assert!(len <= self.capacity()); self.len = len; } + + /// Invokes `f` with values `0..len` collecting the boolean results into a new `MutableBuffer` + /// + /// This is similar to `from_trusted_len_iter_bool`, however, can be significantly faster + /// as it eliminates the conditional `Iterator::next` + #[inline] + pub fn collect_bool bool>(len: usize, mut f: F) -> Self { + let mut buffer = Self::new(bit_util::ceil(len, 64) * 8); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut packed = 0; + for bit_idx in 0..64 { + let i = bit_idx + chunk * 64; + packed |= (f(i) as u64) << bit_idx; + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + if remainder != 0 { + let mut packed = 0; + for bit_idx in 0..remainder { + let i = bit_idx + chunks * 64; + packed |= (f(i) as u64) << bit_idx; + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + buffer.truncate(bit_util::ceil(len, 8)); + buffer + } } -/// # Safety -/// `ptr` must be allocated for `old_capacity`. -#[cold] -unsafe fn reallocate( - ptr: NonNull, - old_capacity: usize, - new_capacity: usize, -) -> (NonNull, usize) { - let new_capacity = bit_util::round_upto_multiple_of_64(new_capacity); - let new_capacity = std::cmp::max(new_capacity, old_capacity * 2); - let ptr = alloc::reallocate(ptr, old_capacity, new_capacity); - (ptr, new_capacity) +#[inline] +fn dangling_ptr() -> NonNull { + // SAFETY: ALIGNMENT is a non-zero usize which is then casted + // to a *mut T. Therefore, `ptr` is not null and the conditions for + // calling new_unchecked() are respected. + #[cfg(miri)] + { + // Since miri implies a nightly rust version we can use the unstable strict_provenance feature + unsafe { NonNull::new_unchecked(std::ptr::without_provenance_mut(ALIGNMENT)) } + } + #[cfg(not(miri))] + { + unsafe { NonNull::new_unchecked(ALIGNMENT as *mut u8) } + } } impl Extend for MutableBuffer { @@ -397,6 +507,21 @@ impl Extend for MutableBuffer { } } +impl From> for MutableBuffer { + fn from(value: Vec) -> Self { + // Safety + // Vec::as_ptr guaranteed to not be null and ArrowNativeType are trivially transmutable + let data = unsafe { NonNull::new_unchecked(value.as_ptr() as _) }; + let len = value.len() * mem::size_of::(); + // Safety + // Vec guaranteed to have a valid layout matching that of `Layout::array` + // This is based on `RawVec::current_memory` + let layout = unsafe { Layout::array::(value.capacity()).unwrap_unchecked() }; + mem::forget(value); + Self { data, len, layout } + } +} + impl MutableBuffer { #[inline] pub(super) fn extend_from_iter>( @@ -411,7 +536,7 @@ impl MutableBuffer { // this is necessary because of https://github.com/rust-lang/rust/issues/32155 let mut len = SetLenOnDrop::new(&mut self.len); let mut dst = unsafe { self.data.as_ptr().add(len.local_len) }; - let capacity = self.capacity; + let capacity = self.layout.size(); while len.local_len + item_size <= capacity { if let Some(item) = iterator.next() { @@ -434,7 +559,7 @@ impl MutableBuffer { /// Prefer this to `collect` whenever possible, as it is faster ~60% faster. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let v = vec![1u32]; /// let iter = v.iter().map(|x| x * 2); /// let buffer = unsafe { MutableBuffer::from_trusted_len_iter(iter) }; @@ -475,10 +600,10 @@ impl MutableBuffer { } /// Creates a [`MutableBuffer`] from a boolean [`Iterator`] with a trusted (upper) length. - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let v = vec![false, true, false]; /// let iter = v.iter().map(|x| *x || true); /// let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(iter) }; @@ -492,42 +617,11 @@ impl MutableBuffer { // we can't specialize `extend` for `TrustedLen` like `Vec` does. // 2. `from_trusted_len_iter_bool` is faster. #[inline] - pub unsafe fn from_trusted_len_iter_bool>( - mut iterator: I, - ) -> Self { + pub unsafe fn from_trusted_len_iter_bool>(mut iterator: I) -> Self { let (_, upper) = iterator.size_hint(); - let upper = upper.expect("from_trusted_len_iter requires an upper limit"); - - let mut result = { - let byte_capacity: usize = upper.saturating_add(7) / 8; - MutableBuffer::new(byte_capacity) - }; - - 'a: loop { - let mut byte_accum: u8 = 0; - let mut mask: u8 = 1; - - //collect (up to) 8 bits into a byte - while mask != 0 { - if let Some(value) = iterator.next() { - byte_accum |= match value { - true => mask, - false => 0, - }; - mask <<= 1; - } else { - if mask != 1 { - // Add last byte - result.push_unchecked(byte_accum); - } - break 'a; - } - } + let len = upper.expect("from_trusted_len_iter requires an upper limit"); - // Soundness: from_trusted_len - result.push_unchecked(byte_accum); - } - result + Self::collect_bool(len, |_| iterator.next().unwrap()) } /// Creates a [`MutableBuffer`] from an [`Iterator`] with a trusted (upper) length or errors @@ -540,10 +634,10 @@ impl MutableBuffer { pub unsafe fn try_from_trusted_len_iter< E, T: ArrowNativeType, - I: Iterator>, + I: Iterator>, >( iterator: I, - ) -> std::result::Result { + ) -> Result { let item_size = std::mem::size_of::(); let (_, upper) = iterator.size_hint(); let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); @@ -574,6 +668,12 @@ impl MutableBuffer { } } +impl Default for MutableBuffer { + fn default() -> Self { + Self::with_capacity(0) + } +} + impl std::ops::Deref for MutableBuffer { type Target = [u8]; @@ -590,7 +690,10 @@ impl std::ops::DerefMut for MutableBuffer { impl Drop for MutableBuffer { fn drop(&mut self) { - unsafe { alloc::free_aligned(self.data, self.capacity) }; + if self.layout.size() != 0 { + // Safety: data was allocated with standard allocator with given layout + unsafe { std::alloc::dealloc(self.data.as_ptr() as _, self.layout) }; + } } } @@ -599,7 +702,7 @@ impl PartialEq for MutableBuffer { if self.len != other.len { return false; } - if self.capacity != other.capacity { + if self.layout != other.layout { return false; } self.as_slice() == other.as_slice() @@ -686,6 +789,14 @@ impl std::iter::FromIterator for MutableBuffer { } } +impl std::iter::FromIterator for MutableBuffer { + fn from_iter>(iter: I) -> Self { + let mut buffer = Self::default(); + buffer.extend_from_iter(iter.into_iter()); + buffer + } +} + #[cfg(test)] mod tests { use super::*; @@ -698,6 +809,19 @@ mod tests { assert!(buf.is_empty()); } + #[test] + fn test_mutable_default() { + let buf = MutableBuffer::default(); + assert_eq!(0, buf.capacity()); + assert_eq!(0, buf.len()); + assert!(buf.is_empty()); + + let mut buf = MutableBuffer::default(); + buf.extend_from_slice(b"hello"); + assert_eq!(5, buf.len()); + assert_eq!(b"hello", buf.as_slice()); + } + #[test] fn test_mutable_extend_from_slice() { let mut buf = MutableBuffer::new(100); @@ -772,7 +896,7 @@ mod tests { #[test] fn test_from_trusted_len_iter() { let iter = vec![1u32, 2].into_iter(); - let buf = unsafe { Buffer::from_trusted_len_iter(iter) }; + let buf = unsafe { MutableBuffer::from_trusted_len_iter(iter) }; assert_eq!(8, buf.len()); assert_eq!(&[1u8, 0, 0, 0, 2, 0, 0, 0], buf.as_slice()); } @@ -860,4 +984,45 @@ mod tests { buffer.shrink_to_fit(); assert!(buffer.capacity() >= 64 && buffer.capacity() < 128); } + + #[test] + fn test_mutable_set_null_bits() { + let mut buffer = MutableBuffer::new(8).with_bitset(8, true); + + for i in 0..=buffer.capacity() { + buffer.set_null_bits(i, 0); + assert_eq!(buffer[..8], [255; 8][..]); + } + + buffer.set_null_bits(1, 4); + assert_eq!(buffer[..8], [255, 0, 0, 0, 0, 255, 255, 255][..]); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob() { + let mut buffer = MutableBuffer::new(64); + buffer.set_null_bits(1, buffer.capacity()); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob_by_overflow() { + let mut buffer = MutableBuffer::new(0); + buffer.set_null_bits(1, usize::MAX); + } + + #[test] + fn from_iter() { + let buffer = [1u16, 2, 3, 4].into_iter().collect::(); + assert_eq!(buffer.len(), 4 * mem::size_of::()); + assert_eq!(buffer.as_slice(), &[1, 0, 2, 0, 3, 0, 4, 0]); + } + + #[test] + #[should_panic(expected = "failed to create layout for MutableBuffer: LayoutError")] + fn test_with_capacity_panics_above_max_capacity() { + let max_capacity = isize::MAX as usize - (isize::MAX as usize % ALIGNMENT); + let _ = MutableBuffer::with_capacity(max_capacity + 1); + } } diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs new file mode 100644 index 000000000000..c79aef398059 --- /dev/null +++ b/arrow-buffer/src/buffer/null.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::{BitIndexIterator, BitIterator, BitSliceIterator}; +use crate::buffer::BooleanBuffer; +use crate::{Buffer, MutableBuffer}; + +/// A [`BooleanBuffer`] used to encode validity for arrow arrays +/// +/// As per the [Arrow specification], array validity is encoded in a packed bitmask with a +/// `true` value indicating the corresponding slot is not null, and `false` indicating +/// that it is null. +/// +/// [Arrow specification]: https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct NullBuffer { + buffer: BooleanBuffer, + null_count: usize, +} + +impl NullBuffer { + /// Create a new [`NullBuffer`] computing the null count + pub fn new(buffer: BooleanBuffer) -> Self { + let null_count = buffer.len() - buffer.count_set_bits(); + Self { buffer, null_count } + } + + /// Create a new [`NullBuffer`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + buffer: BooleanBuffer::new_unset(len), + null_count: len, + } + } + + /// Create a new [`NullBuffer`] of length `len` where all values are valid + /// + /// Note: it is more efficient to not set the null buffer if it is known to be all valid + pub fn new_valid(len: usize) -> Self { + Self { + buffer: BooleanBuffer::new_set(len), + null_count: 0, + } + } + + /// Create a new [`NullBuffer`] with the provided `buffer` and `null_count` + /// + /// # Safety + /// + /// `buffer` must contain `null_count` `0` bits + pub unsafe fn new_unchecked(buffer: BooleanBuffer, null_count: usize) -> Self { + Self { buffer, null_count } + } + + /// Computes the union of the nulls in two optional [`NullBuffer`] + /// + /// This is commonly used by binary operations where the result is NULL if either + /// of the input values is NULL. Handling the null mask separately in this way + /// can yield significant performance improvements over an iterator approach + pub fn union(lhs: Option<&NullBuffer>, rhs: Option<&NullBuffer>) -> Option { + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(Self::new(lhs.inner() & rhs.inner())), + (Some(n), None) | (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Returns true if all nulls in `other` also exist in self + pub fn contains(&self, other: &NullBuffer) -> bool { + if other.null_count == 0 { + return true; + } + let lhs = self.inner().bit_chunks().iter_padded(); + let rhs = other.inner().bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) + } + + /// Returns a new [`NullBuffer`] where each bit in the current null buffer + /// is repeated `count` times. This is useful for masking the nulls of + /// the child of a FixedSizeListArray based on its parent + pub fn expand(&self, count: usize) -> Self { + let capacity = self.buffer.len().checked_mul(count).unwrap(); + let mut buffer = MutableBuffer::new_null(capacity); + + // Expand each bit within `null_mask` into `element_len` + // bits, constructing the implicit mask of the child elements + for i in 0..self.buffer.len() { + if self.is_null(i) { + continue; + } + for j in 0..count { + crate::bit_util::set_bit(buffer.as_mut(), i * count + j) + } + } + Self { + buffer: BooleanBuffer::new(buffer.into(), 0, capacity), + null_count: self.null_count * count, + } + } + + /// Returns the length of this [`NullBuffer`] + #[inline] + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Returns the offset of this [`NullBuffer`] in bits + #[inline] + pub fn offset(&self) -> usize { + self.buffer.offset() + } + + /// Returns true if this [`NullBuffer`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Returns the null count for this [`NullBuffer`] + #[inline] + pub fn null_count(&self) -> usize { + self.null_count + } + + /// Returns `true` if the value at `idx` is not null + #[inline] + pub fn is_valid(&self, idx: usize) -> bool { + self.buffer.value(idx) + } + + /// Returns `true` if the value at `idx` is null + #[inline] + pub fn is_null(&self, idx: usize) -> bool { + !self.is_valid(idx) + } + + /// Returns the packed validity of this [`NullBuffer`] not including any offset + #[inline] + pub fn validity(&self) -> &[u8] { + self.buffer.values() + } + + /// Slices this [`NullBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self::new(self.buffer.slice(offset, len)) + } + + /// Returns an iterator over the bits in this [`NullBuffer`] + /// + /// * `true` indicates that the corresponding value is not NULL + /// * `false` indicates that the corresponding value is NULL + /// + /// Note: [`Self::valid_indices`] will be significantly faster for most use-cases + pub fn iter(&self) -> BitIterator<'_> { + self.buffer.iter() + } + + /// Returns a [`BitIndexIterator`] over the valid indices in this [`NullBuffer`] + /// + /// Valid indices indicate the corresponding value is not NULL + pub fn valid_indices(&self) -> BitIndexIterator<'_> { + self.buffer.set_indices() + } + + /// Returns a [`BitSliceIterator`] yielding contiguous ranges of valid indices + /// + /// Valid indices indicate the corresponding value is not NULL + pub fn valid_slices(&self) -> BitSliceIterator<'_> { + self.buffer.set_slices() + } + + /// Calls the provided closure for each index in this null mask that is set + #[inline] + pub fn try_for_each_valid_idx Result<(), E>>( + &self, + f: F, + ) -> Result<(), E> { + if self.null_count == self.len() { + return Ok(()); + } + self.valid_indices().try_for_each(f) + } + + /// Returns the inner [`BooleanBuffer`] + #[inline] + pub fn inner(&self) -> &BooleanBuffer { + &self.buffer + } + + /// Returns the inner [`BooleanBuffer`] + #[inline] + pub fn into_inner(self) -> BooleanBuffer { + self.buffer + } + + /// Returns the underlying [`Buffer`] + #[inline] + pub fn buffer(&self) -> &Buffer { + self.buffer.inner() + } +} + +impl<'a> IntoIterator for &'a NullBuffer { + type Item = bool; + type IntoIter = BitIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.buffer.iter() + } +} + +impl From for NullBuffer { + fn from(value: BooleanBuffer) -> Self { + Self::new(value) + } +} + +impl From<&[bool]> for NullBuffer { + fn from(value: &[bool]) -> Self { + BooleanBuffer::from(value).into() + } +} + +impl From> for NullBuffer { + fn from(value: Vec) -> Self { + BooleanBuffer::from(value).into() + } +} + +impl FromIterator for NullBuffer { + fn from_iter>(iter: T) -> Self { + BooleanBuffer::from_iter(iter).into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_size() { + // This tests that the niche optimisation eliminates the overhead of an option + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::>() + ); + } +} diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs new file mode 100644 index 000000000000..e9087d30098c --- /dev/null +++ b/arrow-buffer/src/buffer/offset.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::ScalarBuffer; +use crate::{ArrowNativeType, MutableBuffer, OffsetBufferBuilder}; +use std::ops::Deref; + +/// A non-empty buffer of monotonically increasing, positive integers. +/// +/// [`OffsetBuffer`] are used to represent ranges of offsets. An +/// `OffsetBuffer` of `N+1` items contains `N` such ranges. The start +/// offset for element `i` is `offsets[i]` and the end offset is +/// `offsets[i+1]`. Equal offsets represent an empty range. +/// +/// # Example +/// +/// This example shows how 5 distinct ranges, are represented using a +/// 6 entry `OffsetBuffer`. The first entry `(0, 3)` represents the +/// three offsets `0, 1, 2`. The entry `(3,3)` represent no offsets +/// (e.g. an empty list). +/// +/// ```text +/// ┌───────┐ ┌───┐ +/// │ (0,3) │ │ 0 │ +/// ├───────┤ ├───┤ +/// │ (3,3) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (3,4) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (4,5) │ │ 4 │ +/// ├───────┤ ├───┤ +/// │ (5,7) │ │ 5 │ +/// └───────┘ ├───┤ +/// │ 7 │ +/// └───┘ +/// +/// Offsets Buffer +/// Logical +/// Offsets +/// +/// (offsets[i], +/// offsets[i+1]) +/// ``` +#[derive(Debug, Clone)] +pub struct OffsetBuffer(ScalarBuffer); + +impl OffsetBuffer { + /// Create a new [`OffsetBuffer`] from the provided [`ScalarBuffer`] + /// + /// # Panics + /// + /// Panics if `buffer` is not a non-empty buffer containing + /// monotonically increasing values greater than or equal to zero + pub fn new(buffer: ScalarBuffer) -> Self { + assert!(!buffer.is_empty(), "offsets cannot be empty"); + assert!( + buffer[0] >= O::usize_as(0), + "offsets must be greater than 0" + ); + assert!( + buffer.windows(2).all(|w| w[0] <= w[1]), + "offsets must be monotonically increasing" + ); + Self(buffer) + } + + /// Create a new [`OffsetBuffer`] from the provided [`ScalarBuffer`] + /// + /// # Safety + /// + /// `buffer` must be a non-empty buffer containing monotonically increasing + /// values greater than or equal to zero + pub unsafe fn new_unchecked(buffer: ScalarBuffer) -> Self { + Self(buffer) + } + + /// Create a new [`OffsetBuffer`] containing a single 0 value + pub fn new_empty() -> Self { + let buffer = MutableBuffer::from_len_zeroed(std::mem::size_of::()); + Self(buffer.into_buffer().into()) + } + + /// Create a new [`OffsetBuffer`] containing `len + 1` `0` values + pub fn new_zeroed(len: usize) -> Self { + let len_bytes = len + .checked_add(1) + .and_then(|o| o.checked_mul(std::mem::size_of::())) + .expect("overflow"); + let buffer = MutableBuffer::from_len_zeroed(len_bytes); + Self(buffer.into_buffer().into()) + } + + /// Create a new [`OffsetBuffer`] from the iterator of slice lengths + /// + /// ``` + /// # use arrow_buffer::OffsetBuffer; + /// let offsets = OffsetBuffer::::from_lengths([1, 3, 5]); + /// assert_eq!(offsets.as_ref(), &[0, 1, 4, 9]); + /// ``` + /// + /// # Panics + /// + /// Panics on overflow + pub fn from_lengths(lengths: I) -> Self + where + I: IntoIterator, + { + let iter = lengths.into_iter(); + let mut out = Vec::with_capacity(iter.size_hint().0 + 1); + out.push(O::usize_as(0)); + + let mut acc = 0_usize; + for length in iter { + acc = acc.checked_add(length).expect("usize overflow"); + out.push(O::usize_as(acc)) + } + // Check for overflow + O::from_usize(acc).expect("offset overflow"); + Self(out.into()) + } + + /// Returns the inner [`ScalarBuffer`] + pub fn inner(&self) -> &ScalarBuffer { + &self.0 + } + + /// Returns the inner [`ScalarBuffer`], consuming self + pub fn into_inner(self) -> ScalarBuffer { + self.0 + } + + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self(self.0.slice(offset, len.saturating_add(1))) + } + + /// Returns true if this [`OffsetBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.0.ptr_eq(&other.0) + } +} + +impl Deref for OffsetBuffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[T]> for OffsetBuffer { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl From> for OffsetBuffer { + fn from(value: OffsetBufferBuilder) -> Self { + value.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "offsets cannot be empty")] + fn empty_offsets() { + OffsetBuffer::new(Vec::::new().into()); + } + + #[test] + #[should_panic(expected = "offsets must be greater than 0")] + fn negative_offsets() { + OffsetBuffer::new(vec![-1, 0, 1].into()); + } + + #[test] + fn offsets() { + OffsetBuffer::new(vec![0, 1, 2, 3].into()); + + let offsets = OffsetBuffer::::new_zeroed(3); + assert_eq!(offsets.as_ref(), &[0; 4]); + + let offsets = OffsetBuffer::::new_zeroed(0); + assert_eq!(offsets.as_ref(), &[0; 1]); + } + + #[test] + #[should_panic(expected = "overflow")] + fn offsets_new_zeroed_overflow() { + OffsetBuffer::::new_zeroed(usize::MAX); + } + + #[test] + #[should_panic(expected = "offsets must be monotonically increasing")] + fn non_monotonic_offsets() { + OffsetBuffer::new(vec![1, 2, 0].into()); + } + + #[test] + fn from_lengths() { + let buffer = OffsetBuffer::::from_lengths([2, 6, 3, 7, 2]); + assert_eq!(buffer.as_ref(), &[0, 2, 8, 11, 18, 20]); + + let half_max = i32::MAX / 2; + let buffer = OffsetBuffer::::from_lengths([half_max as usize, half_max as usize]); + assert_eq!(buffer.as_ref(), &[0, half_max, half_max * 2]); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn from_lengths_offset_overflow() { + OffsetBuffer::::from_lengths([i32::MAX as usize, 1]); + } + + #[test] + #[should_panic(expected = "usize overflow")] + fn from_lengths_usize_overflow() { + OffsetBuffer::::from_lengths([usize::MAX, 1]); + } +} diff --git a/arrow/src/buffer/ops.rs b/arrow-buffer/src/buffer/ops.rs similarity index 71% rename from arrow/src/buffer/ops.rs rename to arrow-buffer/src/buffer/ops.rs index 7000f39767cb..c69e5c6deb10 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow-buffer/src/buffer/ops.rs @@ -20,26 +20,19 @@ use crate::util::bit_util::ceil; /// Apply a bitwise operation `op` to four inputs and return the result as a Buffer. /// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. -#[allow(clippy::too_many_arguments)] -pub(crate) fn bitwise_quaternary_op_helper( - first: &Buffer, - first_offset_in_bits: usize, - second: &Buffer, - second_offset_in_bits: usize, - third: &Buffer, - third_offset_in_bits: usize, - fourth: &Buffer, - fourth_offset_in_bits: usize, +pub fn bitwise_quaternary_op_helper( + buffers: [&Buffer; 4], + offsets: [usize; 4], len_in_bits: usize, op: F, ) -> Buffer where F: Fn(u64, u64, u64, u64) -> u64, { - let first_chunks = first.bit_chunks(first_offset_in_bits, len_in_bits); - let second_chunks = second.bit_chunks(second_offset_in_bits, len_in_bits); - let third_chunks = third.bit_chunks(third_offset_in_bits, len_in_bits); - let fourth_chunks = fourth.bit_chunks(fourth_offset_in_bits, len_in_bits); + let first_chunks = buffers[0].bit_chunks(offsets[0], len_in_bits); + let second_chunks = buffers[1].bit_chunks(offsets[1], len_in_bits); + let third_chunks = buffers[2].bit_chunks(offsets[2], len_in_bits); + let fourth_chunks = buffers[3].bit_chunks(offsets[3], len_in_bits); let chunks = first_chunks .iter() @@ -73,10 +66,10 @@ pub fn bitwise_bin_op_helper( right: &Buffer, right_offset_in_bits: usize, len_in_bits: usize, - op: F, + mut op: F, ) -> Buffer where - F: Fn(u64, u64) -> u64, + F: FnMut(u64, u64) -> u64, { let left_chunks = left.bit_chunks(left_offset_in_bits, len_in_bits); let right_chunks = right.bit_chunks(right_offset_in_bits, len_in_bits); @@ -104,10 +97,10 @@ pub fn bitwise_unary_op_helper( left: &Buffer, offset_in_bits: usize, len_in_bits: usize, - op: F, + mut op: F, ) -> Buffer where - F: Fn(u64) -> u64, + F: FnMut(u64) -> u64, { // reserve capacity and set length so we can get a typed view of u64 chunks let mut result = @@ -132,6 +125,8 @@ where result.into() } +/// Apply a bitwise and to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. pub fn buffer_bin_and( left: &Buffer, left_offset_in_bits: usize, @@ -149,6 +144,8 @@ pub fn buffer_bin_and( ) } +/// Apply a bitwise or to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. pub fn buffer_bin_or( left: &Buffer, left_offset_in_bits: usize, @@ -166,10 +163,46 @@ pub fn buffer_bin_or( ) } -pub fn buffer_unary_not( +/// Apply a bitwise xor to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +pub fn buffer_bin_xor( left: &Buffer, - offset_in_bits: usize, + left_offset_in_bits: usize, + right: &Buffer, + right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { + bitwise_bin_op_helper( + left, + left_offset_in_bits, + right, + right_offset_in_bits, + len_in_bits, + |a, b| a ^ b, + ) +} + +/// Apply a bitwise and_not to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +pub fn buffer_bin_and_not( + left: &Buffer, + left_offset_in_bits: usize, + right: &Buffer, + right_offset_in_bits: usize, + len_in_bits: usize, +) -> Buffer { + bitwise_bin_op_helper( + left, + left_offset_in_bits, + right, + right_offset_in_bits, + len_in_bits, + |a, b| a & !b, + ) +} + +/// Apply a bitwise not to one input and return the result as a Buffer. +/// The input is treated as a bitmap, meaning that offset and length are specified in number of bits. +pub fn buffer_unary_not(left: &Buffer, offset_in_bits: usize, len_in_bits: usize) -> Buffer { bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) } diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs new file mode 100644 index 000000000000..3dbbe344a025 --- /dev/null +++ b/arrow-buffer/src/buffer/run.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::ScalarBuffer; +use crate::ArrowNativeType; + +/// A slice-able buffer of monotonically increasing, positive integers used to store run-ends +/// +/// # Logical vs Physical +/// +/// A [`RunEndBuffer`] is used to encode runs of the same value, the index of each run is +/// called the physical index. The logical index is then the corresponding index in the logical +/// run-encoded array, i.e. a single run of length `3`, would have the logical indices `0..3`. +/// +/// Each value in [`RunEndBuffer::values`] is the cumulative length of all runs in the +/// logical array, up to that physical index. +/// +/// Consider a [`RunEndBuffer`] containing `[3, 4, 6]`. The maximum physical index is `2`, +/// as there are `3` values, and the maximum logical index is `5`, as the maximum run end +/// is `6`. The physical indices are therefore `[0, 0, 0, 1, 2, 2]` +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌─────────┐ +/// │ 3 │ │ 0 │ ─┬──────▶ │ 0 │ +/// ├─────────┤ ├─────────┤ │ ├─────────┤ +/// │ 4 │ │ 1 │ ─┤ ┌────▶ │ 1 │ +/// ├─────────┤ ├─────────┤ │ │ ├─────────┤ +/// │ 6 │ │ 2 │ ─┘ │ ┌──▶ │ 2 │ +/// └─────────┘ ├─────────┤ │ │ └─────────┘ +/// run ends │ 3 │ ───┘ │ physical indices +/// ├─────────┤ │ +/// │ 4 │ ─────┤ +/// ├─────────┤ │ +/// │ 5 │ ─────┘ +/// └─────────┘ +/// logical indices +/// ``` +/// +/// # Slicing +/// +/// In order to provide zero-copy slicing, this container stores a separate offset and length +/// +/// For example, a [`RunEndBuffer`] containing values `[3, 6, 8]` with offset and length `4` would +/// describe the physical indices `1, 1, 2, 2` +/// +/// For example, a [`RunEndBuffer`] containing values `[6, 8, 9]` with offset `2` and length `5` +/// would describe the physical indices `0, 0, 0, 0, 1` +/// +/// [Run-End encoded layout]: https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout +#[derive(Debug, Clone)] +pub struct RunEndBuffer { + run_ends: ScalarBuffer, + len: usize, + offset: usize, +} + +impl RunEndBuffer +where + E: ArrowNativeType, +{ + /// Create a new [`RunEndBuffer`] from a [`ScalarBuffer`], an `offset` and `len` + /// + /// # Panics + /// + /// - `buffer` does not contain strictly increasing values greater than zero + /// - the last value of `buffer` is less than `offset + len` + pub fn new(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + assert!( + run_ends.windows(2).all(|w| w[0] < w[1]), + "run-ends not strictly increasing" + ); + + if len != 0 { + assert!(!run_ends.is_empty(), "non-empty slice but empty run-ends"); + let end = E::from_usize(offset.saturating_add(len)).unwrap(); + assert!( + *run_ends.first().unwrap() > E::usize_as(0), + "run-ends not greater than 0" + ); + assert!( + *run_ends.last().unwrap() >= end, + "slice beyond bounds of run-ends" + ); + } + + Self { + run_ends, + offset, + len, + } + } + + /// Create a new [`RunEndBuffer`] from an [`ScalarBuffer`], an `offset` and `len` + /// + /// # Safety + /// + /// - `buffer` must contain strictly increasing values greater than zero + /// - The last value of `buffer` must be greater than or equal to `offset + len` + pub unsafe fn new_unchecked(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + Self { + run_ends, + offset, + len, + } + } + + /// Returns the logical offset into the run-ends stored by this buffer + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// Returns the logical length of the run-ends stored by this buffer + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if this buffer is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the values of this [`RunEndBuffer`] not including any offset + #[inline] + pub fn values(&self) -> &[E] { + &self.run_ends + } + + /// Returns the maximum run-end encoded in the underlying buffer + #[inline] + pub fn max_value(&self) -> usize { + self.values().last().copied().unwrap_or_default().as_usize() + } + + /// Performs a binary search to find the physical index for the given logical index + /// + /// The result is arbitrary if `logical_index >= self.len()` + pub fn get_physical_index(&self, logical_index: usize) -> usize { + let logical_index = E::usize_as(self.offset + logical_index); + let cmp = |p: &E| p.partial_cmp(&logical_index).unwrap(); + + match self.run_ends.binary_search_by(cmp) { + Ok(idx) => idx + 1, + Err(idx) => idx, + } + } + + /// Returns the physical index at which the logical array starts + pub fn get_start_physical_index(&self) -> usize { + if self.offset == 0 || self.len == 0 { + return 0; + } + // Fallback to binary search + self.get_physical_index(0) + } + + /// Returns the physical index at which the logical array ends + pub fn get_end_physical_index(&self) -> usize { + if self.len == 0 { + return 0; + } + if self.max_value() == self.offset + self.len { + return self.values().len() - 1; + } + // Fallback to binary search + self.get_physical_index(self.len - 1) + } + + /// Slices this [`RunEndBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced RunEndBuffer cannot exceed the existing length" + ); + Self { + run_ends: self.run_ends.clone(), + offset: self.offset + offset, + len, + } + } + + /// Returns the inner [`ScalarBuffer`] + pub fn inner(&self) -> &ScalarBuffer { + &self.run_ends + } + + /// Returns the inner [`ScalarBuffer`], consuming self + pub fn into_inner(self) -> ScalarBuffer { + self.run_ends + } +} + +#[cfg(test)] +mod tests { + use crate::buffer::RunEndBuffer; + + #[test] + fn test_zero_length_slice() { + let buffer = RunEndBuffer::new(vec![1_i32, 4_i32].into(), 0, 4); + assert_eq!(buffer.get_start_physical_index(), 0); + assert_eq!(buffer.get_end_physical_index(), 1); + assert_eq!(buffer.get_physical_index(3), 1); + + for offset in 0..4 { + let sliced = buffer.slice(offset, 0); + assert_eq!(sliced.get_start_physical_index(), 0); + assert_eq!(sliced.get_end_physical_index(), 0); + } + + let buffer = RunEndBuffer::new(Vec::::new().into(), 0, 0); + assert_eq!(buffer.get_start_physical_index(), 0); + assert_eq!(buffer.get_end_physical_index(), 0); + } +} diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs new file mode 100644 index 000000000000..343b8549e93d --- /dev/null +++ b/arrow-buffer/src/buffer/scalar.rs @@ -0,0 +1,339 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::alloc::Deallocation; +use crate::buffer::Buffer; +use crate::native::ArrowNativeType; +use crate::{BufferBuilder, MutableBuffer, OffsetBuffer}; +use std::fmt::Formatter; +use std::marker::PhantomData; +use std::ops::Deref; + +/// A strongly-typed [`Buffer`] supporting zero-copy cloning and slicing +/// +/// The easiest way to think about `ScalarBuffer` is being equivalent to a `Arc>`, +/// with the following differences: +/// +/// - slicing and cloning is O(1). +/// - it supports external allocated memory +/// +/// ``` +/// # use arrow_buffer::ScalarBuffer; +/// // Zero-copy conversion from Vec +/// let buffer = ScalarBuffer::from(vec![1, 2, 3]); +/// assert_eq!(&buffer, &[1, 2, 3]); +/// +/// // Zero-copy slicing +/// let sliced = buffer.slice(1, 2); +/// assert_eq!(&sliced, &[2, 3]); +/// ``` +#[derive(Clone)] +pub struct ScalarBuffer { + /// Underlying data buffer + buffer: Buffer, + phantom: PhantomData, +} + +impl std::fmt::Debug for ScalarBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ScalarBuffer").field(&self.as_ref()).finish() + } +} + +impl ScalarBuffer { + /// Create a new [`ScalarBuffer`] from a [`Buffer`], and an `offset` + /// and `length` in units of `T` + /// + /// # Panics + /// + /// This method will panic if + /// + /// * `offset` or `len` would result in overflow + /// * `buffer` is not aligned to a multiple of `std::mem::align_of::` + /// * `bytes` is not large enough for the requested slice + pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + let size = std::mem::size_of::(); + let byte_offset = offset.checked_mul(size).expect("offset overflow"); + let byte_len = len.checked_mul(size).expect("length overflow"); + buffer.slice_with_length(byte_offset, byte_len).into() + } + + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self::new(self.buffer.clone(), offset, len) + } + + /// Returns the inner [`Buffer`] + pub fn inner(&self) -> &Buffer { + &self.buffer + } + + /// Returns the inner [`Buffer`], consuming self + pub fn into_inner(self) -> Buffer { + self.buffer + } + + /// Returns true if this [`ScalarBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.buffer.ptr_eq(&other.buffer) + } +} + +impl Deref for ScalarBuffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + // SAFETY: Verified alignment in From + unsafe { + std::slice::from_raw_parts( + self.buffer.as_ptr() as *const T, + self.buffer.len() / std::mem::size_of::(), + ) + } + } +} + +impl AsRef<[T]> for ScalarBuffer { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl From for ScalarBuffer { + fn from(value: MutableBuffer) -> Self { + Buffer::from(value).into() + } +} + +impl From for ScalarBuffer { + fn from(buffer: Buffer) -> Self { + let align = std::mem::align_of::(); + let is_aligned = buffer.as_ptr().align_offset(align) == 0; + + match buffer.deallocation() { + Deallocation::Standard(_) => assert!( + is_aligned, + "Memory pointer is not aligned with the specified scalar type" + ), + Deallocation::Custom(_, _) => + assert!(is_aligned, "Memory pointer from external source (e.g, FFI) is not aligned with the specified scalar type. Before importing buffer through FFI, please make sure the allocation is aligned."), + } + + Self { + buffer, + phantom: Default::default(), + } + } +} + +impl From> for ScalarBuffer { + fn from(value: OffsetBuffer) -> Self { + value.into_inner() + } +} + +impl From> for ScalarBuffer { + fn from(value: Vec) -> Self { + Self { + buffer: Buffer::from_vec(value), + phantom: Default::default(), + } + } +} + +impl From> for Vec { + fn from(value: ScalarBuffer) -> Self { + value + .buffer + .into_vec() + .unwrap_or_else(|buffer| buffer.typed_data::().into()) + } +} + +impl From> for ScalarBuffer { + fn from(mut value: BufferBuilder) -> Self { + let len = value.len(); + Self::new(value.finish(), 0, len) + } +} + +impl FromIterator for ScalarBuffer { + fn from_iter>(iter: I) -> Self { + iter.into_iter().collect::>().into() + } +} + +impl<'a, T: ArrowNativeType> IntoIterator for &'a ScalarBuffer { + type Item = &'a T; + type IntoIter = std::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.as_ref().iter() + } +} + +impl + ?Sized> PartialEq for ScalarBuffer { + fn eq(&self, other: &S) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for [T; N] { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for [T] { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for Vec { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_slice().eq(other.as_ref()) + } +} + +#[cfg(test)] +mod tests { + use std::{ptr::NonNull, sync::Arc}; + + use super::*; + + #[test] + fn test_basic() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let typed = ScalarBuffer::::new(buffer.clone(), 0, 3); + assert_eq!(*typed, expected); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 2); + assert_eq!(*typed, expected[1..]); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 0); + assert!(typed.is_empty()); + + let typed = ScalarBuffer::::new(buffer, 3, 0); + assert!(typed.is_empty()); + } + + #[test] + fn test_debug() { + let buffer = ScalarBuffer::from(vec![1, 2, 3]); + assert_eq!(format!("{buffer:?}"), "ScalarBuffer([1, 2, 3])"); + } + + #[test] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] + fn test_unaligned() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let buffer = buffer.slice(1); + ScalarBuffer::::new(buffer, 0, 2); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_length_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 1, 3); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_offset_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 4, 0); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_length_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX, 1); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_start_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX / 4 + 1, 0); + } + + #[test] + #[should_panic(expected = "length overflow")] + fn test_end_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 0, usize::MAX / 4 + 1); + } + + #[test] + fn convert_from_buffer_builder() { + let input = vec![1, 2, 3, 4]; + let buffer_builder = BufferBuilder::from(input.clone()); + let scalar_buffer = ScalarBuffer::from(buffer_builder); + assert_eq!(scalar_buffer.as_ref(), input); + } + + #[test] + fn into_vec() { + let input = vec![1u8, 2, 3, 4]; + + // No copy + let input_buffer = Buffer::from_vec(input.clone()); + let input_ptr = input_buffer.as_ptr(); + let input_len = input_buffer.len(); + let scalar_buffer = ScalarBuffer::::new(input_buffer, 0, input_len); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec.as_slice(), input.as_slice()); + assert_eq!(vec.as_ptr(), input_ptr); + + // Custom allocation - makes a copy + let mut input_clone = input.clone(); + let input_ptr = NonNull::new(input_clone.as_mut_ptr()).unwrap(); + let dealloc = Arc::new(()); + let buffer = + unsafe { Buffer::from_custom_allocation(input_ptr, input_clone.len(), dealloc as _) }; + let scalar_buffer = ScalarBuffer::::new(buffer, 0, input.len()); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec, input.as_slice()); + assert_ne!(vec.as_ptr(), input_ptr.as_ptr()); + + // Offset - makes a copy + let input_buffer = Buffer::from_vec(input.clone()); + let input_ptr = input_buffer.as_ptr(); + let input_len = input_buffer.len(); + let scalar_buffer = ScalarBuffer::::new(input_buffer, 1, input_len - 1); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec.as_slice(), &input[1..]); + assert_ne!(vec.as_ptr(), input_ptr); + + // Inner buffer Arc ref count != 0 - makes a copy + let buffer = Buffer::from_slice_ref(input.as_slice()); + let scalar_buffer = ScalarBuffer::::new(buffer, 0, input.len()); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec, input.as_slice()); + assert_ne!(vec.as_ptr(), input.as_ptr()); + } +} diff --git a/arrow/src/array/builder/boolean_buffer_builder.rs b/arrow-buffer/src/builder/boolean.rs similarity index 65% rename from arrow/src/array/builder/boolean_buffer_builder.rs rename to arrow-buffer/src/builder/boolean.rs index 5b6d1ce48478..ca178ae5ce4e 100644 --- a/arrow/src/array/builder/boolean_buffer_builder.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::{Buffer, MutableBuffer}; - -use super::Range; - -use crate::util::bit_util; +use crate::{bit_mask, bit_util, BooleanBuffer, Buffer, MutableBuffer}; +use std::ops::Range; +/// Builder for [`BooleanBuffer`] #[derive(Debug)] pub struct BooleanBufferBuilder { buffer: MutableBuffer, @@ -28,6 +26,7 @@ pub struct BooleanBufferBuilder { } impl BooleanBufferBuilder { + /// Creates a new `BooleanBufferBuilder` #[inline] pub fn new(capacity: usize) -> Self { let byte_capacity = bit_util::ceil(capacity, 8); @@ -35,11 +34,24 @@ impl BooleanBufferBuilder { Self { buffer, len: 0 } } + /// Creates a new `BooleanBufferBuilder` from [`MutableBuffer`] of `len` + pub fn new_from_buffer(buffer: MutableBuffer, len: usize) -> Self { + assert!(len <= buffer.len() * 8); + let mut s = Self { + len: buffer.len() * 8, + buffer, + }; + s.truncate(len); + s + } + + /// Returns the length of the buffer #[inline] pub fn len(&self) -> usize { self.len } + /// Sets a bit in the buffer at `index` #[inline] pub fn set_bit(&mut self, index: usize, v: bool) { if v { @@ -49,21 +61,25 @@ impl BooleanBufferBuilder { } } + /// Gets a bit in the buffer at `index` #[inline] pub fn get_bit(&self, index: usize) -> bool { bit_util::get_bit(self.buffer.as_slice(), index) } + /// Returns true if empty #[inline] pub fn is_empty(&self) -> bool { self.len == 0 } + /// Returns the capacity of the buffer #[inline] pub fn capacity(&self) -> usize { self.buffer.capacity() * 8 } + /// Advances the buffer by `additional` bits #[inline] pub fn advance(&mut self, additional: usize) { let new_len = self.len + additional; @@ -74,6 +90,26 @@ impl BooleanBufferBuilder { self.len = new_len; } + /// Truncates the builder to the given length + /// + /// If `len` is greater than the buffer's current length, this has no effect + #[inline] + pub fn truncate(&mut self, len: usize) { + if len > self.len { + return; + } + + let new_len_bytes = bit_util::ceil(len, 8); + self.buffer.truncate(new_len_bytes); + self.len = len; + + let remainder = self.len % 8; + if remainder != 0 { + let mask = (1_u8 << remainder).wrapping_sub(1); + *self.buffer.as_mut().last_mut().unwrap() &= mask; + } + } + /// Reserve space to at least `additional` new bits. /// Capacity will be `>= self.len() + additional`. /// New bytes are uninitialized and reading them is undefined behavior. @@ -91,11 +127,13 @@ impl BooleanBufferBuilder { /// growing it (potentially reallocating it) and writing `false` in the newly available bits. #[inline] pub fn resize(&mut self, len: usize) { - let len_bytes = bit_util::ceil(len, 8); - self.buffer.resize(len_bytes, 0); - self.len = len; + match len.checked_sub(self.len) { + Some(delta) => self.advance(delta), + None => self.truncate(len), + } } + /// Appends a boolean `v` into the buffer #[inline] pub fn append(&mut self, v: bool) { self.advance(1); @@ -104,17 +142,32 @@ impl BooleanBufferBuilder { } } + /// Appends n `additional` bits of value `v` into the buffer #[inline] pub fn append_n(&mut self, additional: usize, v: bool) { - self.advance(additional); - if additional > 0 && v { - let offset = self.len() - additional; - (0..additional).for_each(|i| unsafe { - bit_util::set_bit_raw(self.buffer.as_mut_ptr(), offset + i) - }) + match v { + true => { + let new_len = self.len + additional; + let new_len_bytes = bit_util::ceil(new_len, 8); + let cur_remainder = self.len % 8; + let new_remainder = new_len % 8; + + if cur_remainder != 0 { + // Pad last byte with 1s + *self.buffer.as_slice_mut().last_mut().unwrap() |= !((1 << cur_remainder) - 1) + } + self.buffer.resize(new_len_bytes, 0xFF); + if new_remainder != 0 { + // Clear remaining bits + *self.buffer.as_slice_mut().last_mut().unwrap() &= (1 << new_remainder) - 1 + } + self.len = new_len; + } + false => self.advance(additional), } } + /// Appends a slice of booleans into the buffer #[inline] pub fn append_slice(&mut self, slice: &[bool]) { let additional = slice.len(); @@ -139,7 +192,7 @@ impl BooleanBufferBuilder { let offset_write = self.len; let len = range.end - range.start; self.advance(len); - crate::util::bit_mask::set_bits( + bit_mask::set_bits( self.buffer.as_slice_mut(), to_set, offset_write, @@ -148,16 +201,33 @@ impl BooleanBufferBuilder { ); } + /// Append [`BooleanBuffer`] to this [`BooleanBufferBuilder`] + pub fn append_buffer(&mut self, buffer: &BooleanBuffer) { + let range = buffer.offset()..buffer.offset() + buffer.len(); + self.append_packed_range(range, buffer.values()) + } + /// Returns the packed bits pub fn as_slice(&self) -> &[u8] { self.buffer.as_slice() } + /// Returns the packed bits + pub fn as_slice_mut(&mut self) -> &mut [u8] { + self.buffer.as_slice_mut() + } + + /// Creates a [`BooleanBuffer`] #[inline] - pub fn finish(&mut self) -> Buffer { + pub fn finish(&mut self) -> BooleanBuffer { let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); - self.len = 0; - buf.into() + let len = std::mem::replace(&mut self.len, 0); + BooleanBuffer::new(buf.into(), 0, len) + } + + /// Builds the [BooleanBuffer] without resetting the builder. + pub fn finish_cloned(&self) -> BooleanBuffer { + BooleanBuffer::new(Buffer::from_slice_ref(self.as_slice()), 0, self.len) } } @@ -168,6 +238,13 @@ impl From for Buffer { } } +impl From for BooleanBuffer { + #[inline] + fn from(builder: BooleanBufferBuilder) -> Self { + BooleanBuffer::new(builder.buffer.into(), 0, builder.len) + } +} + #[cfg(test)] mod tests { use super::*; @@ -182,7 +259,7 @@ mod tests { assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); let buffer = b.finish(); - assert_eq!(1, buffer.len()); + assert_eq!(4, buffer.len()); // Overallocate capacity let mut b = BooleanBufferBuilder::new(8); @@ -190,7 +267,7 @@ mod tests { assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); let buffer = b.finish(); - assert_eq!(1, buffer.len()); + assert_eq!(4, buffer.len()); } #[test] @@ -202,7 +279,7 @@ mod tests { buffer.append(true); buffer.set_bit(0, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b1010_u8]); + assert_eq!(buffer.finish().values(), &[0b1010_u8]); } #[test] @@ -214,7 +291,7 @@ mod tests { buffer.append(true); buffer.set_bit(3, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b0011_u8]); + assert_eq!(buffer.finish().values(), &[0b0011_u8]); } #[test] @@ -226,7 +303,7 @@ mod tests { buffer.append(true); buffer.set_bit(1, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b1001_u8]); + assert_eq!(buffer.finish().values(), &[0b1001_u8]); } #[test] @@ -240,7 +317,7 @@ mod tests { buffer.set_bit(1, false); buffer.set_bit(2, false); assert_eq!(buffer.len(), 5); - assert_eq!(buffer.finish().as_slice(), &[0b10001_u8]); + assert_eq!(buffer.finish().values(), &[0b10001_u8]); } #[test] @@ -251,7 +328,7 @@ mod tests { buffer.set_bit(3, false); buffer.set_bit(9, false); assert_eq!(buffer.len(), 10); - assert_eq!(buffer.finish().as_slice(), &[0b11110110_u8, 0b01_u8]); + assert_eq!(buffer.finish().values(), &[0b11110110_u8, 0b01_u8]); } #[test] @@ -267,7 +344,7 @@ mod tests { buffer.set_bit(14, true); buffer.set_bit(13, false); assert_eq!(buffer.len(), 15); - assert_eq!(buffer.finish().as_slice(), &[0b01010110_u8, 0b1011100_u8]); + assert_eq!(buffer.finish().values(), &[0b01010110_u8, 0b1011100_u8]); } #[test] @@ -332,7 +409,7 @@ mod tests { let start = a.min(b); let end = a.max(b); - buffer.append_packed_range(start..end, compacted_src.as_slice()); + buffer.append_packed_range(start..end, compacted_src.values()); all_bools.extend_from_slice(&src[start..end]); } @@ -362,6 +439,45 @@ mod tests { assert_eq!(builder.as_slice(), &[0b11101111, 0b00000001]); } + #[test] + fn test_truncate() { + let b = MutableBuffer::from_iter([true, true, true, true]); + let mut builder = BooleanBufferBuilder::new_from_buffer(b, 2); + builder.advance(2); + let finished = builder.finish(); + assert_eq!(finished.values(), &[0b00000011]); + + let mut builder = BooleanBufferBuilder::new(10); + builder.append_n(5, true); + builder.resize(3); + builder.advance(2); + let finished = builder.finish(); + assert_eq!(finished.values(), &[0b00000111]); + + let mut builder = BooleanBufferBuilder::new(10); + builder.append_n(16, true); + assert_eq!(builder.as_slice(), &[0xFF, 0xFF]); + builder.truncate(20); + assert_eq!(builder.as_slice(), &[0xFF, 0xFF]); + builder.truncate(14); + assert_eq!(builder.as_slice(), &[0xFF, 0b00111111]); + builder.append(false); + builder.append(true); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111]); + builder.append_packed_range(0..3, &[0xFF]); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b00000111]); + builder.truncate(17); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b00000001]); + builder.append_packed_range(0..2, &[2]); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b0000101]); + builder.truncate(8); + assert_eq!(builder.as_slice(), &[0xFF]); + builder.resize(14); + assert_eq!(builder.as_slice(), &[0xFF, 0x00]); + builder.truncate(0); + assert_eq!(builder.as_slice(), &[]); + } + #[test] fn test_boolean_builder_increases_buffer_len() { // 00000010 01001000 @@ -377,7 +493,7 @@ mod tests { } let buf2 = builder.finish(); - assert_eq!(buf.len(), buf2.len()); - assert_eq!(buf.as_slice(), buf2.as_slice()); + assert_eq!(buf.len(), buf2.inner().len()); + assert_eq!(buf.as_slice(), buf2.values()); } } diff --git a/arrow/src/array/builder/buffer_builder.rs b/arrow-buffer/src/builder/mod.rs similarity index 58% rename from arrow/src/array/builder/buffer_builder.rs rename to arrow-buffer/src/builder/mod.rs index a6a81dfd6c0e..f7e0e29dace4 100644 --- a/arrow/src/array/builder/buffer_builder.rs +++ b/arrow-buffer/src/builder/mod.rs @@ -15,35 +15,37 @@ // specific language governing permissions and limitations // under the License. -use std::mem; +//! Buffer builders -use crate::buffer::{Buffer, MutableBuffer}; -use crate::datatypes::ArrowNativeType; +mod boolean; +mod null; +mod offset; -use super::PhantomData; +pub use boolean::*; +pub use null::*; +pub use offset::*; -/// Builder for creating a [`Buffer`](crate::buffer::Buffer) object. +use crate::{ArrowNativeType, Buffer, MutableBuffer}; +use std::{iter, marker::PhantomData}; + +/// Builder for creating a [Buffer] object. /// -/// A [`Buffer`](crate::buffer::Buffer) is the underlying data -/// structure of Arrow's [`Arrays`](crate::array::Array). +/// A [Buffer] is the underlying data structure of Arrow's Arrays. /// /// For all supported types, there are type definitions for the -/// generic version of `BufferBuilder`, e.g. `UInt8BufferBuilder`. +/// generic version of `BufferBuilder`, e.g. `BufferBuilder`. /// /// # Example: /// /// ``` -/// use arrow::array::UInt8BufferBuilder; +/// # use arrow_buffer::builder::BufferBuilder; /// -/// # fn main() -> arrow::error::Result<()> { -/// let mut builder = UInt8BufferBuilder::new(100); +/// let mut builder = BufferBuilder::::new(100); /// builder.append_slice(&[42, 43, 44]); /// builder.append(45); /// let buffer = builder.finish(); /// /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); -/// # Ok(()) -/// # } /// ``` #[derive(Debug)] pub struct BufferBuilder { @@ -67,15 +69,15 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// assert!(builder.capacity() >= 10); /// ``` #[inline] pub fn new(capacity: usize) -> Self { - let buffer = MutableBuffer::new(capacity * mem::size_of::()); + let buffer = MutableBuffer::new(capacity * std::mem::size_of::()); Self { buffer, @@ -84,14 +86,24 @@ impl BufferBuilder { } } + /// Creates a new builder from a [`MutableBuffer`] + pub fn new_from_buffer(buffer: MutableBuffer) -> Self { + let buffer_len = buffer.len(); + Self { + buffer, + len: buffer_len / std::mem::size_of::(), + _marker: PhantomData, + } + } + /// Returns the current number of array elements in the internal buffer. /// /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.len(), 1); @@ -105,9 +117,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.is_empty(), false); @@ -136,16 +148,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.advance(2); /// /// assert_eq!(builder.len(), 2); /// ``` #[inline] pub fn advance(&mut self, i: usize) { - self.buffer.extend_zeros(i * mem::size_of::()); + self.buffer.extend_zeros(i * std::mem::size_of::()); self.len += i; } @@ -154,16 +166,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.reserve(10); /// /// assert!(builder.capacity() >= 20); /// ``` #[inline] pub fn reserve(&mut self, n: usize) { - self.buffer.reserve(n * mem::size_of::()); + self.buffer.reserve(n * std::mem::size_of::()); } /// Appends a value of type `T` into the builder, @@ -172,9 +184,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.len(), 1); @@ -192,9 +204,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_n(10, 42); /// /// assert_eq!(builder.len(), 10); @@ -202,10 +214,7 @@ impl BufferBuilder { #[inline] pub fn append_n(&mut self, n: usize, v: T) { self.reserve(n); - for _ in 0..n { - self.buffer.push(v); - } - self.len += n; + self.extend(iter::repeat(v).take(n)) } /// Appends `n`, zero-initialized values @@ -213,16 +222,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt32BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt32BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_n_zeroed(3); /// /// assert_eq!(builder.len(), 3); /// assert_eq!(builder.as_slice(), &[0, 0, 0]) #[inline] pub fn append_n_zeroed(&mut self, n: usize) { - self.buffer.extend_zeros(n * mem::size_of::()); + self.buffer.extend_zeros(n * std::mem::size_of::()); self.len += n; } @@ -231,9 +240,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_slice(&[42, 44, 46]); /// /// assert_eq!(builder.len(), 3); @@ -247,9 +256,9 @@ impl BufferBuilder { /// View the contents of this buffer as a slice /// /// ``` - /// use arrow::array::Float64BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = Float64BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(1.3); /// builder.append_n(2, 2.3); /// @@ -270,9 +279,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::Float32BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = Float32BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// builder.append_slice(&[1., 2., 3.4]); /// assert_eq!(builder.as_slice(), &[1., 2., 3.4]); @@ -297,9 +306,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt16BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt16BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// builder.append_slice(&[42, 44, 46]); /// assert_eq!(builder.as_slice(), &[42, 44, 46]); @@ -312,7 +321,7 @@ impl BufferBuilder { /// ``` #[inline] pub fn truncate(&mut self, len: usize) { - self.buffer.truncate(len * mem::size_of::()); + self.buffer.truncate(len * std::mem::size_of::()); self.len = len; } @@ -327,20 +336,17 @@ impl BufferBuilder { .1 .expect("append_trusted_len_iter expects upper bound"); self.reserve(len); - for v in iter { - self.buffer.push(v) - } - self.len += len; + self.extend(iter); } - /// Resets this builder and returns an immutable [`Buffer`](crate::buffer::Buffer). + /// Resets this builder and returns an immutable [Buffer]. /// /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_slice(&[42, 44, 46]); /// /// let buffer = builder.finish(); @@ -349,133 +355,67 @@ impl BufferBuilder { /// ``` #[inline] pub fn finish(&mut self) -> Buffer { - let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); + let buf = std::mem::take(&mut self.buffer); self.len = 0; buf.into() } } -#[cfg(test)] -mod tests { - use crate::array::array::Array; - use crate::array::builder::ArrayBuilder; - use crate::array::Int32BufferBuilder; - use crate::array::Int8Builder; - use crate::array::UInt8BufferBuilder; - - #[test] - fn test_builder_i32_empty() { - let mut b = Int32BufferBuilder::new(5); - assert_eq!(0, b.len()); - assert_eq!(16, b.capacity()); - let a = b.finish(); - assert_eq!(0, a.len()); +impl Default for BufferBuilder { + fn default() -> Self { + Self::new(0) } +} - #[test] - fn test_builder_i32_alloc_zero_bytes() { - let mut b = Int32BufferBuilder::new(0); - b.append(123); - let a = b.finish(); - assert_eq!(4, a.len()); +impl Extend for BufferBuilder { + fn extend>(&mut self, iter: I) { + self.buffer.extend(iter.into_iter().inspect(|_| { + self.len += 1; + })) } +} - #[test] - fn test_builder_i32() { - let mut b = Int32BufferBuilder::new(5); - for i in 0..5 { - b.append(i); - } - assert_eq!(16, b.capacity()); - let a = b.finish(); - assert_eq!(20, a.len()); +impl From> for BufferBuilder { + fn from(value: Vec) -> Self { + Self::new_from_buffer(MutableBuffer::from(value)) } +} - #[test] - fn test_builder_i32_grow_buffer() { - let mut b = Int32BufferBuilder::new(2); - assert_eq!(16, b.capacity()); - for i in 0..20 { - b.append(i); - } - assert_eq!(32, b.capacity()); - let a = b.finish(); - assert_eq!(80, a.len()); +impl FromIterator for BufferBuilder { + fn from_iter>(iter: I) -> Self { + let mut builder = Self::default(); + builder.extend(iter); + builder } +} - #[test] - fn test_builder_finish() { - let mut b = Int32BufferBuilder::new(5); - assert_eq!(16, b.capacity()); - for i in 0..10 { - b.append(i); - } - let mut a = b.finish(); - assert_eq!(40, a.len()); - assert_eq!(0, b.len()); - assert_eq!(0, b.capacity()); - - // Try build another buffer after cleaning up. - for i in 0..20 { - b.append(i) - } - assert_eq!(32, b.capacity()); - a = b.finish(); - assert_eq!(80, a.len()); - } +#[cfg(test)] +mod tests { + use super::*; + use std::mem; #[test] - fn test_reserve() { - let mut b = UInt8BufferBuilder::new(2); - assert_eq!(64, b.capacity()); - b.reserve(64); - assert_eq!(64, b.capacity()); - b.reserve(65); - assert_eq!(128, b.capacity()); - - let mut b = Int32BufferBuilder::new(2); - assert_eq!(16, b.capacity()); - b.reserve(16); - assert_eq!(16, b.capacity()); - b.reserve(17); - assert_eq!(32, b.capacity()); + fn default() { + let builder = BufferBuilder::::default(); + assert!(builder.is_empty()); + assert!(builder.buffer.is_empty()); + assert_eq!(builder.buffer.capacity(), 0); } #[test] - fn test_append_slice() { - let mut b = UInt8BufferBuilder::new(0); - b.append_slice(b"Hello, "); - b.append_slice(b"World!"); - let buffer = b.finish(); - assert_eq!(13, buffer.len()); - - let mut b = Int32BufferBuilder::new(0); - b.append_slice(&[32, 54]); - let buffer = b.finish(); - assert_eq!(8, buffer.len()); + fn from_iter() { + let input = [1u16, 2, 3, 4]; + let builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 4); + assert_eq!(builder.buffer.len(), 4 * mem::size_of::()); } #[test] - fn test_append_values() { - let mut a = Int8Builder::new(); - a.append_value(1); - a.append_null(); - a.append_value(-2); - assert_eq!(a.len(), 3); - - // append values - let values = &[1, 2, 3, 4]; - let is_valid = &[true, true, false, true]; - a.append_values(values, is_valid); - - assert_eq!(a.len(), 7); - let array = a.finish(); - assert_eq!(array.value(0), 1); - assert!(array.is_null(1)); - assert_eq!(array.value(2), -2); - assert_eq!(array.value(3), 1); - assert_eq!(array.value(4), 2); - assert!(array.is_null(5)); - assert_eq!(array.value(6), 4); + fn extend() { + let input = [1, 2]; + let mut builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 2); + builder.extend([3, 4]); + assert_eq!(builder.len(), 4); } } diff --git a/arrow/src/array/builder/null_buffer_builder.rs b/arrow-buffer/src/builder/null.rs similarity index 74% rename from arrow/src/array/builder/null_buffer_builder.rs rename to arrow-buffer/src/builder/null.rs index ef2e4c50ab9c..298b479e87df 100644 --- a/arrow/src/array/builder/null_buffer_builder.rs +++ b/arrow-buffer/src/builder/null.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; - -use super::BooleanBufferBuilder; +use crate::{BooleanBufferBuilder, MutableBuffer, NullBuffer}; /// Builder for creating the null bit buffer. +/// /// This builder only materializes the buffer when we append `false`. /// If you only append `true`s to the builder, what you get will be /// `None` when calling [`finish`](#method.finish). /// This optimization is **very** important for the performance. #[derive(Debug)] -pub(super) struct NullBufferBuilder { +pub struct NullBufferBuilder { bitmap_builder: Option, /// Store the length of the buffer before materializing. len: usize, @@ -43,6 +42,29 @@ impl NullBufferBuilder { } } + /// Creates a new builder with given length. + pub fn new_with_len(len: usize) -> Self { + Self { + bitmap_builder: None, + len, + capacity: len, + } + } + + /// Creates a new builder from a `MutableBuffer`. + pub fn new_from_buffer(buffer: MutableBuffer, len: usize) -> Self { + let capacity = buffer.len() * 8; + + assert!(len <= capacity); + + let bitmap_builder = Some(BooleanBufferBuilder::new_from_buffer(buffer, len)); + Self { + bitmap_builder, + len, + capacity, + } + } + /// Appends `n` `true`s into the builder /// to indicate that these `n` items are not nulls. #[inline] @@ -106,14 +128,22 @@ impl NullBufferBuilder { /// Builds the null buffer and resets the builder. /// Returns `None` if the builder only contains `true`s. - pub fn finish(&mut self) -> Option { - let buf = self.bitmap_builder.as_mut().map(|b| b.finish()); - self.bitmap_builder = None; + pub fn finish(&mut self) -> Option { self.len = 0; - buf + Some(NullBuffer::new(self.bitmap_builder.take()?.finish())) + } + + /// Builds the [NullBuffer] without resetting the builder. + pub fn finish_cloned(&self) -> Option { + let buffer = self.bitmap_builder.as_ref()?.finish_cloned(); + Some(NullBuffer::new(buffer)) + } + + /// Returns the inner bitmap builder as slice + pub fn as_slice(&self) -> Option<&[u8]> { + Some(self.bitmap_builder.as_ref()?.as_slice()) } - #[inline] fn materialize_if_needed(&mut self) { if self.bitmap_builder.is_none() { self.materialize() @@ -128,17 +158,28 @@ impl NullBufferBuilder { self.bitmap_builder = Some(b); } } + + /// Return a mutable reference to the inner bitmap slice. + pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> { + self.bitmap_builder.as_mut().map(|b| b.as_slice_mut()) + } + + /// Return the allocated size of this builder, in bytes, useful for memory accounting. + pub fn allocated_size(&self) -> usize { + self.bitmap_builder + .as_ref() + .map(|b| b.capacity()) + .unwrap_or(0) + } } impl NullBufferBuilder { + /// Return the number of bits in the buffer. pub fn len(&self) -> usize { - if let Some(b) = &self.bitmap_builder { - b.len() - } else { - self.len - } + self.bitmap_builder.as_ref().map_or(self.len, |b| b.len()) } + /// Check if the builder is empty. pub fn is_empty(&self) -> bool { self.len() == 0 } @@ -158,7 +199,7 @@ mod tests { assert_eq!(6, builder.len()); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b110010_u8]), buf); + assert_eq!(&[0b110010_u8], buf.validity()); } #[test] @@ -170,7 +211,7 @@ mod tests { assert_eq!(6, builder.len()); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b0_u8]), buf); + assert_eq!(&[0b0_u8], buf.validity()); } #[test] @@ -199,6 +240,6 @@ mod tests { builder.append_slice(&[true, true, false, true]); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b1011_u8]), buf); + assert_eq!(&[0b1011_u8], buf.validity()); } } diff --git a/arrow-buffer/src/builder/offset.rs b/arrow-buffer/src/builder/offset.rs new file mode 100644 index 000000000000..1ef0e3170c96 --- /dev/null +++ b/arrow-buffer/src/builder/offset.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Deref; + +use crate::{ArrowNativeType, OffsetBuffer}; + +/// Builder of [`OffsetBuffer`] +#[derive(Debug)] +pub struct OffsetBufferBuilder { + offsets: Vec, + last_offset: usize, +} + +impl OffsetBufferBuilder { + /// Create a new builder with space for `capacity + 1` offsets + pub fn new(capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(capacity + 1); + offsets.push(O::usize_as(0)); + Self { + offsets, + last_offset: 0, + } + } + + /// Push a slice of `length` bytes + /// + /// # Panics + /// + /// Panics if adding `length` would overflow `usize` + #[inline] + pub fn push_length(&mut self, length: usize) { + self.last_offset = self.last_offset.checked_add(length).expect("overflow"); + self.offsets.push(O::usize_as(self.last_offset)) + } + + /// Reserve space for at least `additional` further offsets + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + } + + /// Takes the builder itself and returns an [`OffsetBuffer`] + /// + /// # Panics + /// + /// Panics if offsets overflow `O` + pub fn finish(self) -> OffsetBuffer { + O::from_usize(self.last_offset).expect("overflow"); + unsafe { OffsetBuffer::new_unchecked(self.offsets.into()) } + } + + /// Builds the [OffsetBuffer] without resetting the builder. + /// + /// # Panics + /// + /// Panics if offsets overflow `O` + pub fn finish_cloned(&self) -> OffsetBuffer { + O::from_usize(self.last_offset).expect("overflow"); + unsafe { OffsetBuffer::new_unchecked(self.offsets.clone().into()) } + } +} + +impl Deref for OffsetBufferBuilder { + type Target = [O]; + + fn deref(&self) -> &Self::Target { + self.offsets.as_ref() + } +} + +#[cfg(test)] +mod tests { + use crate::OffsetBufferBuilder; + + #[test] + fn test_basic() { + let mut builder = OffsetBufferBuilder::::new(5); + assert_eq!(builder.len(), 1); + assert_eq!(&*builder, &[0]); + let finished = builder.finish_cloned(); + assert_eq!(finished.len(), 1); + assert_eq!(&*finished, &[0]); + + builder.push_length(2); + builder.push_length(6); + builder.push_length(0); + builder.push_length(13); + + let finished = builder.finish(); + assert_eq!(&*finished, &[0, 2, 8, 8, 21]); + } + + #[test] + #[should_panic(expected = "overflow")] + fn test_usize_overflow() { + let mut builder = OffsetBufferBuilder::::new(5); + builder.push_length(1); + builder.push_length(usize::MAX); + builder.finish(); + } + + #[test] + #[should_panic(expected = "overflow")] + fn test_i32_overflow() { + let mut builder = OffsetBufferBuilder::::new(5); + builder.push_length(1); + builder.push_length(i32::MAX as usize); + builder.finish(); + } +} diff --git a/arrow/src/bytes.rs b/arrow-buffer/src/bytes.rs similarity index 71% rename from arrow/src/bytes.rs rename to arrow-buffer/src/bytes.rs index 75137a55295b..ba61342d8e39 100644 --- a/arrow/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -23,15 +23,15 @@ use core::slice; use std::ptr::NonNull; use std::{fmt::Debug, fmt::Formatter}; -use crate::alloc; use crate::alloc::Deallocation; /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. +/// /// This structs' API is inspired by the `bytes::Bytes`, but it is not limited to using rust's /// global allocator nor u8 alignment. /// -/// In the most common case, this buffer is allocated using [`allocate_aligned`](crate::alloc::allocate_aligned) -/// and deallocated accordingly [`free_aligned`](crate::alloc::free_aligned). +/// In the most common case, this buffer is allocated using [`alloc`](std::alloc::alloc) +/// with an alignment of [`ALIGNMENT`](crate::alloc::ALIGNMENT) /// /// When the region is allocated by a different allocator, [Deallocation::Custom], this calls the /// custom deallocator to deallocate the region when it is no longer needed. @@ -53,18 +53,14 @@ impl Bytes { /// /// * `ptr` - Pointer to raw parts /// * `len` - Length of raw parts in **bytes** - /// * `capacity` - Total allocated memory for the pointer `ptr`, in **bytes** + /// * `deallocation` - Type of allocation /// /// # Safety /// /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. #[inline] - pub(crate) unsafe fn new( - ptr: std::ptr::NonNull, - len: usize, - deallocation: Deallocation, - ) -> Bytes { + pub(crate) unsafe fn new(ptr: NonNull, len: usize, deallocation: Deallocation) -> Bytes { Bytes { ptr, len, @@ -93,12 +89,17 @@ impl Bytes { pub fn capacity(&self) -> usize { match self.deallocation { - Deallocation::Arrow(capacity) => capacity, - // we cannot determine this in general, - // and thus we state that this is externally-owned memory - Deallocation::Custom(_) => 0, + Deallocation::Standard(layout) => layout.size(), + // we only know the size of the custom allocation + // its underlying capacity might be larger + Deallocation::Custom(_, size) => size, } } + + #[inline] + pub(crate) fn deallocation(&self) -> &Deallocation { + &self.deallocation + } } // Deallocation is Send + Sync, repeating the bound here makes that refactoring safe @@ -110,11 +111,12 @@ impl Drop for Bytes { #[inline] fn drop(&mut self) { match &self.deallocation { - Deallocation::Arrow(capacity) => { - unsafe { alloc::free_aligned::(self.ptr, *capacity) }; - } + Deallocation::Standard(layout) => match layout.size() { + 0 => {} // Nothing to do + _ => unsafe { std::alloc::dealloc(self.ptr.as_ptr(), *layout) }, + }, // The automatic drop implementation will free the memory once the reference count reaches zero - Deallocation::Custom(_allocation) => (), + Deallocation::Custom(_allocation, _size) => (), } } } @@ -142,3 +144,32 @@ impl Debug for Bytes { write!(f, " }}") } } + +impl From for Bytes { + fn from(value: bytes::Bytes) -> Self { + let len = value.len(); + Self { + len, + ptr: NonNull::new(value.as_ptr() as _).unwrap(), + deallocation: Deallocation::Custom(std::sync::Arc::new(value), len), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_bytes() { + let bytes = bytes::Bytes::from(vec![1, 2, 3, 4]); + let arrow_bytes: Bytes = bytes.clone().into(); + + assert_eq!(bytes.as_ptr(), arrow_bytes.as_ptr()); + + drop(bytes); + drop(arrow_bytes); + + let _ = Bytes::from(bytes::Bytes::new()); + } +} diff --git a/arrow-buffer/src/interval.rs b/arrow-buffer/src/interval.rs new file mode 100644 index 000000000000..fa87fec6ea3a --- /dev/null +++ b/arrow-buffer/src/interval.rs @@ -0,0 +1,579 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arith::derive_arith; +use std::ops::Neg; + +/// Value of an IntervalMonthDayNano array +/// +/// ## Representation +/// +/// This type is stored as a single 128 bit integer, interpreted as three +/// different signed integral fields: +/// +/// 1. The number of months (32 bits) +/// 2. The number days (32 bits) +/// 2. The number of nanoseconds (64 bits). +/// +/// Nanoseconds does not allow for leap seconds. +/// +/// Each field is independent (e.g. there is no constraint that the quantity of +/// nanoseconds represents less than a day's worth of time). +/// +/// ```text +/// ┌───────────────┬─────────────┬─────────────────────────────┐ +/// │ Months │ Days │ Nanos │ +/// │ (32 bits) │ (32 bits) │ (64 bits) │ +/// └───────────────┴─────────────┴─────────────────────────────┘ +/// 0 32 64 128 bit offset +/// ``` +/// Please see the [Arrow Spec](https://github.com/apache/arrow/blob/081b4022fe6f659d8765efc82b3f4787c5039e3c/format/Schema.fbs#L409-L415) for more details +/// +///## Note on Comparing and Ordering for Calendar Types +/// +/// Values of `IntervalMonthDayNano` are compared using their binary +/// representation, which can lead to surprising results. +/// +/// Spans of time measured in calendar units are not fixed in absolute size (e.g. +/// number of seconds) which makes defining comparisons and ordering non trivial. +/// For example `1 month` is 28 days for February but `1 month` is 31 days +/// in December. +/// +/// This makes the seemingly simple operation of comparing two intervals +/// complicated in practice. For example is `1 month` more or less than `30 +/// days`? The answer depends on what month you are talking about. +/// +/// This crate defines comparisons for calendar types using their binary +/// representation which is fast and efficient, but leads +/// to potentially surprising results. +/// +/// For example a +/// `IntervalMonthDayNano` of `1 month` will compare as **greater** than a +/// `IntervalMonthDayNano` of `100 days` because the binary representation of `1 month` +/// is larger than the binary representation of 100 days. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] +#[repr(C)] +pub struct IntervalMonthDayNano { + /// Number of months + pub months: i32, + /// Number of days + pub days: i32, + /// Number of nanoseconds + pub nanoseconds: i64, +} + +impl IntervalMonthDayNano { + /// The additive identity i.e. `0`. + pub const ZERO: Self = Self::new(0, 0, 0); + + /// The multiplicative identity, i.e. `1`. + pub const ONE: Self = Self::new(1, 1, 1); + + /// The multiplicative inverse, i.e. `-1`. + pub const MINUS_ONE: Self = Self::new(-1, -1, -1); + + /// The maximum value that can be represented + pub const MAX: Self = Self::new(i32::MAX, i32::MAX, i64::MAX); + + /// The minimum value that can be represented + pub const MIN: Self = Self::new(i32::MIN, i32::MIN, i64::MIN); + + /// Create a new [`IntervalMonthDayNano`] + #[inline] + pub const fn new(months: i32, days: i32, nanoseconds: i64) -> Self { + Self { + months, + days, + nanoseconds, + } + } + + /// Computes the absolute value + #[inline] + pub fn wrapping_abs(self) -> Self { + Self { + months: self.months.wrapping_abs(), + days: self.days.wrapping_abs(), + nanoseconds: self.nanoseconds.wrapping_abs(), + } + } + + /// Computes the absolute value + #[inline] + pub fn checked_abs(self) -> Option { + Some(Self { + months: self.months.checked_abs()?, + days: self.days.checked_abs()?, + nanoseconds: self.nanoseconds.checked_abs()?, + }) + } + + /// Negates the value + #[inline] + pub fn wrapping_neg(self) -> Self { + Self { + months: self.months.wrapping_neg(), + days: self.days.wrapping_neg(), + nanoseconds: self.nanoseconds.wrapping_neg(), + } + } + + /// Negates the value + #[inline] + pub fn checked_neg(self) -> Option { + Some(Self { + months: self.months.checked_neg()?, + days: self.days.checked_neg()?, + nanoseconds: self.nanoseconds.checked_neg()?, + }) + } + + /// Performs wrapping addition + #[inline] + pub fn wrapping_add(self, other: Self) -> Self { + Self { + months: self.months.wrapping_add(other.months), + days: self.days.wrapping_add(other.days), + nanoseconds: self.nanoseconds.wrapping_add(other.nanoseconds), + } + } + + /// Performs checked addition + #[inline] + pub fn checked_add(self, other: Self) -> Option { + Some(Self { + months: self.months.checked_add(other.months)?, + days: self.days.checked_add(other.days)?, + nanoseconds: self.nanoseconds.checked_add(other.nanoseconds)?, + }) + } + + /// Performs wrapping subtraction + #[inline] + pub fn wrapping_sub(self, other: Self) -> Self { + Self { + months: self.months.wrapping_sub(other.months), + days: self.days.wrapping_sub(other.days), + nanoseconds: self.nanoseconds.wrapping_sub(other.nanoseconds), + } + } + + /// Performs checked subtraction + #[inline] + pub fn checked_sub(self, other: Self) -> Option { + Some(Self { + months: self.months.checked_sub(other.months)?, + days: self.days.checked_sub(other.days)?, + nanoseconds: self.nanoseconds.checked_sub(other.nanoseconds)?, + }) + } + + /// Performs wrapping multiplication + #[inline] + pub fn wrapping_mul(self, other: Self) -> Self { + Self { + months: self.months.wrapping_mul(other.months), + days: self.days.wrapping_mul(other.days), + nanoseconds: self.nanoseconds.wrapping_mul(other.nanoseconds), + } + } + + /// Performs checked multiplication + pub fn checked_mul(self, other: Self) -> Option { + Some(Self { + months: self.months.checked_mul(other.months)?, + days: self.days.checked_mul(other.days)?, + nanoseconds: self.nanoseconds.checked_mul(other.nanoseconds)?, + }) + } + + /// Performs wrapping division + #[inline] + pub fn wrapping_div(self, other: Self) -> Self { + Self { + months: self.months.wrapping_div(other.months), + days: self.days.wrapping_div(other.days), + nanoseconds: self.nanoseconds.wrapping_div(other.nanoseconds), + } + } + + /// Performs checked division + pub fn checked_div(self, other: Self) -> Option { + Some(Self { + months: self.months.checked_div(other.months)?, + days: self.days.checked_div(other.days)?, + nanoseconds: self.nanoseconds.checked_div(other.nanoseconds)?, + }) + } + + /// Performs wrapping remainder + #[inline] + pub fn wrapping_rem(self, other: Self) -> Self { + Self { + months: self.months.wrapping_rem(other.months), + days: self.days.wrapping_rem(other.days), + nanoseconds: self.nanoseconds.wrapping_rem(other.nanoseconds), + } + } + + /// Performs checked remainder + pub fn checked_rem(self, other: Self) -> Option { + Some(Self { + months: self.months.checked_rem(other.months)?, + days: self.days.checked_rem(other.days)?, + nanoseconds: self.nanoseconds.checked_rem(other.nanoseconds)?, + }) + } + + /// Performs wrapping exponentiation + #[inline] + pub fn wrapping_pow(self, exp: u32) -> Self { + Self { + months: self.months.wrapping_pow(exp), + days: self.days.wrapping_pow(exp), + nanoseconds: self.nanoseconds.wrapping_pow(exp), + } + } + + /// Performs checked exponentiation + #[inline] + pub fn checked_pow(self, exp: u32) -> Option { + Some(Self { + months: self.months.checked_pow(exp)?, + days: self.days.checked_pow(exp)?, + nanoseconds: self.nanoseconds.checked_pow(exp)?, + }) + } +} + +impl Neg for IntervalMonthDayNano { + type Output = Self; + + #[cfg(debug_assertions)] + fn neg(self) -> Self::Output { + self.checked_neg().expect("IntervalMonthDayNano overflow") + } + + #[cfg(not(debug_assertions))] + fn neg(self) -> Self::Output { + self.wrapping_neg() + } +} + +derive_arith!( + IntervalMonthDayNano, + Add, + AddAssign, + add, + add_assign, + wrapping_add, + checked_add +); +derive_arith!( + IntervalMonthDayNano, + Sub, + SubAssign, + sub, + sub_assign, + wrapping_sub, + checked_sub +); +derive_arith!( + IntervalMonthDayNano, + Mul, + MulAssign, + mul, + mul_assign, + wrapping_mul, + checked_mul +); +derive_arith!( + IntervalMonthDayNano, + Div, + DivAssign, + div, + div_assign, + wrapping_div, + checked_div +); +derive_arith!( + IntervalMonthDayNano, + Rem, + RemAssign, + rem, + rem_assign, + wrapping_rem, + checked_rem +); + +/// Value of an IntervalDayTime array +/// +/// ## Representation +/// +/// This type is stored as a single 64 bit integer, interpreted as two i32 +/// fields: +/// +/// 1. the number of elapsed days +/// 2. The number of milliseconds (no leap seconds), +/// +/// ```text +/// ┌──────────────┬──────────────┐ +/// │ Days │ Milliseconds │ +/// │ (32 bits) │ (32 bits) │ +/// └──────────────┴──────────────┘ +/// 0 31 63 bit offset +/// ``` +/// +/// Please see the [Arrow Spec](https://github.com/apache/arrow/blob/081b4022fe6f659d8765efc82b3f4787c5039e3c/format/Schema.fbs#L406-L408) for more details +/// +/// ## Note on Comparing and Ordering for Calendar Types +/// +/// Values of `IntervalDayTime` are compared using their binary representation, +/// which can lead to surprising results. Please see the description of ordering on +/// [`IntervalMonthDayNano`] for more details +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] +#[repr(C)] +pub struct IntervalDayTime { + /// Number of days + pub days: i32, + /// Number of milliseconds + pub milliseconds: i32, +} + +impl IntervalDayTime { + /// The additive identity i.e. `0`. + pub const ZERO: Self = Self::new(0, 0); + + /// The multiplicative identity, i.e. `1`. + pub const ONE: Self = Self::new(1, 1); + + /// The multiplicative inverse, i.e. `-1`. + pub const MINUS_ONE: Self = Self::new(-1, -1); + + /// The maximum value that can be represented + pub const MAX: Self = Self::new(i32::MAX, i32::MAX); + + /// The minimum value that can be represented + pub const MIN: Self = Self::new(i32::MIN, i32::MIN); + + /// Create a new [`IntervalDayTime`] + #[inline] + pub const fn new(days: i32, milliseconds: i32) -> Self { + Self { days, milliseconds } + } + + /// Computes the absolute value + #[inline] + pub fn wrapping_abs(self) -> Self { + Self { + days: self.days.wrapping_abs(), + milliseconds: self.milliseconds.wrapping_abs(), + } + } + + /// Computes the absolute value + #[inline] + pub fn checked_abs(self) -> Option { + Some(Self { + days: self.days.checked_abs()?, + milliseconds: self.milliseconds.checked_abs()?, + }) + } + + /// Negates the value + #[inline] + pub fn wrapping_neg(self) -> Self { + Self { + days: self.days.wrapping_neg(), + milliseconds: self.milliseconds.wrapping_neg(), + } + } + + /// Negates the value + #[inline] + pub fn checked_neg(self) -> Option { + Some(Self { + days: self.days.checked_neg()?, + milliseconds: self.milliseconds.checked_neg()?, + }) + } + + /// Performs wrapping addition + #[inline] + pub fn wrapping_add(self, other: Self) -> Self { + Self { + days: self.days.wrapping_add(other.days), + milliseconds: self.milliseconds.wrapping_add(other.milliseconds), + } + } + + /// Performs checked addition + #[inline] + pub fn checked_add(self, other: Self) -> Option { + Some(Self { + days: self.days.checked_add(other.days)?, + milliseconds: self.milliseconds.checked_add(other.milliseconds)?, + }) + } + + /// Performs wrapping subtraction + #[inline] + pub fn wrapping_sub(self, other: Self) -> Self { + Self { + days: self.days.wrapping_sub(other.days), + milliseconds: self.milliseconds.wrapping_sub(other.milliseconds), + } + } + + /// Performs checked subtraction + #[inline] + pub fn checked_sub(self, other: Self) -> Option { + Some(Self { + days: self.days.checked_sub(other.days)?, + milliseconds: self.milliseconds.checked_sub(other.milliseconds)?, + }) + } + + /// Performs wrapping multiplication + #[inline] + pub fn wrapping_mul(self, other: Self) -> Self { + Self { + days: self.days.wrapping_mul(other.days), + milliseconds: self.milliseconds.wrapping_mul(other.milliseconds), + } + } + + /// Performs checked multiplication + pub fn checked_mul(self, other: Self) -> Option { + Some(Self { + days: self.days.checked_mul(other.days)?, + milliseconds: self.milliseconds.checked_mul(other.milliseconds)?, + }) + } + + /// Performs wrapping division + #[inline] + pub fn wrapping_div(self, other: Self) -> Self { + Self { + days: self.days.wrapping_div(other.days), + milliseconds: self.milliseconds.wrapping_div(other.milliseconds), + } + } + + /// Performs checked division + pub fn checked_div(self, other: Self) -> Option { + Some(Self { + days: self.days.checked_div(other.days)?, + milliseconds: self.milliseconds.checked_div(other.milliseconds)?, + }) + } + + /// Performs wrapping remainder + #[inline] + pub fn wrapping_rem(self, other: Self) -> Self { + Self { + days: self.days.wrapping_rem(other.days), + milliseconds: self.milliseconds.wrapping_rem(other.milliseconds), + } + } + + /// Performs checked remainder + pub fn checked_rem(self, other: Self) -> Option { + Some(Self { + days: self.days.checked_rem(other.days)?, + milliseconds: self.milliseconds.checked_rem(other.milliseconds)?, + }) + } + + /// Performs wrapping exponentiation + #[inline] + pub fn wrapping_pow(self, exp: u32) -> Self { + Self { + days: self.days.wrapping_pow(exp), + milliseconds: self.milliseconds.wrapping_pow(exp), + } + } + + /// Performs checked exponentiation + #[inline] + pub fn checked_pow(self, exp: u32) -> Option { + Some(Self { + days: self.days.checked_pow(exp)?, + milliseconds: self.milliseconds.checked_pow(exp)?, + }) + } +} + +impl Neg for IntervalDayTime { + type Output = Self; + + #[cfg(debug_assertions)] + fn neg(self) -> Self::Output { + self.checked_neg().expect("IntervalDayMillisecond overflow") + } + + #[cfg(not(debug_assertions))] + fn neg(self) -> Self::Output { + self.wrapping_neg() + } +} + +derive_arith!( + IntervalDayTime, + Add, + AddAssign, + add, + add_assign, + wrapping_add, + checked_add +); +derive_arith!( + IntervalDayTime, + Sub, + SubAssign, + sub, + sub_assign, + wrapping_sub, + checked_sub +); +derive_arith!( + IntervalDayTime, + Mul, + MulAssign, + mul, + mul_assign, + wrapping_mul, + checked_mul +); +derive_arith!( + IntervalDayTime, + Div, + DivAssign, + div, + div_assign, + wrapping_div, + checked_div +); +derive_arith!( + IntervalDayTime, + Rem, + RemAssign, + rem, + rem_assign, + wrapping_rem, + checked_rem +); diff --git a/arrow-buffer/src/lib.rs b/arrow-buffer/src/lib.rs new file mode 100644 index 000000000000..34e432208ada --- /dev/null +++ b/arrow-buffer/src/lib.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Low-level buffer abstractions for [Apache Arrow Rust](https://docs.rs/arrow) + +// used by [`buffer::mutable::dangling_ptr`] +#![cfg_attr(miri, feature(strict_provenance))] +#![warn(missing_docs)] + +pub mod alloc; +pub mod buffer; +pub use buffer::*; + +pub mod builder; +pub use builder::*; + +mod bigint; +pub use bigint::i256; + +mod bytes; + +mod native; +pub use native::*; + +mod util; +pub use util::*; + +mod interval; +pub use interval::*; + +mod arith; diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs new file mode 100644 index 000000000000..c563f73cf5b9 --- /dev/null +++ b/arrow-buffer/src/native.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{i256, IntervalDayTime, IntervalMonthDayNano}; +use half::f16; + +mod private { + pub trait Sealed {} +} + +/// Trait expressing a Rust type that has the same in-memory representation as +/// Arrow. +/// +/// This includes `i16`, `f32`, but excludes `bool` (which in arrow is +/// represented in bits). +/// +/// In little endian machines, types that implement [`ArrowNativeType`] can be +/// memcopied to arrow buffers as is. +/// +/// # Transmute Safety +/// +/// A type T implementing this trait means that any arbitrary slice of bytes of length and +/// alignment `size_of::()` can be safely interpreted as a value of that type without +/// being unsound, i.e. potentially resulting in undefined behaviour. +/// +/// Note: in the case of floating point numbers this transmutation can result in a signalling +/// NaN, which, whilst sound, can be unwieldy. In general, whilst it is perfectly sound to +/// reinterpret bytes as different types using this trait, it is likely unwise. For more information +/// see [f32::from_bits] and [f64::from_bits]. +/// +/// Note: `bool` is restricted to `0` or `1`, and so `bool: !ArrowNativeType` +/// +/// # Sealed +/// +/// Due to the above restrictions, this trait is sealed to prevent accidental misuse +pub trait ArrowNativeType: + std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static +{ + /// Returns the byte width of this native type. + fn get_byte_width() -> usize { + std::mem::size_of::() + } + + /// Convert native integer type from usize + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn from_usize(_: usize) -> Option; + + /// Convert to usize according to the [`as`] operator + /// + /// [`as`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + fn as_usize(self) -> usize; + + /// Convert from usize according to the [`as`] operator + /// + /// [`as`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + fn usize_as(i: usize) -> Self; + + /// Convert native type to usize. + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn to_usize(self) -> Option; + + /// Convert native type to isize. + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn to_isize(self) -> Option; + + /// Convert native type to i64. + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn to_i64(self) -> Option; + + /// Convert native type from i32. + /// + /// Returns `None` if [`Self`] is not `i32` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i32(_: i32) -> Option { + None + } + + /// Convert native type from i64. + /// + /// Returns `None` if [`Self`] is not `i64` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i64(_: i64) -> Option { + None + } + + /// Convert native type from i128. + /// + /// Returns `None` if [`Self`] is not `i128` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i128(_: i128) -> Option { + None + } +} + +macro_rules! native_integer { + ($t: ty $(, $from:ident)*) => { + impl private::Sealed for $t {} + impl ArrowNativeType for $t { + #[inline] + fn from_usize(v: usize) -> Option { + v.try_into().ok() + } + + #[inline] + fn to_usize(self) -> Option { + self.try_into().ok() + } + + #[inline] + fn to_isize(self) -> Option { + self.try_into().ok() + } + + #[inline] + fn to_i64(self) -> Option { + self.try_into().ok() + } + + #[inline] + fn as_usize(self) -> usize { + self as _ + } + + #[inline] + fn usize_as(i: usize) -> Self { + i as _ + } + + + $( + #[inline] + fn $from(v: $t) -> Option { + Some(v) + } + )* + } + }; +} + +native_integer!(i8); +native_integer!(i16); +native_integer!(i32, from_i32); +native_integer!(i64, from_i64); +native_integer!(i128, from_i128); +native_integer!(u8); +native_integer!(u16); +native_integer!(u32); +native_integer!(u64); +native_integer!(u128); + +macro_rules! native_float { + ($t:ty, $s:ident, $as_usize: expr, $i:ident, $usize_as: expr) => { + impl private::Sealed for $t {} + impl ArrowNativeType for $t { + #[inline] + fn from_usize(_: usize) -> Option { + None + } + + #[inline] + fn to_usize(self) -> Option { + None + } + + #[inline] + fn to_isize(self) -> Option { + None + } + + #[inline] + fn to_i64(self) -> Option { + None + } + + #[inline] + fn as_usize($s) -> usize { + $as_usize + } + + #[inline] + fn usize_as($i: usize) -> Self { + $usize_as + } + } + }; +} + +native_float!(f16, self, self.to_f32() as _, i, f16::from_f32(i as _)); +native_float!(f32, self, self as _, i, i as _); +native_float!(f64, self, self as _, i, i as _); + +impl private::Sealed for i256 {} +impl ArrowNativeType for i256 { + fn from_usize(u: usize) -> Option { + Some(Self::from_parts(u as u128, 0)) + } + + fn as_usize(self) -> usize { + self.to_parts().0 as usize + } + + fn usize_as(i: usize) -> Self { + Self::from_parts(i as u128, 0) + } + + fn to_usize(self) -> Option { + let (low, high) = self.to_parts(); + if high != 0 { + return None; + } + low.try_into().ok() + } + + fn to_isize(self) -> Option { + self.to_i128()?.try_into().ok() + } + + fn to_i64(self) -> Option { + self.to_i128()?.try_into().ok() + } +} + +impl private::Sealed for IntervalMonthDayNano {} +impl ArrowNativeType for IntervalMonthDayNano { + fn from_usize(_: usize) -> Option { + None + } + + fn as_usize(self) -> usize { + ((self.months as u64) | ((self.days as u64) << 32)) as usize + } + + fn usize_as(i: usize) -> Self { + Self::new(i as _, ((i as u64) >> 32) as _, 0) + } + + fn to_usize(self) -> Option { + None + } + + fn to_isize(self) -> Option { + None + } + + fn to_i64(self) -> Option { + None + } +} + +impl private::Sealed for IntervalDayTime {} +impl ArrowNativeType for IntervalDayTime { + fn from_usize(_: usize) -> Option { + None + } + + fn as_usize(self) -> usize { + ((self.days as u64) | ((self.milliseconds as u64) << 32)) as usize + } + + fn usize_as(i: usize) -> Self { + Self::new(i as _, ((i as u64) >> 32) as _) + } + + fn to_usize(self) -> Option { + None + } + + fn to_isize(self) -> Option { + None + } + + fn to_i64(self) -> Option { + None + } +} + +/// Allows conversion from supported Arrow types to a byte slice. +pub trait ToByteSlice { + /// Converts this instance into a byte slice + fn to_byte_slice(&self) -> &[u8]; +} + +impl ToByteSlice for [T] { + #[inline] + fn to_byte_slice(&self) -> &[u8] { + let raw_ptr = self.as_ptr() as *const u8; + unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of_val(self)) } + } +} + +impl ToByteSlice for T { + #[inline] + fn to_byte_slice(&self) -> &[u8] { + let raw_ptr = self as *const T as *const u8; + unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of::()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_i256() { + let a = i256::from_parts(0, 0); + assert_eq!(a.as_usize(), 0); + assert_eq!(a.to_usize().unwrap(), 0); + assert_eq!(a.to_isize().unwrap(), 0); + + let a = i256::from_parts(0, -1); + assert_eq!(a.as_usize(), 0); + assert!(a.to_usize().is_none()); + assert!(a.to_usize().is_none()); + + let a = i256::from_parts(u128::MAX, -1); + assert_eq!(a.as_usize(), usize::MAX); + assert!(a.to_usize().is_none()); + assert_eq!(a.to_isize().unwrap(), -1); + } + + #[test] + fn test_interval_usize() { + assert_eq!(IntervalDayTime::new(1, 0).as_usize(), 1); + assert_eq!(IntervalMonthDayNano::new(1, 0, 0).as_usize(), 1); + + let a = IntervalDayTime::new(23, 53); + let b = IntervalDayTime::usize_as(a.as_usize()); + assert_eq!(a, b); + + let a = IntervalMonthDayNano::new(23, 53, 0); + let b = IntervalMonthDayNano::usize_as(a.as_usize()); + assert_eq!(a, b); + } +} diff --git a/arrow/src/util/bit_chunk_iterator.rs b/arrow-buffer/src/util/bit_chunk_iterator.rs similarity index 92% rename from arrow/src/util/bit_chunk_iterator.rs rename to arrow-buffer/src/util/bit_chunk_iterator.rs index f0127ed2267f..54995314c49b 100644 --- a/arrow/src/util/bit_chunk_iterator.rs +++ b/arrow-buffer/src/util/bit_chunk_iterator.rs @@ -60,8 +60,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 8 bytes, read into prefix if buffer.len() <= 8 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(buffer) & suffix_mask & prefix_mask; return Self { @@ -75,8 +74,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 16 bytes, read into prefix and suffix if buffer.len() <= 16 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(&buffer[..8]) & prefix_mask; let suffix = read_u64(&buffer[8..]) & suffix_mask; @@ -133,31 +131,37 @@ impl<'a> UnalignedBitChunk<'a> { } } + /// Returns the number of leading padding bits pub fn lead_padding(&self) -> usize { self.lead_padding } + /// Returns the number of trailing padding bits pub fn trailing_padding(&self) -> usize { self.trailing_padding } + /// Returns the prefix, if any pub fn prefix(&self) -> Option { self.prefix } + /// Returns the suffix, if any pub fn suffix(&self) -> Option { self.suffix } + /// Returns reference to the chunks pub fn chunks(&self) -> &'a [u64] { self.chunks } - pub(crate) fn iter(&self) -> UnalignedBitChunkIterator<'a> { + /// Returns an iterator over the chunks + pub fn iter(&self) -> UnalignedBitChunkIterator<'a> { self.prefix .into_iter() .chain(self.chunks.iter().cloned()) - .chain(self.suffix.into_iter()) + .chain(self.suffix) } /// Counts the number of ones @@ -166,11 +170,9 @@ impl<'a> UnalignedBitChunk<'a> { } } -pub(crate) type UnalignedBitChunkIterator<'a> = std::iter::Chain< - std::iter::Chain< - std::option::IntoIter, - std::iter::Cloned>, - >, +/// Iterator over an [`UnalignedBitChunk`] +pub type UnalignedBitChunkIterator<'a> = std::iter::Chain< + std::iter::Chain, std::iter::Cloned>>, std::option::IntoIter, >; @@ -178,7 +180,7 @@ pub(crate) type UnalignedBitChunkIterator<'a> = std::iter::Chain< fn read_u64(input: &[u8]) -> u64 { let len = input.len().min(8); let mut buf = [0_u8; 8]; - (&mut buf[..len]).copy_from_slice(input); + buf[..len].copy_from_slice(input); u64::from_le_bytes(buf) } @@ -217,6 +219,7 @@ pub struct BitChunks<'a> { } impl<'a> BitChunks<'a> { + /// Create a new [`BitChunks`] from a byte array, and an offset and length in bits pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { assert!(ceil(offset + len, 8) <= buffer.len() * 8); @@ -237,6 +240,7 @@ impl<'a> BitChunks<'a> { } } +/// Iterator over chunks of 64 bits represented as an u64 #[derive(Debug)] pub struct BitChunkIterator<'a> { buffer: &'a [u8], @@ -296,6 +300,12 @@ impl<'a> BitChunks<'a> { index: 0, } } + + /// Returns an iterator over chunks of 64 bits, with the remaining bits zero padded to 64-bits + #[inline] + pub fn iter_padded(&self) -> impl Iterator + 'a { + self.iter().chain(std::iter::once(self.remainder_bits())) + } } impl<'a> IntoIterator for BitChunks<'a> { @@ -332,9 +342,8 @@ impl Iterator for BitChunkIterator<'_> { } else { // the constructor ensures that bit_offset is in 0..8 // that means we need to read at most one additional byte to fill in the high bits - let next = unsafe { - std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 - }; + let next = + unsafe { std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 }; (current >> bit_offset) | (next << (64 - bit_offset)) }; @@ -381,8 +390,8 @@ mod tests { #[test] fn test_iter_unaligned() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -402,8 +411,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_1_byte() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -436,8 +445,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_bits_large() { let input: &[u8] = &[ - 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, - 0b11111111, 0b00000000, 0b11111111, + 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, + 0b00000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -456,7 +465,7 @@ mod tests { const ALLOC_SIZE: usize = 4 * 1024; let input = vec![0xFF_u8; ALLOC_SIZE]; - let buffer: Buffer = Buffer::from(input); + let buffer: Buffer = Buffer::from_vec(input); let bitchunks = buffer.bit_chunks(57, ALLOC_SIZE * 8 - 57); @@ -631,11 +640,8 @@ mod tests { let max_truncate = 128.min(mask_len - offset); let truncate = rng.gen::().checked_rem(max_truncate).unwrap_or(0); - let unaligned = UnalignedBitChunk::new( - buffer.as_slice(), - offset, - mask_len - offset - truncate, - ); + let unaligned = + UnalignedBitChunk::new(buffer.as_slice(), offset, mask_len - offset - truncate); let bool_slice = &bools[offset..mask_len - truncate]; diff --git a/arrow/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs similarity index 53% rename from arrow/src/util/bit_iterator.rs rename to arrow-buffer/src/util/bit_iterator.rs index bba9dac60a4b..df40a8fbaccb 100644 --- a/arrow/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -15,7 +15,73 @@ // specific language governing permissions and limitations // under the License. -use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +//! Types for iterating over packed bitmasks + +use crate::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +use crate::bit_util::{ceil, get_bit_raw}; + +/// Iterator over the bits within a packed bitmask +/// +/// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +pub struct BitIterator<'a> { + buffer: &'a [u8], + current_offset: usize, + end_offset: usize, +} + +impl<'a> BitIterator<'a> { + /// Create a new [`BitIterator`] from the provided `buffer`, + /// and `offset` and `len` in bits + /// + /// # Panic + /// + /// Panics if `buffer` is too short for the provided offset and length + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + let end_offset = offset.checked_add(len).unwrap(); + let required_len = ceil(end_offset, 8); + assert!( + buffer.len() >= required_len, + "BitIterator buffer too small, expected {required_len} got {}", + buffer.len() + ); + + Self { + buffer, + current_offset: offset, + end_offset, + } + } +} + +impl<'a> Iterator for BitIterator<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + if self.current_offset == self.end_offset { + return None; + } + // Safety: + // offsets in bounds + let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.current_offset) }; + self.current_offset += 1; + Some(v) + } +} + +impl<'a> ExactSizeIterator for BitIterator<'a> {} + +impl<'a> DoubleEndedIterator for BitIterator<'a> { + fn next_back(&mut self) -> Option { + if self.current_offset == self.end_offset { + return None; + } + self.end_offset -= 1; + // Safety: + // offsets in bounds + let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; + Some(v) + } +} /// Iterator of contiguous ranges of set bits within a provided packed bitmask /// @@ -31,7 +97,7 @@ pub struct BitSliceIterator<'a> { } impl<'a> BitSliceIterator<'a> { - /// Create a new [`BitSliceIterator`] from the provide `buffer`, + /// Create a new [`BitSliceIterator`] from the provided `buffer`, /// and `offset` and `len` in bits pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { let chunk = UnalignedBitChunk::new(buffer, offset, len); @@ -157,4 +223,72 @@ impl<'a> Iterator for BitIndexIterator<'a> { } } -// Note: tests located in filter module +/// Calls the provided closure for each index in the provided null mask that is set, +/// using an adaptive strategy based on the null count +/// +/// Ideally this would be encapsulated in an [`Iterator`] that would determine the optimal +/// strategy up front, and then yield indexes based on this. +/// +/// Unfortunately, external iteration based on the resulting [`Iterator`] would match the strategy +/// variant on each call to [`Iterator::next`], and LLVM generally cannot eliminate this. +/// +/// One solution to this might be internal iteration, e.g. [`Iterator::try_fold`], however, +/// it is currently [not possible] to override this for custom iterators in stable Rust. +/// +/// As such this is the next best option +/// +/// [not possible]: https://github.com/rust-lang/rust/issues/69595 +#[inline] +pub fn try_for_each_valid_idx Result<(), E>>( + len: usize, + offset: usize, + null_count: usize, + nulls: Option<&[u8]>, + f: F, +) -> Result<(), E> { + let valid_count = len - null_count; + + if valid_count == len { + (0..len).try_for_each(f) + } else if null_count != len { + BitIndexIterator::new(nulls.unwrap(), offset, len).try_for_each(f) + } else { + Ok(()) + } +} + +// Note: further tests located in arrow_select::filter module + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bit_iterator() { + let mask = &[0b00010010, 0b00100011, 0b00000101, 0b00010001, 0b10010011]; + let actual: Vec<_> = BitIterator::new(mask, 0, 5).collect(); + assert_eq!(actual, &[false, true, false, false, true]); + + let actual: Vec<_> = BitIterator::new(mask, 4, 5).collect(); + assert_eq!(actual, &[true, false, false, false, true]); + + let actual: Vec<_> = BitIterator::new(mask, 12, 14).collect(); + assert_eq!( + actual, + &[ + false, true, false, false, true, false, true, false, false, false, false, false, + true, false + ] + ); + + assert_eq!(BitIterator::new(mask, 0, 0).count(), 0); + assert_eq!(BitIterator::new(mask, 40, 0).count(), 0); + } + + #[test] + #[should_panic(expected = "BitIterator buffer too small, expected 3 got 2")] + fn test_bit_iterator_bounds() { + let mask = &[223, 23]; + BitIterator::new(mask, 17, 0); + } +} diff --git a/arrow-buffer/src/util/bit_mask.rs b/arrow-buffer/src/util/bit_mask.rs new file mode 100644 index 000000000000..d4c2fa4744e1 --- /dev/null +++ b/arrow-buffer/src/util/bit_mask.rs @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils for working with packed bit masks + +use crate::bit_util::ceil; + +/// Util function to set bits in a slice of bytes. +/// +/// This will sets all bits on `write_data` in the range `[offset_write..offset_write+len]` +/// to be equal to the bits in `data` in the range `[offset_read..offset_read+len]` +/// returns the number of `0` bits `data[offset_read..offset_read+len]` +/// `offset_write`, `offset_read`, and `len` are in terms of bits +pub fn set_bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> usize { + assert!(offset_write + len <= write_data.len() * 8); + assert!(offset_read + len <= data.len() * 8); + let mut null_count = 0; + let mut acc = 0; + while len > acc { + // SAFETY: the arguments to `set_upto_64bits` are within the valid range because + // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 + // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 + let (n, len_set) = unsafe { + set_upto_64bits( + write_data, + data, + offset_write + acc, + offset_read + acc, + len - acc, + ) + }; + null_count += n; + acc += len_set; + } + + null_count +} + +/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary. +/// Returns a pair of the number of `0` bits and the number of bits set +/// +/// # Safety +/// The caller must ensure all arguments are within the valid range. +#[inline] +unsafe fn set_upto_64bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> (usize, usize) { + let read_byte = offset_read / 8; + let read_shift = offset_read % 8; + let write_byte = offset_write / 8; + let write_shift = offset_write % 8; + + if len >= 64 { + let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; + if read_shift == 0 { + if write_shift == 0 { + // no shifting necessary + let len = 64; + let null_count = chunk.count_zeros() as usize; + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } else { + // only write shifting necessary + let len = 64 - write_shift; + let chunk = chunk << write_shift; + let null_count = len - chunk.count_ones() as usize; + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } + } else if write_shift == 0 { + // only read shifting necessary + let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 + let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask + let null_count = len - chunk.count_ones() as usize; + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } else { + let len = 64 - std::cmp::max(read_shift, write_shift); + let chunk = (chunk >> read_shift) << write_shift; + let null_count = len - chunk.count_ones() as usize; + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } + } else if len == 1 { + let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; + unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift }; + ((byte_chunk ^ 1) as usize, 1) + } else { + let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); + let bytes = ceil(len + read_shift, 8); + // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len() + let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; + let mask = u64::MAX >> (64 - len); + let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only + let chunk = chunk << write_shift; // shifting back to align with `write_data` + let null_count = len - chunk.count_ones() as usize; + let bytes = ceil(len + write_shift, 8); + for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) { + unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c }; + } + (null_count, len) + } +} + +/// # Safety +/// The caller must ensure all arguments are within the valid range. +#[inline] +unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { + debug_assert!(count <= 8); + let mut tmp = std::mem::MaybeUninit::::new(0); + let src = data.as_ptr().add(offset); + unsafe { + std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count); + tmp.assume_init() + } +} + +/// # Safety +/// The caller must ensure `data` has `offset..(offset + 8)` range +#[inline] +unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { + let ptr = data.as_mut_ptr().add(offset) as *mut u64; + ptr.write_unaligned(chunk); +} + +/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` +/// instead of overwriting +/// +/// # Safety +/// The caller must ensure `data` has `offset..(offset + 8)` range +#[inline] +unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { + let ptr = data.as_mut_ptr().add(offset); + let chunk = chunk | (*ptr) as u64; + (ptr as *mut u64).write_unaligned(chunk); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bit_util::{get_bit, set_bit, unset_bit}; + use rand::prelude::StdRng; + use rand::{Fill, Rng, SeedableRng}; + use std::fmt::Display; + + #[test] + fn test_set_bits_aligned() { + SetBitsTest { + write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + data: vec![ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, + ], + offset_write: 8, + offset_read: 0, + len: 64, + expected_data: vec![ + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0, + ], + expected_null_count: 24, + } + .verify(); + } + + #[test] + fn test_set_bits_unaligned_destination_start() { + SetBitsTest { + write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + data: vec![ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, + ], + offset_write: 3, + offset_read: 0, + len: 64, + expected_data: vec![ + 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110, + 0b00101111, 0b00000101, 0b00000000, + ], + expected_null_count: 24, + } + .verify(); + } + + #[test] + fn test_set_bits_unaligned_destination_end() { + SetBitsTest { + write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + data: vec![ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, + ], + offset_write: 8, + offset_read: 0, + len: 62, + expected_data: vec![ + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b00100101, 0, + ], + expected_null_count: 23, + } + .verify(); + } + + #[test] + fn test_set_bits_unaligned() { + SetBitsTest { + write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + data: vec![ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101, + 0b10011001, 0b11011011, 0b11101011, 0b11000011, + ], + offset_write: 3, + offset_read: 5, + len: 95, + expected_data: vec![ + 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001, + 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001, + ], + expected_null_count: 35, + } + .verify(); + } + + #[test] + fn set_bits_fuzz() { + let mut rng = StdRng::seed_from_u64(42); + let mut data = SetBitsTest::new(); + for _ in 0..100 { + data.regen(&mut rng); + data.verify(); + } + } + + #[derive(Debug, Default)] + struct SetBitsTest { + /// target write data + write_data: Vec, + /// source data + data: Vec, + offset_write: usize, + offset_read: usize, + len: usize, + /// the expected contents of write_data after the test + expected_data: Vec, + /// the expected number of nulls copied at the end of the test + expected_null_count: usize, + } + + /// prints a byte slice as a binary string like "01010101 10101010" + struct BinaryFormatter<'a>(&'a [u8]); + impl<'a> Display for BinaryFormatter<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for byte in self.0 { + write!(f, "{:08b} ", byte)?; + } + write!(f, " ")?; + Ok(()) + } + } + + impl Display for SetBitsTest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "SetBitsTest {{")?; + writeln!(f, " write_data: {}", BinaryFormatter(&self.write_data))?; + writeln!(f, " data: {}", BinaryFormatter(&self.data))?; + writeln!( + f, + " expected_data: {}", + BinaryFormatter(&self.expected_data) + )?; + writeln!(f, " offset_write: {}", self.offset_write)?; + writeln!(f, " offset_read: {}", self.offset_read)?; + writeln!(f, " len: {}", self.len)?; + writeln!(f, " expected_null_count: {}", self.expected_null_count)?; + writeln!(f, "}}") + } + } + + impl SetBitsTest { + /// create a new instance of FuzzData + fn new() -> Self { + Self::default() + } + + /// Update this instance's fields with randomly selected values and expected data + fn regen(&mut self, rng: &mut StdRng) { + // (read) data + // ------------------+-----------------+------- + // .. offset_read .. | data | ... + // ------------------+-----------------+------- + + // Write data + // -------------------+-----------------+------- + // .. offset_write .. | (data to write) | ... + // -------------------+-----------------+------- + + // length of data to copy + let len = rng.gen_range(0..=200); + + // randomly pick where we will write to + let offset_write_bits = rng.gen_range(0..=200); + let offset_write_bytes = if offset_write_bits % 8 == 0 { + offset_write_bits / 8 + } else { + (offset_write_bits / 8) + 1 + }; + let extra_write_data_bytes = rng.gen_range(0..=5); // ensure 0 shows up often + + // randomly decide where we will read from + let extra_read_data_bytes = rng.gen_range(0..=5); // make sure 0 shows up often + let offset_read_bits = rng.gen_range(0..=200); + let offset_read_bytes = if offset_read_bits % 8 != 0 { + (offset_read_bits / 8) + 1 + } else { + offset_read_bits / 8 + }; + + // create space for writing + self.write_data.clear(); + self.write_data + .resize(offset_write_bytes + len + extra_write_data_bytes, 0); + + // interestingly set_bits seems to assume the output is already zeroed + // the fuzz tests fail when this is uncommented + //self.write_data.try_fill(rng).unwrap(); + self.offset_write = offset_write_bits; + + // make source data + self.data + .resize(offset_read_bytes + len + extra_read_data_bytes, 0); + // fill source data with random bytes + self.data.try_fill(rng).unwrap(); + self.offset_read = offset_read_bits; + + self.len = len; + + // generated expectated output (not efficient) + self.expected_data.resize(self.write_data.len(), 0); + self.expected_data.copy_from_slice(&self.write_data); + + self.expected_null_count = 0; + for i in 0..self.len { + let bit = get_bit(&self.data, self.offset_read + i); + if bit { + set_bit(&mut self.expected_data, self.offset_write + i); + } else { + unset_bit(&mut self.expected_data, self.offset_write + i); + self.expected_null_count += 1; + } + } + } + + /// call set_bits with the given parameters and compare with the expected output + fn verify(&self) { + // call set_bits and compare + let mut actual = self.write_data.to_vec(); + let null_count = set_bits( + &mut actual, + &self.data, + self.offset_write, + self.offset_read, + self.len, + ); + + assert_eq!(actual, self.expected_data, "self: {}", self); + assert_eq!(null_count, self.expected_null_count, "self: {}", self); + } + } + + #[test] + fn test_set_upto_64bits() { + // len >= 64 + let write_data: &mut [u8] = &mut [0; 9]; + let data: &[u8] = &[ + 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, + 0b00000001, 0b00000001, + ]; + let offset_write = 1; + let offset_read = 0; + let len = 65; + let (n, len_set) = + unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; + assert_eq!(n, 55); + assert_eq!(len_set, 63); + assert_eq!( + write_data, + &[ + 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, + 0b00000010, 0b00000000 + ] + ); + + // len = 1 + let write_data: &mut [u8] = &mut [0b00000000]; + let data: &[u8] = &[0b00000001]; + let offset_write = 1; + let offset_read = 0; + let len = 1; + let (n, len_set) = + unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; + assert_eq!(n, 0); + assert_eq!(len_set, 1); + assert_eq!(write_data, &[0b00000010]); + } +} diff --git a/arrow/src/util/bit_util.rs b/arrow-buffer/src/util/bit_util.rs similarity index 77% rename from arrow/src/util/bit_util.rs rename to arrow-buffer/src/util/bit_util.rs index 5752c5df972e..bf14525bbd6b 100644 --- a/arrow/src/util/bit_util.rs +++ b/arrow-buffer/src/util/bit_util.rs @@ -17,22 +17,6 @@ //! Utils for working with bits -use num::Integer; -#[cfg(feature = "simd")] -use packed_simd::u8x64; - -const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; -const UNSET_BIT_MASK: [u8; 8] = [ - 255 - 1, - 255 - 2, - 255 - 4, - 255 - 8, - 255 - 16, - 255 - 32, - 255 - 64, - 255 - 128, -]; - /// Returns the nearest number that is `>=` than `num` and is a multiple of 64 #[inline] pub fn round_upto_multiple_of_64(num: usize) -> usize { @@ -43,13 +27,15 @@ pub fn round_upto_multiple_of_64(num: usize) -> usize { /// be a power of 2. pub fn round_upto_power_of_2(num: usize, factor: usize) -> usize { debug_assert!(factor > 0 && (factor & (factor - 1)) == 0); - (num + (factor - 1)) & !(factor - 1) + num.checked_add(factor - 1) + .expect("failed to round to next highest power of 2") + & !(factor - 1) } /// Returns whether bit at position `i` in `data` is set or not #[inline] pub fn get_bit(data: &[u8], i: usize) -> bool { - (data[i >> 3] & BIT_MASK[i & 7]) != 0 + data[i / 8] & (1 << (i % 8)) != 0 } /// Returns whether bit at position `i` in `data` is set or not. @@ -60,13 +46,13 @@ pub fn get_bit(data: &[u8], i: usize) -> bool { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn get_bit_raw(data: *const u8, i: usize) -> bool { - (*data.add(i >> 3) & BIT_MASK[i & 7]) != 0 + (*data.add(i / 8) & (1 << (i % 8))) != 0 } /// Sets bit at position `i` for `data` to 1 #[inline] pub fn set_bit(data: &mut [u8], i: usize) { - data[i >> 3] |= BIT_MASK[i & 7]; + data[i / 8] |= 1 << (i % 8); } /// Sets bit at position `i` for `data` @@ -77,13 +63,13 @@ pub fn set_bit(data: &mut [u8], i: usize) { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn set_bit_raw(data: *mut u8, i: usize) { - *data.add(i >> 3) |= BIT_MASK[i & 7]; + *data.add(i / 8) |= 1 << (i % 8); } /// Sets bit at position `i` for `data` to 0 #[inline] pub fn unset_bit(data: &mut [u8], i: usize) { - data[i >> 3] &= UNSET_BIT_MASK[i & 7]; + data[i / 8] &= !(1 << (i % 8)); } /// Sets bit at position `i` for `data` to 0 @@ -94,7 +80,7 @@ pub fn unset_bit(data: &mut [u8], i: usize) { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { - *data.add(i >> 3) &= UNSET_BIT_MASK[i & 7]; + *data.add(i / 8) &= !(1 << (i % 8)); } /// Returns the ceil of `value`/`divisor` @@ -102,34 +88,16 @@ pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { pub fn ceil(value: usize, divisor: usize) -> usize { // Rewrite as `value.div_ceil(&divisor)` after // https://github.com/rust-lang/rust/issues/88581 is merged. - Integer::div_ceil(&value, &divisor) + value / divisor + (0 != value % divisor) as usize } -/// Performs SIMD bitwise binary operations. -/// -/// # Safety -/// -/// Note that each slice should be 64 bytes and it is the callers responsibility to ensure -/// that this is the case. If passed slices larger than 64 bytes the operation will only -/// be performed on the first 64 bytes. Slices less than 64 bytes will panic. -#[cfg(feature = "simd")] -pub unsafe fn bitwise_bin_op_simd(left: &[u8], right: &[u8], result: &mut [u8], op: F) -where - F: Fn(u8x64, u8x64) -> u8x64, -{ - let left_simd = u8x64::from_slice_unaligned_unchecked(left); - let right_simd = u8x64::from_slice_unaligned_unchecked(right); - let simd_result = op(left_simd, right_simd); - simd_result.write_to_slice_unaligned_unchecked(result); -} - -#[cfg(all(test, feature = "test_utils"))] +#[cfg(test)] mod tests { use std::collections::HashSet; use super::*; - use crate::util::test_util::seedable_rng; - use rand::Rng; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; #[test] fn test_round_upto_multiple_of_64() { @@ -141,6 +109,12 @@ mod tests { assert_eq!(192, round_upto_multiple_of_64(129)); } + #[test] + #[should_panic(expected = "failed to round to next highest power of 2")] + fn test_round_upto_panic() { + let _ = round_upto_power_of_2(usize::MAX, 2); + } + #[test] fn test_get_bit() { // 00001101 @@ -168,10 +142,14 @@ mod tests { assert!(!get_bit(&[0b01001001, 0b01010010], 15)); } + pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) + } + #[test] fn test_get_bit_raw() { const NUM_BYTE: usize = 10; - let mut buf = vec![0; NUM_BYTE]; + let mut buf = [0; NUM_BYTE]; let mut expected = vec![]; let mut rng = seedable_rng(); for i in 0..8 * NUM_BYTE { @@ -279,7 +257,6 @@ mod tests { } #[test] - #[cfg(all(any(target_arch = "x86", target_arch = "x86_64")))] fn test_ceil() { assert_eq!(ceil(0, 1), 0); assert_eq!(ceil(1, 1), 1); @@ -293,28 +270,4 @@ mod tests { assert_eq!(ceil(10, 10000000000), 1); assert_eq!(ceil(10000000000, 1000000000), 10); } - - #[test] - #[cfg(feature = "simd")] - fn test_bitwise_and_simd() { - let buf1 = [0b00110011u8; 64]; - let buf2 = [0b11110000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { bitwise_bin_op_simd(&buf1, &buf2, &mut buf3, |a, b| a & b) }; - for i in buf3.iter() { - assert_eq!(&0b00110000u8, i); - } - } - - #[test] - #[cfg(feature = "simd")] - fn test_bitwise_or_simd() { - let buf1 = [0b00110011u8; 64]; - let buf2 = [0b11110000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { bitwise_bin_op_simd(&buf1, &buf2, &mut buf3, |a, b| a | b) }; - for i in buf3.iter() { - assert_eq!(&0b11110011u8, i); - } - } } diff --git a/arrow-buffer/src/util/mod.rs b/arrow-buffer/src/util/mod.rs new file mode 100644 index 000000000000..9023fe4a035d --- /dev/null +++ b/arrow-buffer/src/util/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod bit_chunk_iterator; +pub mod bit_iterator; +pub mod bit_mask; +pub mod bit_util; diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml new file mode 100644 index 000000000000..4046f5226094 --- /dev/null +++ b/arrow-cast/Cargo.toml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-cast" +version = { workspace = true } +description = "Cast kernel and utilities for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_cast" +path = "src/lib.rs" +bench = false + +[package.metadata.docs.rs] +features = ["prettyprint"] + +[features] +prettyprint = ["comfy-table"] +force_validate = [] + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +arrow-select = { workspace = true } +chrono = { workspace = true } +half = { version = "2.1", default-features = false } +num = { version = "0.4", default-features = false, features = ["std"] } +lexical-core = { version = "1.0", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } +atoi = "2.0.0" +comfy-table = { version = "7.0", optional = true, default-features = false } +base64 = "0.22" +ryu = "1.0.16" + +[dev-dependencies] +criterion = { version = "0.5", default-features = false } +half = { version = "2.1", default-features = false } +rand = "0.8" + +[build-dependencies] + +[[bench]] +name = "parse_timestamp" +harness = false + +[[bench]] +name = "parse_time" +harness = false + +[[bench]] +name = "parse_date" +harness = false + +[[bench]] +name = "parse_decimal" +harness = false diff --git a/arrow-cast/benches/parse_date.rs b/arrow-cast/benches/parse_date.rs new file mode 100644 index 000000000000..e05d38d2f853 --- /dev/null +++ b/arrow-cast/benches/parse_date.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Date32Type; +use arrow_cast::parse::Parser; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = ["2020-09-08", "2020-9-8", "2020-09-8", "2020-9-08"]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| Date32Type::parse(t).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/benches/parse_decimal.rs b/arrow-cast/benches/parse_decimal.rs new file mode 100644 index 000000000000..5682859dd25a --- /dev/null +++ b/arrow-cast/benches/parse_decimal.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Decimal256Type; +use arrow_cast::parse::parse_decimal; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let decimals = [ + "123.123", + "123.1234", + "123.1", + "123", + "-123.123", + "-123.1234", + "-123.1", + "-123", + "0.0000123", + "12.", + "-12.", + "00.1", + "-00.1", + "12345678912345678.1234", + "-12345678912345678.1234", + "99999999999999999.999", + "-99999999999999999.999", + ".123", + "-.123", + "123.", + "-123.", + ]; + + for decimal in decimals { + let d = black_box(decimal); + c.bench_function(d, |b| { + b.iter(|| parse_decimal::(d, 20, 3).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/benches/parse_time.rs b/arrow-cast/benches/parse_time.rs new file mode 100644 index 000000000000..d28b9c7c613d --- /dev/null +++ b/arrow-cast/benches/parse_time.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_cast::parse::string_to_time_nanoseconds; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = [ + "9:50", + "09:50", + "09:50 PM", + "9:50:12 AM", + "09:50:12 PM", + "09:50:12.123456789", + "9:50:12.123456789", + "09:50:12.123456789 PM", + ]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| string_to_time_nanoseconds(t).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow/src/util/serialization.rs b/arrow-cast/benches/parse_timestamp.rs similarity index 51% rename from arrow/src/util/serialization.rs rename to arrow-cast/benches/parse_timestamp.rs index 14d67ca117c4..d3ab41863e70 100644 --- a/arrow/src/util/serialization.rs +++ b/arrow-cast/benches/parse_timestamp.rs @@ -15,19 +15,30 @@ // specific language governing permissions and limitations // under the License. -/// Converts numeric type to a `String` -pub fn lexical_to_string(n: N) -> String { - let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); - unsafe { - // JUSTIFICATION - // Benefit - // Allows using the faster serializer lexical core and convert to string - // Soundness - // Length of buf is set as written length afterwards. lexical_core - // creates a valid string, so doesn't need to be checked. - let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); - let len = lexical_core::write(n, slice).len(); - buf.set_len(len); - String::from_utf8_unchecked(buf) +use arrow_cast::parse::string_to_timestamp_nanos; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = [ + "2020-09-08", + "2020-09-08T13:42:29", + "2020-09-08T13:42:29.190", + "2020-09-08T13:42:29.190855", + "2020-09-08T13:42:29.190855999", + "2020-09-08T13:42:29+00:00", + "2020-09-08T13:42:29.190+00:00", + "2020-09-08T13:42:29.190855+00:00", + "2020-09-08T13:42:29.190855999-05:00", + "2020-09-08T13:42:29.190855Z", + ]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| string_to_timestamp_nanos(t).unwrap()); + }); } } + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/src/base64.rs b/arrow-cast/src/base64.rs new file mode 100644 index 000000000000..534b21878c56 --- /dev/null +++ b/arrow-cast/src/base64.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for converting data in [`GenericBinaryArray`] such as [`StringArray`] to/from base64 encoded strings +//! +//! [`StringArray`]: arrow_array::StringArray + +use arrow_array::{Array, GenericBinaryArray, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::{Buffer, OffsetBuffer}; +use arrow_schema::ArrowError; +use base64::encoded_len; +use base64::engine::Config; + +pub use base64::prelude::*; + +/// Bas64 encode each element of `array` with the provided [`Engine`] +pub fn b64_encode( + engine: &E, + array: &GenericBinaryArray, +) -> GenericStringArray { + let lengths = array.offsets().windows(2).map(|w| { + let len = w[1].as_usize() - w[0].as_usize(); + encoded_len(len, engine.config().encode_padding()).unwrap() + }); + let offsets = OffsetBuffer::::from_lengths(lengths); + let buffer_len = offsets.last().unwrap().as_usize(); + let mut buffer = vec![0_u8; buffer_len]; + let mut offset = 0; + + for i in 0..array.len() { + let len = engine + .encode_slice(array.value(i), &mut buffer[offset..]) + .unwrap(); + offset += len; + } + assert_eq!(offset, buffer_len); + + // Safety: Base64 is valid UTF-8 + unsafe { + GenericStringArray::new_unchecked(offsets, Buffer::from_vec(buffer), array.nulls().cloned()) + } +} + +/// Base64 decode each element of `array` with the provided [`Engine`] +pub fn b64_decode( + engine: &E, + array: &GenericBinaryArray, +) -> Result, ArrowError> { + let estimated_len = array.values().len(); // This is an overestimate + let mut buffer = vec![0; estimated_len]; + + let mut offsets = Vec::with_capacity(array.len() + 1); + offsets.push(O::usize_as(0)); + let mut offset = 0; + + for v in array.iter() { + if let Some(v) = v { + let len = engine.decode_slice(v, &mut buffer[offset..]).unwrap(); + // This cannot overflow as `len` is less than `v.len()` and `a` is valid + offset += len; + } + offsets.push(O::usize_as(offset)); + } + + // Safety: offsets monotonically increasing by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Ok(GenericBinaryArray::new( + offsets, + Buffer::from_vec(buffer), + array.nulls().cloned(), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::BinaryArray; + use rand::{thread_rng, Rng}; + + fn test_engine(e: &E, a: &BinaryArray) { + let encoded = b64_encode(e, a); + encoded.to_data().validate_full().unwrap(); + + let to_decode = encoded.into(); + let decoded = b64_decode(e, &to_decode).unwrap(); + decoded.to_data().validate_full().unwrap(); + + assert_eq!(&decoded, a); + } + + #[test] + fn test_b64() { + let mut rng = thread_rng(); + let len = rng.gen_range(1024..1050); + let data: BinaryArray = (0..len) + .map(|_| { + let len = rng.gen_range(0..16); + Some((0..len).map(|_| rng.gen()).collect::>()) + }) + .collect(); + + test_engine(&BASE64_STANDARD, &data); + test_engine(&BASE64_STANDARD_NO_PAD, &data); + } +} diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs new file mode 100644 index 000000000000..637cbc417008 --- /dev/null +++ b/arrow-cast/src/cast/decimal.rs @@ -0,0 +1,569 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// A utility trait that provides checked conversions between +/// decimal types inspired by [`NumCast`] +pub(crate) trait DecimalCast: Sized { + fn to_i128(self) -> Option; + + fn to_i256(self) -> Option; + + fn from_decimal(n: T) -> Option; +} + +impl DecimalCast for i128 { + fn to_i128(self) -> Option { + Some(self) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self)) + } + + fn from_decimal(n: T) -> Option { + n.to_i128() + } +} + +impl DecimalCast for i256 { + fn to_i128(self) -> Option { + self.to_i128() + } + + fn to_i256(self) -> Option { + Some(self) + } + + fn from_decimal(n: T) -> Option { + n.to_i256() + } +} + +pub(crate) fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +pub(crate) fn convert_to_smaller_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let div = I::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((input_scale - output_scale) as u32)?; + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + let f = |x: I::Native| { + // div is >= 10 and so this cannot overflow + let d = x.div_wrapping(div); + let r = x.mod_wrapping(div); + + // Round result + let adjusted = match x >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + }; + O::Native::from_decimal(adjusted) + }; + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +pub(crate) fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let mul = O::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((output_scale - input_scale) as u32)?; + + let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +// Only support one type of decimal cast operations +pub(crate) fn cast_decimal_to_decimal_same_type( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = match input_scale.cmp(&output_scale) { + Ordering::Equal => { + // the scale doesn't change, the native value don't need to be changed + array.clone() + } + Ordering::Greater => convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )?, + Ordering::Less => { + // input_scale < output_scale + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +// Support two different types of decimal cast operations +pub(crate) fn cast_decimal_to_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = if input_scale > output_scale { + convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } else { + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +/// Parses given string to specified decimal native (i128/i256) based on given +/// scale. Returns an `Err` if it cannot parse given string. +pub(crate) fn parse_string_to_decimal_native( + value_str: &str, + scale: usize, +) -> Result +where + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let value_str = value_str.trim(); + let parts: Vec<&str> = value_str.split('.').collect(); + if parts.len() > 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + let (negative, first_part) = if parts[0].is_empty() { + (false, parts[0]) + } else { + match parts[0].as_bytes()[0] { + b'-' => (true, &parts[0][1..]), + b'+' => (false, &parts[0][1..]), + _ => (false, parts[0]), + } + }; + + let integers = first_part.trim_start_matches('0'); + let decimals = if parts.len() == 2 { parts[1] } else { "" }; + + if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + // Adjust decimal based on scale + let mut number_decimals = if decimals.len() > scale { + let decimal_number = i256::from_string(decimals).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) + })?; + + let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; + + let half = div.div_wrapping(i256::from_i128(2)); + let half_neg = half.neg_wrapping(); + + let d = decimal_number.div_wrapping(div); + let r = decimal_number.mod_wrapping(div); + + // Round result + let adjusted = match decimal_number >= i256::ZERO { + true if r >= half => d.add_wrapping(i256::ONE), + false if r <= half_neg => d.sub_wrapping(i256::ONE), + _ => d, + }; + + let integers = if !integers.is_empty() { + i256::from_string(integers) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Cannot parse decimal format: {value_str}" + )) + }) + .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? + } else { + i256::ZERO + }; + + format!("{}", integers.add_wrapping(adjusted)) + } else { + let padding = if scale > decimals.len() { scale } else { 0 }; + + let decimals = format!("{decimals:0( + from: &GenericStringArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if cast_options.safe { + let iter = from.iter().map(|v| { + v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + .with_precision_and_scale(precision, scale)? + }) + } else { + let vec = from + .iter() + .map(|v| { + v.map(|v| { + parse_string_to_decimal_native::(v, scale as usize) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + }) + .transpose() + }) + .collect::, _>>()?; + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + .with_precision_and_scale(precision, scale)? + }) + } +} + +/// Cast Utf8 to decimal +pub(crate) fn cast_string_to_decimal( + from: &dyn Array, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if scale < 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal with negative scale {scale}" + ))); + } + + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal greater than maximum scale {}", + T::MAX_SCALE + ))); + } + + Ok(Arc::new(string_to_decimal_cast::( + from.as_any() + .downcast_ref::>() + .unwrap(), + precision, + scale, + cast_options, + )?)) +} + +pub(crate) fn cast_floating_point_to_decimal128( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal128Type>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision)) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal128Type, _>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +pub(crate) fn cast_floating_point_to_decimal256( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal256Type>(|v| { + i256::from_f64((v.as_() * mul).round()) + .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision)) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_f64((v.as_() * mul).round()) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +pub(crate) fn cast_decimal_to_integer( + array: &dyn Array, + base: D::Native, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: ArrowPrimitiveType, + ::Native: NumCast, + D: DecimalType + ArrowPrimitiveType, + ::Native: ArrowNativeTypeOp + ToPrimitive, +{ + let array = array.as_primitive::(); + + let div: D::Native = base.pow_checked(scale as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}. The scale {} causes overflow.", + D::PREFIX, + scale, + )) + })?; + + let mut value_builder = PrimitiveBuilder::::with_capacity(array.len()); + + if cast_options.safe { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array + .value(i) + .div_checked(div) + .ok() + .and_then(::from::); + + value_builder.append_option(v); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array.value(i).div_checked(div)?; + + let value = ::from::(v).ok_or_else(|| { + ArrowError::CastError(format!( + "value of {:?} is out of range {}", + v, + T::DATA_TYPE + )) + })?; + + value_builder.append_value(value); + } + } + } + Ok(Arc::new(value_builder.finish())) +} + +// Cast the decimal array to floating-point array +pub(crate) fn cast_decimal_to_float( + array: &dyn Array, + op: F, +) -> Result +where + F: Fn(D::Native) -> T::Native, +{ + let array = array.as_primitive::(); + let array = array.unary::<_, T>(op); + Ok(Arc::new(array)) +} diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs new file mode 100644 index 000000000000..ec0ab346f997 --- /dev/null +++ b/arrow-cast/src/cast/dictionary.rs @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Attempts to cast an `ArrayDictionary` with index type K into +/// `to_type` for supported types. +/// +/// K is the key type +pub(crate) fn dictionary_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match to_type { + Dictionary(to_index_type, to_value_type) => { + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + let keys_array: ArrayRef = + Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); + let values_array = dict_array.values(); + let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } + + let data = cast_keys.into_data(); + let builder = data + .into_builder() + .data_type(to_type.clone()) + .child_data(vec![cast_values.into_data()]); + + // Safety + // Cast keys are still valid + let data = unsafe { builder.build_unchecked() }; + + // create the appropriate array type + let new_array: ArrayRef = match **to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported type {to_index_type:?} for dictionary index" + ))); + } + }; + + Ok(new_array) + } + Utf8View => { + // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. + // we handle it here to avoid the copy. + let dict_array = array + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast Utf8View to StringArray of expected type" + .to_string(), + ) + })?; + + let string_view = view_from_dict_values::>( + dict_array.values(), + dict_array.keys(), + )?; + Ok(Arc::new(string_view)) + } + BinaryView => { + // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. + // we handle it here to avoid the copy. + let dict_array = array + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast BinaryView to BinaryArray of expected type" + .to_string(), + ) + })?; + + let binary_view = view_from_dict_values::( + dict_array.values(), + dict_array.keys(), + )?; + Ok(Arc::new(binary_view)) + } + _ => unpack_dictionary::(array, to_type, cast_options), + } +} + +fn view_from_dict_values( + array: &GenericByteArray, + keys: &PrimitiveArray, +) -> Result, ArrowError> { + let value_buffer = array.values(); + let value_offsets = array.value_offsets(); + let mut builder = GenericByteViewBuilder::::with_capacity(keys.len()); + builder.append_block(value_buffer.clone()); + for i in keys.iter() { + match i { + Some(v) => { + let idx = v.to_usize().ok_or_else(|| { + ArrowError::ComputeError("Invalid dictionary index".to_string()) + })?; + + // Safety + // (1) The index is within bounds as they are offsets + // (2) The append_view is safe + unsafe { + let offset = value_offsets.get_unchecked(idx).as_usize(); + let end = value_offsets.get_unchecked(idx + 1).as_usize(); + let length = end - offset; + builder.append_view_unchecked(0, offset as u32, length as u32) + } + } + None => { + builder.append_null(); + } + } + } + Ok(builder.finish()) +} + +// Unpack a dictionary where the keys are of type into a flattened array of type to_type +pub(crate) fn unpack_dictionary( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let dict_array = array.as_dictionary::(); + let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; + take(cast_dict_values.as_ref(), dict_array.keys(), None) +} + +/// Pack a data type into a dictionary array passing the values through a primitive array +pub(crate) fn pack_array_to_dictionary_via_primitive( + array: &dyn Array, + primitive_type: DataType, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let primitive = cast_with_options(array, &primitive_type, cast_options)?; + let dict = cast_with_options( + primitive.as_ref(), + &DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(primitive_type)), + cast_options, + )?; + cast_with_options( + dict.as_ref(), + &DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(dict_value_type.clone())), + cast_options, + ) +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +pub(crate) fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match *dict_value_type { + Int8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Decimal128(p, s) => { + let dict = pack_numeric_to_dictionary::( + array, + dict_value_type, + cast_options, + )?; + let dict = dict + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dict to Decimal128Array".to_string(), + ) + })?; + let value = dict.values().clone(); + // Set correct precision/scale + let value = value.with_precision_and_scale(p, s)?; + Ok(Arc::new(DictionaryArray::::try_new( + dict.keys().clone(), + Arc::new(value), + )?)) + } + Decimal256(p, s) => { + let dict = pack_numeric_to_dictionary::( + array, + dict_value_type, + cast_options, + )?; + let dict = dict + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dict to Decimal256Array".to_string(), + ) + })?; + let value = dict.values().clone(); + // Set correct precision/scale + let value = value.with_precision_and_scale(p, s)?; + Ok(Arc::new(DictionaryArray::::try_new( + dict.keys().clone(), + Arc::new(value), + )?)) + } + Float16 => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Float32 => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Float64 => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Date32 => pack_array_to_dictionary_via_primitive::( + array, + DataType::Int32, + dict_value_type, + cast_options, + ), + Date64 => pack_array_to_dictionary_via_primitive::( + array, + DataType::Int64, + dict_value_type, + cast_options, + ), + Time32(_) => pack_array_to_dictionary_via_primitive::( + array, + DataType::Int32, + dict_value_type, + cast_options, + ), + Time64(_) => pack_array_to_dictionary_via_primitive::( + array, + DataType::Int64, + dict_value_type, + cast_options, + ), + Timestamp(_, _) => pack_array_to_dictionary_via_primitive::( + array, + DataType::Int64, + dict_value_type, + cast_options, + ), + Utf8 => { + // If the input is a view type, we can avoid casting (thus copying) the data + if array.data_type() == &DataType::Utf8View { + return string_view_to_dictionary::(array); + } + pack_byte_to_dictionary::>(array, cast_options) + } + LargeUtf8 => { + // If the input is a view type, we can avoid casting (thus copying) the data + if array.data_type() == &DataType::Utf8View { + return string_view_to_dictionary::(array); + } + pack_byte_to_dictionary::>(array, cast_options) + } + Binary => { + // If the input is a view type, we can avoid casting (thus copying) the data + if array.data_type() == &DataType::BinaryView { + return binary_view_to_dictionary::(array); + } + pack_byte_to_dictionary::>(array, cast_options) + } + LargeBinary => { + // If the input is a view type, we can avoid casting (thus copying) the data + if array.data_type() == &DataType::BinaryView { + return binary_view_to_dictionary::(array); + } + pack_byte_to_dictionary::>(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported output type for dictionary packing: {dict_value_type:?}" + ))), + } +} + +// Packs the data from the primitive array of type to a +// DictionaryArray with keys of type K and values of value_type V +pub(crate) fn pack_numeric_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + // attempt to cast the source array values to the target value type (the dictionary values type) + let cast_values = cast_with_options(array, dict_value_type, cast_options)?; + let values = cast_values.as_primitive::(); + + let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +pub(crate) fn string_view_to_dictionary( + array: &dyn Array, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let mut b = GenericByteDictionaryBuilder::>::with_capacity( + array.len(), + 1024, + 1024, + ); + let string_view = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError("Internal Error: Cannot cast to StringViewArray".to_string()) + })?; + for v in string_view.iter() { + match v { + Some(v) => { + b.append(v)?; + } + None => { + b.append_null(); + } + } + } + + Ok(Arc::new(b.finish())) +} + +pub(crate) fn binary_view_to_dictionary( + array: &dyn Array, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let mut b = GenericByteDictionaryBuilder::>::with_capacity( + array.len(), + 1024, + 1024, + ); + let binary_view = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError("Internal Error: Cannot cast to BinaryViewArray".to_string()) + })?; + for v in binary_view.iter() { + match v { + Some(v) => { + b.append(v)?; + } + None => { + b.append_null(); + } + } + } + + Ok(Arc::new(b.finish())) +} + +// Packs the data as a GenericByteDictionaryBuilder, if possible, with the +// key types of K +pub(crate) fn pack_byte_to_dictionary( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + let cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; + let values = cast_values + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError("Internal Error: Cannot cast to GenericByteArray".to_string()) + })?; + let mut b = GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} diff --git a/arrow-cast/src/cast/list.rs b/arrow-cast/src/cast/list.rs new file mode 100644 index 000000000000..ec7a5c57d504 --- /dev/null +++ b/arrow-cast/src/cast/list.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Helper function that takes a primitive array and casts to a (generic) list array. +pub(crate) fn cast_values_to_list( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let values = cast_with_options(array, to.data_type(), cast_options)?; + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); + let list = GenericListArray::::new(to.clone(), offsets, values, None); + Ok(Arc::new(list)) +} + +/// Helper function that takes a primitive array and casts to a fixed size list array. +pub(crate) fn cast_values_to_fixed_size_list( + array: &dyn Array, + to: &FieldRef, + size: i32, + cast_options: &CastOptions, +) -> Result { + let values = cast_with_options(array, to.data_type(), cast_options)?; + let list = FixedSizeListArray::new(to.clone(), size, values, None); + Ok(Arc::new(list)) +} + +pub(crate) fn cast_single_element_fixed_size_list_to_values( + array: &dyn Array, + to: &DataType, + cast_options: &CastOptions, +) -> Result { + let values = array.as_fixed_size_list().values(); + cast_with_options(values, to, cast_options) +} + +pub(crate) fn cast_fixed_size_list_to_list( + array: &dyn Array, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let fixed_size_list: &FixedSizeListArray = array.as_fixed_size_list(); + let list: GenericListArray = fixed_size_list.clone().into(); + Ok(Arc::new(list)) +} + +pub(crate) fn cast_list_to_fixed_size_list( + array: &GenericListArray, + field: &FieldRef, + size: i32, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let cap = array.len() * size as usize; + + // Whether the resulting array may contain null lists + let nullable = cast_options.safe || array.null_count() != 0; + let mut nulls = nullable.then(|| { + let mut buffer = BooleanBufferBuilder::new(array.len()); + match array.nulls() { + Some(n) => buffer.append_buffer(n.inner()), + None => buffer.append_n(array.len(), true), + } + buffer + }); + + // Nulls in FixedSizeListArray take up space and so we must pad the values + let values = array.values().to_data(); + let mut mutable = MutableArrayData::new(vec![&values], nullable, cap); + // The end position in values of the last incorrectly-sized list slice + let mut last_pos = 0; + for (idx, w) in array.offsets().windows(2).enumerate() { + let start_pos = w[0].as_usize(); + let end_pos = w[1].as_usize(); + let len = end_pos - start_pos; + + if len != size as usize { + if cast_options.safe || array.is_null(idx) { + if last_pos != start_pos { + // Extend with valid slices + mutable.extend(0, last_pos, start_pos); + } + // Pad this slice with nulls + mutable.extend_nulls(size as _); + nulls.as_mut().unwrap().set_bit(idx, false); + // Set last_pos to the end of this slice's values + last_pos = end_pos + } else { + return Err(ArrowError::CastError(format!( + "Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}", + ))); + } + } + } + + let values = match last_pos { + 0 => array.values().slice(0, cap), // All slices were the correct length + _ => { + if mutable.len() != cap { + // Remaining slices were all correct length + let remaining = cap - mutable.len(); + mutable.extend(0, last_pos, last_pos + remaining) + } + make_array(mutable.freeze()) + } + }; + + // Cast the inner values if necessary + let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?; + + // Construct the FixedSizeListArray + let nulls = nulls.map(|mut x| x.finish().into()); + let array = FixedSizeListArray::new(field.clone(), size, values, nulls); + Ok(Arc::new(array)) +} + +/// Helper function that takes an Generic list container and casts the inner datatype. +pub(crate) fn cast_list_values( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = cast_with_options(list.values(), to.data_type(), cast_options)?; + Ok(Arc::new(GenericListArray::::new( + to.clone(), + list.offsets().clone(), + values, + list.nulls().cloned(), + ))) +} + +/// Cast the container type of List/Largelist array along with the inner datatype +pub(crate) fn cast_list( + array: &dyn Array, + field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = list.values(); + let offsets = list.offsets(); + let nulls = list.nulls().cloned(); + + if !O::IS_LARGE && values.len() > i32::MAX as usize { + return Err(ArrowError::ComputeError( + "LargeList too large to cast to List".into(), + )); + } + + // Recursively cast values + let values = cast_with_options(values, field.data_type(), cast_options)?; + let offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect(); + + // Safety: valid offsets and checked for overflow + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Ok(Arc::new(GenericListArray::::new( + field.clone(), + offsets, + values, + nulls, + ))) +} diff --git a/arrow-cast/src/cast/map.rs b/arrow-cast/src/cast/map.rs new file mode 100644 index 000000000000..d62a9519b7b3 --- /dev/null +++ b/arrow-cast/src/cast/map.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Helper function that takes a map container and casts the inner datatype. +pub(crate) fn cast_map_values( + from: &MapArray, + to_data_type: &DataType, + cast_options: &CastOptions, + to_ordered: bool, +) -> Result { + let entries_field = if let DataType::Map(entries_field, _) = to_data_type { + entries_field + } else { + return Err(ArrowError::CastError( + "Internal Error: to_data_type is not a map type.".to_string(), + )); + }; + + let key_field = key_field(entries_field).ok_or(ArrowError::CastError( + "map is missing key field".to_string(), + ))?; + let value_field = value_field(entries_field).ok_or(ArrowError::CastError( + "map is missing value field".to_string(), + ))?; + + let key_array = cast_with_options(from.keys(), key_field.data_type(), cast_options)?; + let value_array = cast_with_options(from.values(), value_field.data_type(), cast_options)?; + + Ok(Arc::new(MapArray::new( + entries_field.clone(), + from.offsets().clone(), + StructArray::new( + Fields::from(vec![key_field, value_field]), + vec![key_array, value_array], + from.entries().nulls().cloned(), + ), + from.nulls().cloned(), + to_ordered, + ))) +} + +/// Gets the key field from the entries of a map. For all other types returns None. +pub(crate) fn key_field(entries_field: &FieldRef) -> Option { + if let DataType::Struct(fields) = entries_field.data_type() { + fields.first().cloned() + } else { + None + } +} + +/// Gets the value field from the entries of a map. For all other types returns None. +pub(crate) fn value_field(entries_field: &FieldRef) -> Option { + if let DataType::Struct(fields) = entries_field.data_type() { + fields.get(1).cloned() + } else { + None + } +} diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs new file mode 100644 index 000000000000..e3fad3da19f8 --- /dev/null +++ b/arrow-cast/src/cast/mod.rs @@ -0,0 +1,9693 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cast kernels to convert [`ArrayRef`] between supported datatypes. +//! +//! See [`cast_with_options`] for more information on specific conversions. +//! +//! Example: +//! +//! ``` +//! # use arrow_array::*; +//! # use arrow_cast::cast; +//! # use arrow_schema::DataType; +//! # use std::sync::Arc; +//! # use arrow_array::types::Float64Type; +//! # use arrow_array::cast::AsArray; +//! // int32 to float64 +//! let a = Int32Array::from(vec![5, 6, 7]); +//! let b = cast(&a, &DataType::Float64).unwrap(); +//! let c = b.as_primitive::(); +//! assert_eq!(5.0, c.value(0)); +//! assert_eq!(6.0, c.value(1)); +//! assert_eq!(7.0, c.value(2)); +//! ``` + +mod decimal; +mod dictionary; +mod list; +mod map; +mod string; +use crate::cast::decimal::*; +use crate::cast::dictionary::*; +use crate::cast::list::*; +use crate::cast::map::*; +use crate::cast::string::*; + +use arrow_buffer::IntervalMonthDayNano; +use arrow_data::ByteView; +use chrono::{NaiveTime, Offset, TimeZone, Utc}; +use std::cmp::Ordering; +use std::sync::Arc; + +use crate::display::{ArrayFormatter, FormatOptions}; +use crate::parse::{ + parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, + string_to_datetime, Parser, +}; +use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; +use arrow_buffer::{i256, ArrowNativeType, OffsetBuffer}; +use arrow_data::transform::MutableArrayData; +use arrow_data::ArrayData; +use arrow_schema::*; +use arrow_select::take::take; +use num::cast::AsPrimitive; +use num::{NumCast, ToPrimitive}; + +/// CastOptions provides a way to override the default cast behaviors +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct CastOptions<'a> { + /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) + pub safe: bool, + /// Formatting options when casting from temporal types to string + pub format_options: FormatOptions<'a>, +} + +impl<'a> Default for CastOptions<'a> { + fn default() -> Self { + Self { + safe: true, + format_options: FormatOptions::default(), + } + } +} + +/// Return true if a value of type `from_type` can be cast into a value of `to_type`. +/// +/// See [`cast_with_options`] for more information +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + use self::IntervalUnit::*; + use self::TimeUnit::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | BinaryView + | Utf8View + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) => true, + // Dictionary/List conditions should be put in front of others + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + (List(list_from) | LargeList(list_from), List(list_to) | LargeList(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => { + can_cast_types(list_from.data_type(), to_type) + } + (List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (List(_), _) => false, + (FixedSizeList(list_from,_), List(list_to)) | + (FixedSizeList(list_from,_), LargeList(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (FixedSizeList(inner, size), FixedSizeList(inner_to, size_to)) if size == size_to => { + can_cast_types(inner.data_type(), inner_to.data_type()) + } + (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), + (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), + (_, FixedSizeList(list_to,size)) if *size == 1 => { + can_cast_types(from_type, list_to.data_type())}, + (FixedSizeList(list_from,size), _) if *size == 1 => { + can_cast_types(list_from.data_type(), to_type)}, + (Map(from_entries,ordered_from), Map(to_entries, ordered_to)) if ordered_from == ordered_to => + match (key_field(from_entries), key_field(to_entries), value_field(from_entries), value_field(to_entries)) { + (Some(from_key), Some(to_key), Some(from_value), Some(to_value)) => + can_cast_types(from_key.data_type(), to_key.data_type()) && can_cast_types(from_value.data_type(), to_value.data_type()), + _ => false + }, + // cast one decimal type to another decimal type + (Decimal128(_, _), Decimal128(_, _)) => true, + (Decimal256(_, _), Decimal256(_, _)) => true, + (Decimal128(_, _), Decimal256(_, _)) => true, + (Decimal256(_, _), Decimal128(_, _)) => true, + // unsigned integer to decimal + (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) | + (UInt8 | UInt16 | UInt32 | UInt64, Decimal256(_, _)) | + // signed numeric to decimal + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) | + // decimal to unsigned numeric + (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) | + // decimal to signed numeric + (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, + // decimal to Utf8 + (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, + // Utf8 to decimal + (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + (Struct(from_fields), Struct(to_fields)) => { + from_fields.len() == to_fields.len() && + from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + // Assume that nullability between two structs are compatible, if not, + // cast kernel will return error. + can_cast_types(f1.data_type(), f2.data_type()) + }) + } + (Struct(_), _) => false, + (_, Struct(_)) => false, + (_, Boolean) => { + DataType::is_integer(from_type) || + DataType::is_floating(from_type) + || from_type == &Utf8 + || from_type == &LargeUtf8 + } + (Boolean, _) => { + DataType::is_integer(to_type) || DataType::is_floating(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 + } + + (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView) => true, + (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView) => true, + (FixedSizeBinary(_), Binary | LargeBinary) => true, + ( + Utf8 | LargeUtf8 | Utf8View, + Binary + | LargeBinary + | Utf8 + | LargeUtf8 + | Date32 + | Date64 + | Time32(Second) + | Time32(Millisecond) + | Time64(Microsecond) + | Time64(Nanosecond) + | Timestamp(Second, _) + | Timestamp(Millisecond, _) + | Timestamp(Microsecond, _) + | Timestamp(Nanosecond, _) + | Interval(_) + | BinaryView, + ) => true, + (Utf8 | LargeUtf8, Utf8View) => true, + (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true, + (Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, + (_, Utf8 | LargeUtf8) => from_type.is_primitive(), + + (_, Binary | LargeBinary) => from_type.is_integer(), + + // start numeric casts + ( + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + ) => true, + // end numeric casts + + // temporal casts + (Int32, Date32 | Date64 | Time32(_)) => true, + (Date32, Int32 | Int64) => true, + (Time32(_), Int32) => true, + (Int64, Date64 | Date32 | Time64(_)) => true, + (Date64, Int64 | Int32) => true, + (Time64(_), Int64) => true, + (Date32 | Date64, Date32 | Date64) => true, + // time casts + (Time32(_), Time32(_)) => true, + (Time32(_), Time64(_)) => true, + (Time64(_), Time64(_)) => true, + (Time64(_), Time32(to_unit)) => { + matches!(to_unit, Second | Millisecond) + } + (Timestamp(_, _), _) if to_type.is_numeric() => true, + (_, Timestamp(_, _)) if from_type.is_numeric() => true, + (Date64, Timestamp(_, None)) => true, + (Date32, Timestamp(_, None)) => true, + ( + Timestamp(_, _), + Timestamp(_, _) + | Date32 + | Date64 + | Time32(Second) + | Time32(Millisecond) + | Time64(Microsecond) + | Time64(Nanosecond), + ) => true, + (_, Duration(_)) if from_type.is_numeric() => true, + (Duration(_), _) if to_type.is_numeric() => true, + (Duration(_), Duration(_)) => true, + (Interval(from_type), Int64) => { + match from_type { + YearMonth => true, + DayTime => true, + MonthDayNano => false, // Native type is i128 + } + } + (Int32, Interval(to_type)) => match to_type { + YearMonth => true, + DayTime => false, + MonthDayNano => false, + }, + (Duration(_), Interval(MonthDayNano)) => true, + (Interval(MonthDayNano), Duration(_)) => true, + (Interval(YearMonth), Interval(MonthDayNano)) => true, + (Interval(DayTime), Interval(MonthDayNano)) => true, + (_, _) => false, + } +} + +/// Cast `array` to the provided data type and return a new Array with type `to_type`, if possible. +/// +/// See [`cast_with_options`] for more information +pub fn cast(array: &dyn Array, to_type: &DataType) -> Result { + cast_with_options(array, to_type, &CastOptions::default()) +} + +fn cast_integer_to_decimal< + T: ArrowPrimitiveType, + D: DecimalType + ArrowPrimitiveType, + M, +>( + array: &PrimitiveArray, + precision: u8, + scale: i8, + base: M, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, + M: ArrowNativeTypeOp, +{ + let scale_factor = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). The scale causes overflow.", + D::PREFIX, + precision, + scale, + )) + })?; + + let array = if scale < 0 { + match cast_options.safe { + true => array.unary_opt::<_, D>(|v| { + v.as_() + .div_checked(scale_factor) + .ok() + .and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v)) + }), + false => array.try_unary::<_, D, _>(|v| { + v.as_() + .div_checked(scale_factor) + .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + })?, + } + } else { + match cast_options.safe { + true => array.unary_opt::<_, D>(|v| { + v.as_() + .mul_checked(scale_factor) + .ok() + .and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v)) + }), + false => array.try_unary::<_, D, _>(|v| { + v.as_() + .mul_checked(scale_factor) + .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + })?, + } + }; + + Ok(Arc::new(array.with_precision_and_scale(precision, scale)?)) +} + +/// Cast the array from interval year month to month day nano +fn cast_interval_year_month_to_interval_month_day_nano( + array: &dyn Array, + _cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + + Ok(Arc::new(array.unary::<_, IntervalMonthDayNanoType>(|v| { + let months = IntervalYearMonthType::to_months(v); + IntervalMonthDayNanoType::make_value(months, 0, 0) + }))) +} + +/// Cast the array from interval day time to month day nano +fn cast_interval_day_time_to_interval_month_day_nano( + array: &dyn Array, + _cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + let mul = 1_000_000; + + Ok(Arc::new(array.unary::<_, IntervalMonthDayNanoType>(|v| { + let (days, ms) = IntervalDayTimeType::to_parts(v); + IntervalMonthDayNanoType::make_value(0, days, ms as i64 * mul) + }))) +} + +/// Cast the array from interval to duration +fn cast_month_day_nano_to_duration>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + let scale = match D::DATA_TYPE { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array.iter().map(|v| { + v.and_then(|v| (v.days == 0 && v.months == 0).then_some(v.nanoseconds / scale)) + }); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| match v.days == 0 && v.months == 0 { + true => Ok((v.nanoseconds) / scale), + _ => Err(ArrowError::ComputeError( + "Cannot convert interval containing non-zero months or days to duration" + .to_string(), + )), + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + +/// Cast the array from duration and interval +fn cast_duration_to_interval>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast duration to DurationArray of expected type" + .to_string(), + ) + })?; + + let scale = match array.data_type() { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array.iter().map(|v| { + v.and_then(|v| { + v.checked_mul(scale) + .map(|v| IntervalMonthDayNano::new(0, 0, v)) + }) + }); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| { + if let Ok(v) = v.mul_checked(scale) { + Ok(IntervalMonthDayNano::new(0, 0, v)) + } else { + Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + IntervalMonthDayNanoType::DATA_TYPE, + v + ))) + } + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + +/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] +fn cast_reinterpret_arrays>( + array: &dyn Array, +) -> Result { + Ok(Arc::new(array.as_primitive::().reinterpret_cast::())) +} + +fn make_timestamp_array( + array: &PrimitiveArray, + unit: TimeUnit, + tz: Option>, +) -> ArrayRef { + match unit { + TimeUnit::Second => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Millisecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Microsecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Nanosecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + } +} + +fn make_duration_array(array: &PrimitiveArray, unit: TimeUnit) -> ArrayRef { + match unit { + TimeUnit::Second => Arc::new(array.reinterpret_cast::()), + TimeUnit::Millisecond => Arc::new(array.reinterpret_cast::()), + TimeUnit::Microsecond => Arc::new(array.reinterpret_cast::()), + TimeUnit::Nanosecond => Arc::new(array.reinterpret_cast::()), + } +} + +fn as_time_res_with_timezone( + v: i64, + tz: Option, +) -> Result { + let time = match tz { + Some(tz) => as_datetime_with_timezone::(v, tz).map(|d| d.time()), + None => as_datetime::(v).map(|d| d.time()), + }; + + time.ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to create naive time with {} {}", + std::any::type_name::(), + v + )) + }) +} + +fn timestamp_to_date32( + array: &PrimitiveArray, +) -> Result { + let err = |x: i64| { + ArrowError::CastError(format!( + "Cannot convert {} {x} to datetime", + std::any::type_name::() + )) + }; + + let array: Date32Array = match array.timezone() { + Some(tz) => { + let tz: Tz = tz.parse()?; + array.try_unary(|x| { + as_datetime_with_timezone::(x, tz) + .ok_or_else(|| err(x)) + .map(|d| Date32Type::from_naive_date(d.date_naive())) + })? + } + None => array.try_unary(|x| { + as_datetime::(x) + .ok_or_else(|| err(x)) + .map(|d| Date32Type::from_naive_date(d.date())) + })?, + }; + Ok(Arc::new(array)) +} + +/// Try to cast `array` to `to_type` if possible. +/// +/// Returns a new Array with type `to_type` if possible. +/// +/// Accepts [`CastOptions`] to specify cast behavior. See also [`cast()`]. +/// +/// # Behavior +/// * `Boolean` to `Utf8`: `true` => '1', `false` => `0` +/// * `Utf8` to `Boolean`: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`, +/// short variants are accepted, other strings return null or error +/// * `Utf8` to Numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to `Boolean`: 0 returns `false`, any other value returns `true` +/// * `List` to `List`: the underlying data type is cast +/// * `List` to `FixedSizeList`: the underlying data type is cast. If safe is true and a list element +/// has the wrong length it will be replaced with NULL, otherwise an error will be returned +/// * Primitive to `List`: a list array with 1 value per slot is created +/// * `Date32` and `Date64`: precision lost when going to higher interval +/// * `Time32 and `Time64`: precision lost when going to higher interval +/// * `Timestamp` and `Date{32|64}`: precision lost when going to higher interval +/// * Temporal to/from backing Primitive: zero-copy with data type change +/// * `Float32/Float64` to `Decimal(precision, scale)` rounds to the `scale` decimals +/// (i.e. casting `6.4999` to `Decimal(10, 1)` becomes `6.5`). +/// +/// Unsupported Casts (check with `can_cast_types` before calling): +/// * To or from `StructArray` +/// * `List` to `Primitive` +/// * `Interval` and `Duration` +/// +/// # Timestamps and Timezones +/// +/// Timestamps are stored with an optional timezone in Arrow. +/// +/// ## Casting timestamps to a timestamp without timezone / UTC +/// ``` +/// # use arrow_array::Int64Array; +/// # use arrow_array::types::TimestampSecondType; +/// # use arrow_cast::{cast, display}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_schema::{DataType, TimeUnit}; +/// // can use "UTC" if chrono-tz feature is enabled, here use offset based timezone +/// let data_type = DataType::Timestamp(TimeUnit::Second, None); +/// let a = Int64Array::from(vec![1_000_000_000, 2_000_000_000, 3_000_000_000]); +/// let b = cast(&a, &data_type).unwrap(); +/// let b = b.as_primitive::(); // downcast to result type +/// assert_eq!(2_000_000_000, b.value(1)); // values are the same as the type has no timezone +/// // use display to show them (note has no trailing Z) +/// assert_eq!("2033-05-18T03:33:20", display::array_value_to_string(&b, 1).unwrap()); +/// ``` +/// +/// ## Casting timestamps to a timestamp with timezone +/// +/// Similarly to the previous example, if you cast numeric values to a timestamp +/// with timezone, the cast kernel will not change the underlying values +/// but display and other functions will interpret them as being in the provided timezone. +/// +/// ``` +/// # use arrow_array::Int64Array; +/// # use arrow_array::types::TimestampSecondType; +/// # use arrow_cast::{cast, display}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_schema::{DataType, TimeUnit}; +/// // can use "Americas/New_York" if chrono-tz feature is enabled, here use offset based timezone +/// let data_type = DataType::Timestamp(TimeUnit::Second, Some("-05:00".into())); +/// let a = Int64Array::from(vec![1_000_000_000, 2_000_000_000, 3_000_000_000]); +/// let b = cast(&a, &data_type).unwrap(); +/// let b = b.as_primitive::(); // downcast to result type +/// assert_eq!(2_000_000_000, b.value(1)); // values are still the same +/// // displayed in the target timezone (note the offset -05:00) +/// assert_eq!("2033-05-17T22:33:20-05:00", display::array_value_to_string(&b, 1).unwrap()); +/// ``` +/// # Casting timestamps without timezone to timestamps with timezone +/// +/// When casting from a timestamp without timezone to a timestamp with +/// timezone, the cast kernel interprets the timestamp values as being in +/// the destination timezone and then adjusts the underlying value to UTC as required +/// +/// However, note that when casting from a timestamp with timezone BACK to a +/// timestamp without timezone the cast kernel does not adjust the values. +/// +/// Thus round trip casting a timestamp without timezone to a timestamp with +/// timezone and back to a timestamp without timezone results in different +/// values than the starting values. +/// +/// ``` +/// # use arrow_array::Int64Array; +/// # use arrow_array::types::{TimestampSecondType}; +/// # use arrow_cast::{cast, display}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_schema::{DataType, TimeUnit}; +/// let data_type = DataType::Timestamp(TimeUnit::Second, None); +/// let data_type_tz = DataType::Timestamp(TimeUnit::Second, Some("-05:00".into())); +/// let a = Int64Array::from(vec![1_000_000_000, 2_000_000_000, 3_000_000_000]); +/// let b = cast(&a, &data_type).unwrap(); // cast to timestamp without timezone +/// let b = b.as_primitive::(); // downcast to result type +/// assert_eq!(2_000_000_000, b.value(1)); // values are still the same +/// // displayed without a timezone (note lack of offset or Z) +/// assert_eq!("2033-05-18T03:33:20", display::array_value_to_string(&b, 1).unwrap()); +/// +/// // Convert timestamps without a timezone to timestamps with a timezone +/// let c = cast(&b, &data_type_tz).unwrap(); +/// let c = c.as_primitive::(); // downcast to result type +/// assert_eq!(2_000_018_000, c.value(1)); // value has been adjusted by offset +/// // displayed with the target timezone offset (-05:00) +/// assert_eq!("2033-05-18T03:33:20-05:00", display::array_value_to_string(&c, 1).unwrap()); +/// +/// // Convert from timestamp with timezone back to timestamp without timezone +/// let d = cast(&c, &data_type).unwrap(); +/// let d = d.as_primitive::(); // downcast to result type +/// assert_eq!(2_000_018_000, d.value(1)); // value has not been adjusted +/// // NOTE: the timestamp is adjusted (08:33:20 instead of 03:33:20 as in previous example) +/// assert_eq!("2033-05-18T08:33:20", display::array_value_to_string(&d, 1).unwrap()); +/// ``` +pub fn cast_with_options( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + let from_type = array.data_type(); + // clone array if types are the same + if from_type == to_type { + return Ok(make_array(array.to_data())); + } + match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | BinaryView + | Utf8View + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) => Ok(new_null_array(to_type, array.len())), + (Dictionary(index_type, _), _) => match **index_type { + Int8 => dictionary_cast::(array, to_type, cast_options), + Int16 => dictionary_cast::(array, to_type, cast_options), + Int32 => dictionary_cast::(array, to_type, cast_options), + Int64 => dictionary_cast::(array, to_type, cast_options), + UInt8 => dictionary_cast::(array, to_type, cast_options), + UInt16 => dictionary_cast::(array, to_type, cast_options), + UInt32 => dictionary_cast::(array, to_type, cast_options), + UInt64 => dictionary_cast::(array, to_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from dictionary type {from_type:?} to {to_type:?} not supported", + ))), + }, + (_, Dictionary(index_type, value_type)) => match **index_type { + Int8 => cast_to_dictionary::(array, value_type, cast_options), + Int16 => cast_to_dictionary::(array, value_type, cast_options), + Int32 => cast_to_dictionary::(array, value_type, cast_options), + Int64 => cast_to_dictionary::(array, value_type, cast_options), + UInt8 => cast_to_dictionary::(array, value_type, cast_options), + UInt16 => cast_to_dictionary::(array, value_type, cast_options), + UInt32 => cast_to_dictionary::(array, value_type, cast_options), + UInt64 => cast_to_dictionary::(array, value_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from type {from_type:?} to dictionary type {to_type:?} not supported", + ))), + }, + (List(_), List(to)) => cast_list_values::(array, to, cast_options), + (LargeList(_), LargeList(to)) => cast_list_values::(array, to, cast_options), + (List(_), LargeList(list_to)) => cast_list::(array, list_to, cast_options), + (LargeList(_), List(list_to)) => cast_list::(array, list_to, cast_options), + (List(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } + (LargeList(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } + (List(_) | LargeList(_), _) => match to_type { + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + _ => Err(ArrowError::CastError( + "Cannot cast list to non-list data types".to_string(), + )), + }, + (FixedSizeList(list_from, size), List(list_to)) => { + if list_to.data_type() != list_from.data_type() { + // To transform inner type, can first cast to FSL with new inner type. + let fsl_to = DataType::FixedSizeList(list_to.clone(), *size); + let array = cast_with_options(array, &fsl_to, cast_options)?; + cast_fixed_size_list_to_list::(array.as_ref()) + } else { + cast_fixed_size_list_to_list::(array) + } + } + (FixedSizeList(list_from, size), LargeList(list_to)) => { + if list_to.data_type() != list_from.data_type() { + // To transform inner type, can first cast to FSL with new inner type. + let fsl_to = DataType::FixedSizeList(list_to.clone(), *size); + let array = cast_with_options(array, &fsl_to, cast_options)?; + cast_fixed_size_list_to_list::(array.as_ref()) + } else { + cast_fixed_size_list_to_list::(array) + } + } + (FixedSizeList(_, size_from), FixedSizeList(list_to, size_to)) => { + if size_from != size_to { + return Err(ArrowError::CastError( + "cannot cast fixed-size-list to fixed-size-list with different size".into(), + )); + } + let array = array.as_any().downcast_ref::().unwrap(); + let values = cast_with_options(array.values(), list_to.data_type(), cast_options)?; + Ok(Arc::new(FixedSizeListArray::try_new( + list_to.clone(), + *size_from, + values, + array.nulls().cloned(), + )?)) + } + (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), + (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), + (_, FixedSizeList(ref to, size)) if *size == 1 => { + cast_values_to_fixed_size_list(array, to, *size, cast_options) + } + (FixedSizeList(_, size), _) if *size == 1 => { + cast_single_element_fixed_size_list_to_values(array, to_type, cast_options) + } + (Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => { + cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) + } + (Decimal128(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(_, scale), _) if !to_type.is_temporal() => { + // cast decimal to other type + match to_type { + UInt8 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt16 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt32 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt64 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int8 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int16 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int32 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int64 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Float32 => cast_decimal_to_float::(array, |x| { + (x as f64 / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x as f64 / 10_f64.powi(*scale as i32) + }), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (Decimal256(_, scale), _) if !to_type.is_temporal() => { + // cast decimal to other type + match to_type { + UInt8 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt16 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt32 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt64 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int8 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int16 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int32 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int64 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Float32 => cast_decimal_to_float::(array, |x| { + (x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x.to_f64().unwrap() / 10_f64.powi(*scale as i32) + }), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { + // cast data to decimal + match from_type { + UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt16 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt32 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt64 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Float32 => cast_floating_point_to_decimal128( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Float64 => cast_floating_point_to_decimal128( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (_, Decimal256(precision, scale)) if !from_type.is_temporal() => { + // cast data to decimal + match from_type { + UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt16 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt32 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt64 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Float32 => cast_floating_point_to_decimal256( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Float64 => cast_floating_point_to_decimal256( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (Struct(_), Struct(to_fields)) => { + let array = array.as_struct(); + let fields = array + .columns() + .iter() + .zip(to_fields.iter()) + .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) + .collect::, ArrowError>>()?; + let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; + Ok(Arc::new(array) as ArrayRef) + } + (Struct(_), _) => Err(ArrowError::CastError( + "Cannot cast from struct to other types except struct".to_string(), + )), + (_, Struct(_)) => Err(ArrowError::CastError( + "Cannot cast to struct from other types except struct".to_string(), + )), + (_, Boolean) => match from_type { + UInt8 => cast_numeric_to_bool::(array), + UInt16 => cast_numeric_to_bool::(array), + UInt32 => cast_numeric_to_bool::(array), + UInt64 => cast_numeric_to_bool::(array), + Int8 => cast_numeric_to_bool::(array), + Int16 => cast_numeric_to_bool::(array), + Int32 => cast_numeric_to_bool::(array), + Int64 => cast_numeric_to_bool::(array), + Float16 => cast_numeric_to_bool::(array), + Float32 => cast_numeric_to_bool::(array), + Float64 => cast_numeric_to_bool::(array), + Utf8 => cast_utf8_to_boolean::(array, cast_options), + LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Boolean, _) => match to_type { + UInt8 => cast_bool_to_numeric::(array, cast_options), + UInt16 => cast_bool_to_numeric::(array, cast_options), + UInt32 => cast_bool_to_numeric::(array, cast_options), + UInt64 => cast_bool_to_numeric::(array, cast_options), + Int8 => cast_bool_to_numeric::(array, cast_options), + Int16 => cast_bool_to_numeric::(array, cast_options), + Int32 => cast_bool_to_numeric::(array, cast_options), + Int64 => cast_bool_to_numeric::(array, cast_options), + Float16 => cast_bool_to_numeric::(array, cast_options), + Float32 => cast_bool_to_numeric::(array, cast_options), + Float64 => cast_bool_to_numeric::(array, cast_options), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Utf8, _) => match to_type { + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), + Binary => Ok(Arc::new(BinaryArray::from( + array.as_string::().clone(), + ))), + LargeBinary => { + let binary = BinaryArray::from(array.as_string::().clone()); + cast_byte_container::(&binary) + } + Utf8View => Ok(Arc::new(StringViewArray::from(array.as_string::()))), + BinaryView => Ok(Arc::new( + StringViewArray::from(array.as_string::()).to_binary_view(), + )), + LargeUtf8 => cast_byte_container::(array), + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) + } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMicrosecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Nanosecond, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Interval(IntervalUnit::YearMonth) => { + cast_string_to_year_month_interval::(array, cast_options) + } + Interval(IntervalUnit::DayTime) => { + cast_string_to_day_time_interval::(array, cast_options) + } + Interval(IntervalUnit::MonthDayNano) => { + cast_string_to_month_day_nano_interval::(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Utf8View, _) => match to_type { + UInt8 => parse_string_view::(array, cast_options), + UInt16 => parse_string_view::(array, cast_options), + UInt32 => parse_string_view::(array, cast_options), + UInt64 => parse_string_view::(array, cast_options), + Int8 => parse_string_view::(array, cast_options), + Int16 => parse_string_view::(array, cast_options), + Int32 => parse_string_view::(array, cast_options), + Int64 => parse_string_view::(array, cast_options), + Float32 => parse_string_view::(array, cast_options), + Float64 => parse_string_view::(array, cast_options), + Date32 => parse_string_view::(array, cast_options), + Date64 => parse_string_view::(array, cast_options), + Binary => cast_view_to_byte::>(array), + LargeBinary => cast_view_to_byte::>(array), + BinaryView => Ok(Arc::new(array.as_string_view().clone().to_binary_view())), + Utf8 => cast_view_to_byte::>(array), + LargeUtf8 => cast_view_to_byte::>(array), + Time32(TimeUnit::Second) => parse_string_view::(array, cast_options), + Time32(TimeUnit::Millisecond) => { + parse_string_view::(array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + parse_string_view::(array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + parse_string_view::(array, cast_options) + } + Timestamp(TimeUnit::Second, to_tz) => { + cast_view_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => { + cast_view_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Microsecond, to_tz) => { + cast_view_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Nanosecond, to_tz) => { + cast_view_to_timestamp::(array, to_tz, cast_options) + } + Interval(IntervalUnit::YearMonth) => { + cast_view_to_year_month_interval(array, cast_options) + } + Interval(IntervalUnit::DayTime) => cast_view_to_day_time_interval(array, cast_options), + Interval(IntervalUnit::MonthDayNano) => { + cast_view_to_month_day_nano_interval(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeUtf8, _) => match to_type { + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), + Utf8 => cast_byte_container::(array), + Binary => { + let large_binary = LargeBinaryArray::from(array.as_string::().clone()); + cast_byte_container::(&large_binary) + } + LargeBinary => Ok(Arc::new(LargeBinaryArray::from( + array.as_string::().clone(), + ))), + Utf8View => Ok(Arc::new(StringViewArray::from(array.as_string::()))), + BinaryView => Ok(Arc::new(BinaryViewArray::from( + array + .as_string::() + .into_iter() + .map(|x| x.map(|x| x.as_bytes())) + .collect::>(), + ))), + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) + } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMicrosecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Nanosecond, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Interval(IntervalUnit::YearMonth) => { + cast_string_to_year_month_interval::(array, cast_options) + } + Interval(IntervalUnit::DayTime) => { + cast_string_to_day_time_interval::(array, cast_options) + } + Interval(IntervalUnit::MonthDayNano) => { + cast_string_to_month_day_nano_interval::(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Binary, _) => match to_type { + Utf8 => cast_binary_to_string::(array, cast_options), + LargeUtf8 => { + let array = cast_binary_to_string::(array, cast_options)?; + cast_byte_container::(array.as_ref()) + } + LargeBinary => cast_byte_container::(array), + FixedSizeBinary(size) => { + cast_binary_to_fixed_size_binary::(array, *size, cast_options) + } + BinaryView => Ok(Arc::new(BinaryViewArray::from(array.as_binary::()))), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeBinary, _) => match to_type { + Utf8 => { + let array = cast_binary_to_string::(array, cast_options)?; + cast_byte_container::(array.as_ref()) + } + LargeUtf8 => cast_binary_to_string::(array, cast_options), + Binary => cast_byte_container::(array), + FixedSizeBinary(size) => { + cast_binary_to_fixed_size_binary::(array, *size, cast_options) + } + BinaryView => Ok(Arc::new(BinaryViewArray::from(array.as_binary::()))), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (FixedSizeBinary(size), _) => match to_type { + Binary => cast_fixed_size_binary_to_binary::(array, *size), + LargeBinary => cast_fixed_size_binary_to_binary::(array, *size), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (BinaryView, Binary) => cast_view_to_byte::>(array), + (BinaryView, LargeBinary) => { + cast_view_to_byte::>(array) + } + (BinaryView, Utf8) => { + let binary_arr = cast_view_to_byte::>(array)?; + cast_binary_to_string::(&binary_arr, cast_options) + } + (BinaryView, LargeUtf8) => { + let binary_arr = cast_view_to_byte::>(array)?; + cast_binary_to_string::(&binary_arr, cast_options) + } + (BinaryView, Utf8View) => { + Ok(Arc::new(array.as_binary_view().clone().to_string_view()?) as ArrayRef) + } + (BinaryView, _) => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + (from_type, LargeUtf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) + } + (from_type, Utf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) + } + (from_type, Binary) if from_type.is_integer() => match from_type { + UInt8 => cast_numeric_to_binary::(array), + UInt16 => cast_numeric_to_binary::(array), + UInt32 => cast_numeric_to_binary::(array), + UInt64 => cast_numeric_to_binary::(array), + Int8 => cast_numeric_to_binary::(array), + Int16 => cast_numeric_to_binary::(array), + Int32 => cast_numeric_to_binary::(array), + Int64 => cast_numeric_to_binary::(array), + _ => unreachable!(), + }, + (from_type, LargeBinary) if from_type.is_integer() => match from_type { + UInt8 => cast_numeric_to_binary::(array), + UInt16 => cast_numeric_to_binary::(array), + UInt32 => cast_numeric_to_binary::(array), + UInt64 => cast_numeric_to_binary::(array), + Int8 => cast_numeric_to_binary::(array), + Int16 => cast_numeric_to_binary::(array), + Int32 => cast_numeric_to_binary::(array), + Int64 => cast_numeric_to_binary::(array), + _ => unreachable!(), + }, + // start numeric casts + (UInt8, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt16, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt32, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt64, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int8, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int8, Int16) => cast_numeric_arrays::(array, cast_options), + (Int8, Int32) => cast_numeric_arrays::(array, cast_options), + (Int8, Int64) => cast_numeric_arrays::(array, cast_options), + (Int8, Float16) => cast_numeric_arrays::(array, cast_options), + (Int8, Float32) => cast_numeric_arrays::(array, cast_options), + (Int8, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int16, Int8) => cast_numeric_arrays::(array, cast_options), + (Int16, Int32) => cast_numeric_arrays::(array, cast_options), + (Int16, Int64) => cast_numeric_arrays::(array, cast_options), + (Int16, Float16) => cast_numeric_arrays::(array, cast_options), + (Int16, Float32) => cast_numeric_arrays::(array, cast_options), + (Int16, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int32, Int8) => cast_numeric_arrays::(array, cast_options), + (Int32, Int16) => cast_numeric_arrays::(array, cast_options), + (Int32, Int64) => cast_numeric_arrays::(array, cast_options), + (Int32, Float16) => cast_numeric_arrays::(array, cast_options), + (Int32, Float32) => cast_numeric_arrays::(array, cast_options), + (Int32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int64, Int8) => cast_numeric_arrays::(array, cast_options), + (Int64, Int16) => cast_numeric_arrays::(array, cast_options), + (Int64, Int32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float16) => cast_numeric_arrays::(array, cast_options), + (Int64, Float32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float16, Int8) => cast_numeric_arrays::(array, cast_options), + (Float16, Int16) => cast_numeric_arrays::(array, cast_options), + (Float16, Int32) => cast_numeric_arrays::(array, cast_options), + (Float16, Int64) => cast_numeric_arrays::(array, cast_options), + (Float16, Float32) => cast_numeric_arrays::(array, cast_options), + (Float16, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float32, Int8) => cast_numeric_arrays::(array, cast_options), + (Float32, Int16) => cast_numeric_arrays::(array, cast_options), + (Float32, Int32) => cast_numeric_arrays::(array, cast_options), + (Float32, Int64) => cast_numeric_arrays::(array, cast_options), + (Float32, Float16) => cast_numeric_arrays::(array, cast_options), + (Float32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float64, Int8) => cast_numeric_arrays::(array, cast_options), + (Float64, Int16) => cast_numeric_arrays::(array, cast_options), + (Float64, Int32) => cast_numeric_arrays::(array, cast_options), + (Float64, Int64) => cast_numeric_arrays::(array, cast_options), + (Float64, Float16) => cast_numeric_arrays::(array, cast_options), + (Float64, Float32) => cast_numeric_arrays::(array, cast_options), + // end numeric casts + + // temporal casts + (Int32, Date32) => cast_reinterpret_arrays::(array), + (Int32, Date64) => cast_with_options( + &cast_with_options(array, &Date32, cast_options)?, + &Date64, + cast_options, + ), + (Int32, Time32(TimeUnit::Second)) => { + cast_reinterpret_arrays::(array) + } + (Int32, Time32(TimeUnit::Millisecond)) => { + cast_reinterpret_arrays::(array) + } + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => cast_reinterpret_arrays::(array), + (Date32, Int64) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Int64, + cast_options, + ), + (Time32(TimeUnit::Second), Int32) => { + cast_reinterpret_arrays::(array) + } + (Time32(TimeUnit::Millisecond), Int32) => { + cast_reinterpret_arrays::(array) + } + (Int64, Date64) => cast_reinterpret_arrays::(array), + (Int64, Date32) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Date32, + cast_options, + ), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Time64(TimeUnit::Nanosecond)) => { + cast_reinterpret_arrays::(array) + } + + (Date64, Int64) => cast_reinterpret_arrays::(array), + (Date64, Int32) => cast_with_options( + &cast_with_options(array, &Int64, cast_options)?, + &Int32, + cast_options, + ), + (Time64(TimeUnit::Microsecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Time64(TimeUnit::Nanosecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Date32, Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x as i64 * MILLISECONDS_IN_DAY), + )), + (Date64, Date32) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date32Type>(|x| (x / MILLISECONDS_IN_DAY) as i32), + )), + + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| x * MILLISECONDS as i32), + )), + (Time32(TimeUnit::Second), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x as i64 * MICROSECONDS), + )), + (Time32(TimeUnit::Second), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x as i64 * NANOSECONDS), + )), + + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| x / MILLISECONDS as i32), + )), + (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x as i64 * (MICROSECONDS / MILLISECONDS)), + )), + (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x as i64 * (MICROSECONDS / NANOSECONDS)), + )), + + (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| (x / MICROSECONDS) as i32), + )), + (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (MICROSECONDS / MILLISECONDS)) as i32), + )), + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x * (NANOSECONDS / MICROSECONDS)), + )), + + (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| (x / NANOSECONDS) as i32), + )), + (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (NANOSECONDS / MILLISECONDS)) as i32), + )), + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)), + )), + + // Timestamp to integer/floating/decimals + (Timestamp(TimeUnit::Second, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Millisecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Microsecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Nanosecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + + (_, Timestamp(unit, tz)) if from_type.is_numeric() => { + let array = cast_with_options(array, &Int64, cast_options)?; + Ok(make_timestamp_array( + array.as_primitive(), + *unit, + tz.clone(), + )) + } + + (Timestamp(from_unit, from_tz), Timestamp(to_unit, to_tz)) => { + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = array.as_primitive::(); + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + // we either divide or multiply, depending on size of each unit + // units are never the same when the types are the same + let converted = match from_size.cmp(&to_size) { + Ordering::Greater => { + let divisor = from_size / to_size; + time_array.unary::<_, Int64Type>(|o| o / divisor) + } + Ordering::Equal => time_array.clone(), + Ordering::Less => { + let mul = to_size / from_size; + if cast_options.safe { + time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul)) + } else { + time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))? + } + } + }; + // Normalize timezone + let adjusted = match (from_tz, to_tz) { + // Only this case needs to be adjusted because we're casting from + // unknown time offset to some time offset, we want the time to be + // unchanged. + // + // i.e. Timestamp('2001-01-01T00:00', None) -> Timestamp('2001-01-01T00:00', '+0700') + (None, Some(to_tz)) => { + let to_tz: Tz = to_tz.parse()?; + match to_unit { + TimeUnit::Second => adjust_timestamp_to_timezone::( + converted, + &to_tz, + cast_options, + )?, + TimeUnit::Millisecond => adjust_timestamp_to_timezone::< + TimestampMillisecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Microsecond => adjust_timestamp_to_timezone::< + TimestampMicrosecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Nanosecond => adjust_timestamp_to_timezone::< + TimestampNanosecondType, + >( + converted, &to_tz, cast_options + )?, + } + } + _ => converted, + }; + Ok(make_timestamp_array(&adjusted, *to_unit, to_tz.clone())) + } + (Timestamp(TimeUnit::Microsecond, _), Date32) => { + timestamp_to_date32(array.as_primitive::()) + } + (Timestamp(TimeUnit::Millisecond, _), Date32) => { + timestamp_to_date32(array.as_primitive::()) + } + (Timestamp(TimeUnit::Second, _), Date32) => { + timestamp_to_date32(array.as_primitive::()) + } + (Timestamp(TimeUnit::Nanosecond, _), Date32) => { + timestamp_to_date32(array.as_primitive::()) + } + (Timestamp(TimeUnit::Second, _), Date64) => Ok(Arc::new(match cast_options.safe { + true => { + // change error to None + array + .as_primitive::() + .unary_opt::<_, Date64Type>(|x| x.checked_mul(MILLISECONDS)) + } + false => array + .as_primitive::() + .try_unary::<_, Date64Type, _>(|x| x.mul_checked(MILLISECONDS))?, + })), + (Timestamp(TimeUnit::Millisecond, _), Date64) => { + cast_reinterpret_arrays::(array) + } + (Timestamp(TimeUnit::Microsecond, _), Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x / (MICROSECONDS / MILLISECONDS)), + )), + (Timestamp(TimeUnit::Nanosecond, _), Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x / (NANOSECONDS / MILLISECONDS)), + )), + (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Date64, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampSecondType>(|x| x / MILLISECONDS), + )), + (Date64, Timestamp(TimeUnit::Millisecond, None)) => { + cast_reinterpret_arrays::(array) + } + (Date64, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| x * (MICROSECONDS / MILLISECONDS)), + )), + (Date64, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampNanosecondType>(|x| x * (NANOSECONDS / MILLISECONDS)), + )), + (Date32, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampSecondType>(|x| (x as i64) * SECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Millisecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMillisecondType>(|x| (x as i64) * MILLISECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| (x as i64) * MICROSECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY), + )), + + (_, Duration(unit)) if from_type.is_numeric() => { + let array = cast_with_options(array, &Int64, cast_options)?; + Ok(make_duration_array(array.as_primitive(), *unit)) + } + (Duration(TimeUnit::Second), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + + (Duration(from_unit), Duration(to_unit)) => { + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = array.as_primitive::(); + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + // we either divide or multiply, depending on size of each unit + // units are never the same when the types are the same + let converted = match from_size.cmp(&to_size) { + Ordering::Greater => { + let divisor = from_size / to_size; + time_array.unary::<_, Int64Type>(|o| o / divisor) + } + Ordering::Equal => time_array.clone(), + Ordering::Less => { + let mul = to_size / from_size; + if cast_options.safe { + time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul)) + } else { + time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))? + } + } + }; + Ok(make_duration_array(&converted, *to_unit)) + } + + (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Second)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Millisecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Microsecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Nanosecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { + cast_interval_year_month_to_interval_month_day_nano(array, cast_options) + } + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::MonthDayNano)) => { + cast_interval_day_time_to_interval_month_day_nano(array, cast_options) + } + (Int32, Interval(IntervalUnit::YearMonth)) => { + cast_reinterpret_arrays::(array) + } + (_, _) => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } +} + +/// Get the time unit as a multiple of a second +const fn time_unit_multiple(unit: &TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +/// Convert Array into a PrimitiveArray of type, and apply numeric cast +fn cast_numeric_arrays( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + FROM: ArrowPrimitiveType, + TO: ArrowPrimitiveType, + FROM::Native: NumCast, + TO::Native: NumCast, +{ + if cast_options.safe { + // If the value can't be casted to the `TO::Native`, return null + Ok(Arc::new(numeric_cast::( + from.as_primitive::(), + ))) + } else { + // If the value can't be casted to the `TO::Native`, return error + Ok(Arc::new(try_numeric_cast::( + from.as_primitive::(), + )?)) + } +} + +// Natural cast between numeric types +// If the value of T can't be casted to R, will throw error +fn try_numeric_cast(from: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowPrimitiveType, + R: ArrowPrimitiveType, + T::Native: NumCast, + R::Native: NumCast, +{ + from.try_unary(|value| { + num::cast::cast::(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Can't cast value {:?} to type {}", + value, + R::DATA_TYPE + )) + }) + }) +} + +// Natural cast between numeric types +// If the value of T can't be casted to R, it will be converted to null +fn numeric_cast(from: &PrimitiveArray) -> PrimitiveArray +where + T: ArrowPrimitiveType, + R: ArrowPrimitiveType, + T::Native: NumCast, + R::Native: NumCast, +{ + from.unary_opt::<_, R>(num::cast::cast::) +} + +fn cast_numeric_to_binary( + array: &dyn Array, +) -> Result { + let array = array.as_primitive::(); + let size = std::mem::size_of::(); + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(array.len())); + Ok(Arc::new(GenericBinaryArray::::new( + offsets, + array.values().inner().clone(), + array.nulls().cloned(), + ))) +} + +fn adjust_timestamp_to_timezone( + array: PrimitiveArray, + to_tz: &Tz, + cast_options: &CastOptions, +) -> Result, ArrowError> { + let adjust = |o| { + let local = as_datetime::(o)?; + let offset = to_tz.offset_from_local_datetime(&local).single()?; + T::make_value(local - offset.fix()) + }; + let adjusted = if cast_options.safe { + array.unary_opt::<_, Int64Type>(adjust) + } else { + array.try_unary::<_, Int64Type, _>(|o| { + adjust(o).ok_or_else(|| { + ArrowError::CastError("Cannot cast timezone to different timezone".to_string()) + }) + })? + }; + Ok(adjusted) +} + +/// Cast numeric types to Boolean +/// +/// Any zero value returns `false` while non-zero returns `true` +fn cast_numeric_to_bool(from: &dyn Array) -> Result +where + FROM: ArrowPrimitiveType, +{ + numeric_to_bool_cast::(from.as_primitive::()).map(|to| Arc::new(to) as ArrayRef) +} + +fn numeric_to_bool_cast(from: &PrimitiveArray) -> Result +where + T: ArrowPrimitiveType + ArrowPrimitiveType, +{ + let mut b = BooleanBuilder::with_capacity(from.len()); + + for i in 0..from.len() { + if from.is_null(i) { + b.append_null(); + } else if from.value(i) != T::default_value() { + b.append_value(true); + } else { + b.append_value(false); + } + } + + Ok(b.finish()) +} + +/// Cast Boolean types to numeric +/// +/// `false` returns 0 while `true` returns 1 +fn cast_bool_to_numeric( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + TO: ArrowPrimitiveType, + TO::Native: num::cast::NumCast, +{ + Ok(Arc::new(bool_to_numeric_cast::( + from.as_any().downcast_ref::().unwrap(), + cast_options, + ))) +} + +fn bool_to_numeric_cast(from: &BooleanArray, _cast_options: &CastOptions) -> PrimitiveArray +where + T: ArrowPrimitiveType, + T::Native: num::NumCast, +{ + let iter = (0..from.len()).map(|i| { + if from.is_null(i) { + None + } else if from.value(i) { + // a workaround to cast a primitive to T::Native, infallible + num::cast::cast(1) + } else { + Some(T::default_value()) + } + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from a Range + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } +} + +/// Helper function to cast from one `BinaryArray` or 'LargeBinaryArray' to 'FixedSizeBinaryArray'. +fn cast_binary_to_fixed_size_binary( + array: &dyn Array, + byte_width: i32, + cast_options: &CastOptions, +) -> Result { + let array = array.as_binary::(); + let mut builder = FixedSizeBinaryBuilder::with_capacity(array.len(), byte_width); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + match builder.append_value(array.value(i)) { + Ok(_) => {} + Err(e) => match cast_options.safe { + true => builder.append_null(), + false => return Err(e), + }, + } + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Helper function to cast from 'FixedSizeBinaryArray' to one `BinaryArray` or 'LargeBinaryArray'. +/// If the target one is too large for the source array it will return an Error. +fn cast_fixed_size_binary_to_binary( + array: &dyn Array, + byte_width: i32, +) -> Result { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + let offsets: i128 = byte_width as i128 * array.len() as i128; + + let is_binary = matches!(GenericBinaryType::::DATA_TYPE, DataType::Binary); + if is_binary && offsets > i32::MAX as i128 { + return Err(ArrowError::ComputeError( + "FixedSizeBinary array too large to cast to Binary array".to_string(), + )); + } else if !is_binary && offsets > i64::MAX as i128 { + return Err(ArrowError::ComputeError( + "FixedSizeBinary array too large to cast to LargeBinary array".to_string(), + )); + } + + let mut builder = GenericBinaryBuilder::::with_capacity(array.len(), array.len()); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(array.value(i)); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Helper function to cast from one `ByteArrayType` to another and vice versa. +/// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error. +fn cast_byte_container(array: &dyn Array) -> Result +where + FROM: ByteArrayType, + TO: ByteArrayType, + FROM::Offset: OffsetSizeTrait + ToPrimitive, + TO::Offset: OffsetSizeTrait + NumCast, +{ + let data = array.to_data(); + assert_eq!(data.data_type(), &FROM::DATA_TYPE); + let str_values_buf = data.buffers()[1].clone(); + let offsets = data.buffers()[0].typed_data::(); + + let mut offset_builder = BufferBuilder::::new(offsets.len()); + offsets + .iter() + .try_for_each::<_, Result<_, ArrowError>>(|offset| { + let offset = + <::Offset as NumCast>::from(*offset).ok_or_else(|| { + ArrowError::ComputeError(format!( + "{}{} array too large to cast to {}{} array", + FROM::Offset::PREFIX, + FROM::PREFIX, + TO::Offset::PREFIX, + TO::PREFIX + )) + })?; + offset_builder.append(offset); + Ok(()) + })?; + + let offset_buffer = offset_builder.finish(); + + let dtype = TO::DATA_TYPE; + + let builder = ArrayData::builder(dtype) + .offset(array.offset()) + .len(array.len()) + .add_buffer(offset_buffer) + .add_buffer(str_values_buf) + .nulls(data.nulls().cloned()); + + let array_data = unsafe { builder.build_unchecked() }; + + Ok(Arc::new(GenericByteArray::::from(array_data))) +} + +/// Helper function to cast from one `ByteViewType` array to `ByteArrayType` array. +fn cast_view_to_byte(array: &dyn Array) -> Result +where + FROM: ByteViewType, + TO: ByteArrayType, + FROM::Native: AsRef, +{ + let data = array.to_data(); + let view_array = GenericByteViewArray::::from(data); + + let len = view_array.len(); + let bytes = view_array + .views() + .iter() + .map(|v| ByteView::from(*v).length as usize) + .sum::(); + + let mut byte_array_builder = GenericByteBuilder::::with_capacity(len, bytes); + + for val in view_array.iter() { + byte_array_builder.append_option(val); + } + + Ok(Arc::new(byte_array_builder.finish())) +} + +#[cfg(test)] +mod tests { + use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; + use chrono::NaiveDate; + use half::f16; + + use super::*; + + macro_rules! generate_cast_test_case { + ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + let output = + $OUTPUT_TYPE_ARRAY::from($OUTPUT_VALUES).with_data_type($OUTPUT_TYPE.clone()); + + // assert cast type + let input_array_type = $INPUT_ARRAY.data_type(); + assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); + let result = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); + + let cast_option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); + }; + } + + fn create_decimal_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + fn create_decimal256_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + #[test] + #[cfg(not(feature = "force_validate"))] + #[should_panic( + expected = "Cannot cast to Decimal128(20, 3). Overflowing on 57896044618658097711785492504343953926634992332820282019728792003956564819967" + )] + fn test_cast_decimal_to_decimal_round_with_error() { + // decimal256 to decimal128 overflow + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + Some(i256::MAX), + Some(i256::MIN), + ]; + let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let input_type = DataType::Decimal256(76, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None, + None, + None, + ] + ); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_decimal_to_decimal_round() { + let array = vec![ + Some(1123454), + Some(2123456), + Some(-3123453), + Some(-3123456), + None, + ]; + let array = create_decimal_array(array, 20, 4).unwrap(); + // decimal128 to decimal128 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + + // decimal128 to decimal256 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + + // decimal256 + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 4).unwrap(); + + // decimal256 to decimal256 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + // decimal256 to decimal128 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + } + + #[test] + fn test_cast_decimal128_to_decimal128() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal128(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(11234560_i128), + Some(21234560_i128), + Some(31234560_i128), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal_array(array, 10, 0).unwrap(); + let result = cast(&array, &DataType::Decimal128(2, 2)); + assert!(result.is_ok()); + let array = result.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + let err = array.validate_decimal_precision(2); + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", + err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal128_dict() { + let p = 20; + let s = 3; + let input_type = DataType::Decimal128(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal128(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + + #[test] + fn test_cast_decimal256_to_decimal256_dict() { + let p = 20; + let s = 3; + let input_type = DataType::Decimal256(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal256(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + + #[test] + fn test_cast_decimal128_to_decimal128_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal128(38, 38); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let array = create_decimal_array(array, 38, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal128(38, 38). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal256_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal256(76, 76); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let array = create_decimal_array(array, 38, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal256(76, 76). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal256() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal128_overflow() { + let input_type = DataType::Decimal256(76, 5); + let output_type = DataType::Decimal128(38, 7); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i256::from_i128(i128::MAX))]; + let array = create_decimal256_array(array, 76, 5).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal128(38, 7). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal256_to_decimal256_overflow() { + let input_type = DataType::Decimal256(76, 5); + let output_type = DataType::Decimal256(76, 55); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i256::from_i128(i128::MAX))]; + let array = create_decimal256_array(array, 76, 5).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal256(76, 55). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal256_to_decimal128() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal128(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(i256::from_i128(1123456)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(11234560_i128), + Some(21234560_i128), + Some(31234560_i128), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal256() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(i256::from_i128(1123456)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + + #[test] + fn test_cast_decimal_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + // u8 + generate_cast_test_case!( + &array, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + &array, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + &array, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + &array, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); + + // overflow test: out of range of max u8 + let value_array: Vec> = vec![Some(51300)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 513 is out of range UInt8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(24400)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678), + Some(112345679), + ]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678901234568), + Some(112345678901234560), + ]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_decimal256_to_numeric() { + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + ]; + let array = create_decimal256_array(value_array, 38, 2).unwrap(); + // u8 + generate_cast_test_case!( + &array, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + &array, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + &array, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + &array, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(i256::from_i128(24400))]; + let array = create_decimal256_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + Some(i256::from_i128(112345678)), + Some(i256::from_i128(112345679)), + ]; + let array = create_decimal256_array(value_array, 76, 2).unwrap(); + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + Some(i256::from_i128(112345678901234568)), + Some(i256::from_i128(112345678901234560)), + ]; + let array = create_decimal256_array(value_array, 76, 2).unwrap(); + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal128() { + let decimal_type = DataType::Decimal128(38, 6); + // u8, u16, u32, u64 + let input_datas = vec![ + Arc::new(UInt8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u8 + Arc::new(UInt16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u16 + Arc::new(UInt32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u32 + Arc::new(UInt64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u64 + ]; + + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1000000_i128), + Some(2000000_i128), + Some(3000000_i128), + None, + Some(5000000_i128) + ] + ); + } + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1000000_i128), + Some(2000000_i128), + Some(3000000_i128), + None, + Some(5000000_i128) + ] + ); + } + + // test u8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = UInt8Array::from(vec![1, 2, 3, 4, 100]); + let casted_array = cast(&array, &DataType::Decimal128(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let casted_array = cast(&array, &DataType::Decimal128(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), // round down + Some(1123457_i128), // round up + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_489_123_4), // round up + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up + ]); + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), // round down + Some(1123457_i128), // round up + Some(1123456_i128), // round down + Some(1123457_i128), // round up + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal256() { + let decimal_type = DataType::Decimal256(76, 6); + // u8, u16, u32, u64 + let input_datas = vec![ + Arc::new(UInt8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u8 + Arc::new(UInt16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u16 + Arc::new(UInt32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u32 + Arc::new(UInt64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u64 + ]; + + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1000000_i128)), + Some(i256::from_i128(2000000_i128)), + Some(i256::from_i128(3000000_i128)), + None, + Some(i256::from_i128(5000000_i128)) + ] + ); + } + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1000000_i128)), + Some(i256::from_i128(2000000_i128)), + Some(i256::from_i128(3000000_i128)), + None, + Some(i256::from_i128(5000000_i128)) + ] + ); + } + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Decimal256(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal256Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up + ]); + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_489_123_4), // round down + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up + ]); + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + ] + ); + } + + #[test] + fn test_cast_i32_to_f64() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5.0, c.value(0)); + assert_eq!(6.0, c.value(1)); + assert_eq!(7.0, c.value(2)); + assert_eq!(8.0, c.value(3)); + assert_eq!(9.0, c.value(4)); + } + + #[test] + fn test_cast_i32_to_u8() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + let b = cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_primitive::(); + assert!(!c.is_valid(0)); + assert_eq!(6, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8, c.value(3)); + // overflows return None + assert!(!c.is_valid(4)); + } + + #[test] + #[should_panic(expected = "Can't cast value -5 to type UInt8")] + fn test_cast_int32_to_u8_with_error() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + // overflow with the error + let cast_option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options(&array, &DataType::UInt8, &cast_option); + assert!(result.is_err()); + result.unwrap(); + } + + #[test] + fn test_cast_i32_to_u8_sliced() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + assert_eq!(0, array.offset()); + let array = array.slice(2, 3); + let b = cast(&array, &DataType::UInt8).unwrap(); + assert_eq!(3, b.len()); + let c = b.as_primitive::(); + assert!(!c.is_valid(0)); + assert_eq!(8, c.value(1)); + // overflows return None + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_i32_to_i32() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_i32() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + assert_eq!(5, b.len()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + assert_eq!(1, arr.value_length(4)); + let c = arr.values().as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_i32_nullable() { + let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + assert_eq!(5, b.len()); + assert_eq!(0, b.null_count()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + assert_eq!(1, arr.value_length(4)); + + let c = arr.values().as_primitive::(); + assert_eq!(1, c.null_count()); + assert_eq!(5, c.value(0)); + assert!(!c.is_valid(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_f64_nullable_sliced() { + let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), None, Some(10)]); + let array = array.slice(2, 4); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + ) + .unwrap(); + assert_eq!(4, b.len()); + assert_eq!(0, b.null_count()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + let c = arr.values().as_primitive::(); + assert_eq!(1, c.null_count()); + assert_eq!(7.0, c.value(0)); + assert_eq!(8.0, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(10.0, c.value(3)); + } + + #[test] + fn test_cast_utf8_to_i32() { + let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8, c.value(3)); + assert!(!c.is_valid(4)); + } + + #[test] + fn test_cast_with_options_utf8_to_i32() { + let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); + let result = cast_with_options( + &array, + &DataType::Int32, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + match result { + Ok(_) => panic!("expected error"), + Err(e) => { + assert!( + e.to_string() + .contains("Cast error: Cannot cast string 'seven' to value of Int32 type",), + "Error: {e}" + ) + } + } + } + + #[test] + fn test_cast_utf8_to_bool() { + let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast(&strings, &DataType::Boolean).unwrap(); + let expected = BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + assert_eq!(*as_boolean_array(&casted), expected); + } + + #[test] + fn test_cast_with_options_utf8_to_bool() { + let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast_with_options( + &strings, + &DataType::Boolean, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + match casted { + Ok(_) => panic!("expected error"), + Err(e) => { + assert!(e + .to_string() + .contains("Cast error: Cannot cast value 'invalid' to value of Boolean type")) + } + } + } + + #[test] + fn test_cast_bool_to_i32() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1, c.value(0)); + assert_eq!(0, c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_large_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::LargeUtf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_f64() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1.0, c.value(0)); + assert_eq!(0.0, c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_integer_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Int8Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Int16Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Int32Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt8Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt16Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt32Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt64Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_integer() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast(&cast(&array, &DataType::Int8).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Int16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Int32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt8).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt64).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_floating_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Float16Array::from(vec![ + Some(f16::from_f32(2.0)), + Some(f16::from_f32(10.6)), + None, + ]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Float64Array::from(vec![Some(2.1), Some(10.2), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_floating() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast(&cast(&array, &DataType::Float16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Float64).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_decimal_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Decimal128Array::from(vec![Some(200), Some(1000), None]) + .with_precision_and_scale(4, 2) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(2000)), + Some(i256::from_i128(10000)), + None, + ]) + .with_precision_and_scale(5, 3) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_decimal() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast( + &cast(&array, &DataType::Decimal128(5, 2)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast( + &cast(&array, &DataType::Decimal256(10, 5)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_list_i32_to_list_u16() { + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data(); + + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + // Construct a list array from the above two + // [[0,0,0], [-1, -2, -1], [2, 100000000]] + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = ListArray::from(list_data); + + let cast_array = cast( + &list_array, + &DataType::List(Arc::new(Field::new("item", DataType::UInt16, true))), + ) + .unwrap(); + + // For the ListArray itself, there are no null values (as there were no nulls when they went in) + // + // 3 negative values should get lost when casting to unsigned, + // 1 value should overflow + assert_eq!(0, cast_array.null_count()); + + // offsets should be the same + let array = cast_array.as_list::(); + assert_eq!(list_array.value_offsets(), array.value_offsets()); + + assert_eq!(DataType::UInt16, array.value_type()); + assert_eq!(3, array.value_length(0)); + assert_eq!(3, array.value_length(1)); + assert_eq!(2, array.value_length(2)); + + // expect 4 nulls: negative numbers and overflow + let u16arr = array.values().as_primitive::(); + assert_eq!(4, u16arr.null_count()); + + // expect 4 nulls: negative numbers and overflow + let expected: UInt16Array = + vec![Some(0), Some(0), Some(0), None, None, None, Some(2), None] + .into_iter() + .collect(); + + assert_eq!(u16arr, &expected); + } + + #[test] + fn test_cast_list_i32_to_list_timestamp() { + // Construct a value array + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).into_data(); + + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 9]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + let actual = cast( + &list_array, + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ))), + ) + .unwrap(); + + let expected = cast( + &cast( + &list_array, + &DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + ) + .unwrap(), + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ))), + ) + .unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_date32_to_date64() { + let a = Date32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000000, c.value(0)); + assert_eq!(1545696000000, c.value(1)); + } + + #[test] + fn test_cast_date64_to_date32() { + let a = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_string_to_integral_overflow() { + let str = Arc::new(StringArray::from(vec![ + Some("123"), + Some("-123"), + Some("86374"), + None, + ])) as ArrayRef; + + let options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + let res = cast_with_options(&str, &DataType::Int16, &options).expect("should cast to i16"); + let expected = + Arc::new(Int16Array::from(vec![Some(123), Some(-123), None, None])) as ArrayRef; + assert_eq!(&res, &expected); + } + + #[test] + fn test_cast_string_to_timestamp() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("2020-09-08T12:00:00.123456789+00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-09-08T12:00:00.123456789+00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-09-08T12:00:00.123456789+00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + for time_unit in &[ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let to_type = DataType::Timestamp(*time_unit, None); + let b = cast(array, &to_type).unwrap(); + + match time_unit { + TimeUnit::Second => { + let c = b.as_primitive::(); + assert_eq!(1599566400, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Millisecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Microsecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123456, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Nanosecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123456789, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + } + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Parser error: Error parsing timestamp from 'Not a valid date': error parsing date" + ); + } + } + } + + #[test] + fn test_cast_string_to_timestamp_overflow() { + let array = StringArray::from(vec!["9800-09-08T12:00:00.123456789"]); + let result = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.values(), &[247112596800]); + } + + #[test] + fn test_cast_string_to_date32() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Date32; + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(17890, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type" + ); + } + } + + #[test] + fn test_cast_string_format_yyyymmdd_to_date32() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("2020-12-25"), + Some("20201117"), + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-12-25"), + Some("20201117"), + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-12-25"), + Some("20201117"), + ])) as ArrayRef; + + for array in &[a0, a1, a2] { + let to_type = DataType::Date32; + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options(&array, &to_type, &options).unwrap(); + let c = result.as_primitive::(); + assert_eq!( + chrono::NaiveDate::from_ymd_opt(2020, 12, 25), + c.value_as_date(0) + ); + assert_eq!( + chrono::NaiveDate::from_ymd_opt(2020, 11, 17), + c.value_as_date(1) + ); + } + } + + #[test] + fn test_cast_string_to_time32second() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Time32(TimeUnit::Second); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315, c.value(0)); + assert_eq!(29340, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Second) type"); + } + } + + #[test] + fn test_cast_string_to_time32millisecond() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Time32(TimeUnit::Millisecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091, c.value(0)); + assert_eq!(29340091, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Millisecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64microsecond() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Time64(TimeUnit::Microsecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091323, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Microsecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64nanosecond() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Time64(TimeUnit::Nanosecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091323414, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Nanosecond) type"); + } + } + + #[test] + fn test_cast_string_to_date64() { + let a0 = Arc::new(StringViewArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a0, a1, a2] { + let to_type = DataType::Date64; + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1599566400000, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type" + ); + } + } + + macro_rules! test_safe_string_to_interval { + ($data_vec:expr, $interval_unit:expr, $array_ty:ty, $expect_vec:expr) => { + let source_string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; + + let options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + + let target_interval_array = cast_with_options( + &source_string_array.clone(), + &DataType::Interval($interval_unit), + &options, + ) + .unwrap() + .as_any() + .downcast_ref::<$array_ty>() + .unwrap() + .clone() as $array_ty; + + let target_string_array = + cast_with_options(&target_interval_array, &DataType::Utf8, &options) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + + let expect_string_array = StringArray::from($expect_vec); + + assert_eq!(target_string_array, expect_string_array); + + let target_large_string_array = + cast_with_options(&target_interval_array, &DataType::LargeUtf8, &options) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + + let expect_large_string_array = LargeStringArray::from($expect_vec); + + assert_eq!(target_large_string_array, expect_large_string_array); + }; + } + + #[test] + fn test_cast_string_to_interval_year_month() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month"), + Some("1.5 years 13 month"), + Some("30 days"), + Some("31 days"), + Some("2 months 31 days"), + Some("2 months 31 days 1 second"), + Some("foobar"), + ], + IntervalUnit::YearMonth, + IntervalYearMonthArray, + vec![ + Some("1 years 1 mons"), + Some("2 years 7 mons"), + None, + None, + None, + None, + None, + ] + ); + } + + #[test] + fn test_cast_string_to_interval_day_time() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month"), + Some("1.5 years 13 month"), + Some("30 days"), + Some("1 day 2 second 3.5 milliseconds"), + Some("foobar"), + ], + IntervalUnit::DayTime, + IntervalDayTimeArray, + vec![ + Some("390 days"), + Some("930 days"), + Some("30 days"), + None, + None, + ] + ); + } + + #[test] + fn test_cast_string_to_interval_month_day_nano() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month 1 day"), + None, + Some("1.5 years 13 month 35 days 1.4 milliseconds"), + Some("3 days"), + Some("8 seconds"), + None, + Some("1 day 29800 milliseconds"), + Some("3 months 1 second"), + Some("6 minutes 120 second"), + Some("2 years 39 months 9 days 19 hours 1 minute 83 seconds 399222 milliseconds"), + Some("foobar"), + ], + IntervalUnit::MonthDayNano, + IntervalMonthDayNanoArray, + vec![ + Some("13 mons 1 days"), + None, + Some("31 mons 35 days 0.001400000 secs"), + Some("3 days"), + Some("8.000000000 secs"), + None, + Some("1 days 29.800000000 secs"), + Some("3 mons 1.000000000 secs"), + Some("8 mins"), + Some("63 mons 9 days 19 hours 9 mins 2.222000000 secs"), + None, + ] + ); + } + + macro_rules! test_unsafe_string_to_interval_err { + ($data_vec:expr, $interval_unit:expr, $error_msg:expr) => { + let string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let arrow_err = cast_with_options( + &string_array.clone(), + &DataType::Interval($interval_unit), + &options, + ) + .unwrap_err(); + assert_eq!($error_msg, arrow_err.to_string()); + }; + } + + #[test] + fn test_cast_string_to_interval_err() { + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::YearMonth, + r#"Parser error: Invalid input syntax for type interval: "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::DayTime, + r#"Parser error: Invalid input syntax for type interval: "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::MonthDayNano, + r#"Parser error: Invalid input syntax for type interval: "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("2 months 31 days 1 second")], + IntervalUnit::YearMonth, + r#"Cast error: Cannot cast 2 months 31 days 1 second to IntervalYearMonth. Only year and month fields are allowed."# + ); + test_unsafe_string_to_interval_err!( + vec![Some("1 day 1.5 milliseconds")], + IntervalUnit::DayTime, + r#"Cast error: Cannot cast 1 day 1.5 milliseconds to IntervalDayTime because the nanos part isn't multiple of milliseconds"# + ); + + // overflow + test_unsafe_string_to_interval_err!( + vec![Some(format!( + "{} century {} year {} month", + i64::MAX - 2, + i64::MAX - 2, + i64::MAX - 2 + ))], + IntervalUnit::DayTime, + format!( + "Arithmetic overflow: Overflow happened on: {} * 100", + i64::MAX - 2 + ) + ); + test_unsafe_string_to_interval_err!( + vec![Some(format!( + "{} year {} month {} day", + i64::MAX - 2, + i64::MAX - 2, + i64::MAX - 2 + ))], + IntervalUnit::MonthDayNano, + format!( + "Arithmetic overflow: Overflow happened on: {} * 12", + i64::MAX - 2 + ) + ); + } + + #[test] + fn test_cast_binary_to_fixed_size_binary() { + let bytes_1 = "Hiiii".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(BinaryArray::from(binary_data.clone())) as ArrayRef; + let a2 = Arc::new(LargeBinaryArray::from(binary_data)) as ArrayRef; + + let array_ref = cast(&a1, &DataType::FixedSizeBinary(5)).unwrap(); + let down_cast = array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let array_ref = cast(&a2, &DataType::FixedSizeBinary(5)).unwrap(); + let down_cast = array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + // test error cases when the length of binary are not same + let bytes_1 = "Hi".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(BinaryArray::from(binary_data.clone())) as ArrayRef; + let a2 = Arc::new(LargeBinaryArray::from(binary_data)) as ArrayRef; + + let array_ref = cast_with_options( + &a1, + &DataType::FixedSizeBinary(5), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(array_ref.is_err()); + + let array_ref = cast_with_options( + &a2, + &DataType::FixedSizeBinary(5), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(array_ref.is_err()); + } + + #[test] + fn test_fixed_size_binary_to_binary() { + let bytes_1 = "Hiiii".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(FixedSizeBinaryArray::from(binary_data.clone())) as ArrayRef; + + let array_ref = cast(&a1, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let array_ref = cast(&a1, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_numeric_to_binary() { + let a = Int16Array::from(vec![Some(1), Some(511), None]); + + let array_ref = cast(&a, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&1_i16.to_le_bytes(), down_cast.value(0)); + assert_eq!(&511_i16.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let a = Int64Array::from(vec![Some(-1), Some(123456789), None]); + + let array_ref = cast(&a, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&(-1_i64).to_le_bytes(), down_cast.value(0)); + assert_eq!(&123456789_i64.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_numeric_to_large_binary() { + let a = Int16Array::from(vec![Some(1), Some(511), None]); + + let array_ref = cast(&a, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&1_i16.to_le_bytes(), down_cast.value(0)); + assert_eq!(&511_i16.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let a = Int64Array::from(vec![Some(-1), Some(123456789), None]); + + let array_ref = cast(&a, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&(-1_i64).to_le_bytes(), down_cast.value(0)); + assert_eq!(&123456789_i64.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_cast_date32_to_int32() { + let array = Date32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_int32_to_date32() { + let array = Int32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_timestamp_to_date32() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("+00:00".to_string()); + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + #[test] + fn test_cast_timestamp_to_date32_zone() { + let strings = StringArray::from_iter([ + Some("1970-01-01T00:00:01"), + Some("1970-01-01T23:59:59"), + None, + Some("2020-03-01T02:00:23+00:00"), + ]); + let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("-07:00".into())); + let timestamps = cast(&strings, &dt).unwrap(); + let dates = cast(timestamps.as_ref(), &DataType::Date32).unwrap(); + + let c = dates.as_primitive::(); + let expected = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + assert_eq!(c.value_as_date(0).unwrap(), expected); + assert_eq!(c.value_as_date(1).unwrap(), expected); + assert!(c.is_null(2)); + let expected = NaiveDate::from_ymd_opt(2020, 2, 29).unwrap(); + assert_eq!(c.value_as_date(3).unwrap(), expected); + } + #[test] + fn test_cast_timestamp_to_date64() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + + let array = TimestampSecondArray::from(vec![Some(864000000005), Some(1545696000001)]); + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000005000, c.value(0)); + assert_eq!(1545696000001000, c.value(1)); + + // test overflow, safe cast + let array = TimestampSecondArray::from(vec![Some(i64::MAX)]); + let b = cast(&array, &DataType::Date64).unwrap(); + assert!(b.is_null(0)); + // test overflow, unsafe cast + let array = TimestampSecondArray::from(vec![Some(i64::MAX)]); + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let b = cast_with_options(&array, &DataType::Date64, &options); + assert!(b.is_err()); + } + + #[test] + fn test_cast_timestamp_to_time64() { + // test timestamp secs + let array = TimestampSecondArray::from(vec![Some(86405), Some(1), None]) + .with_timezone("+01:00".to_string()); + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp milliseconds + let a = TimestampMillisecondArray::from(vec![Some(86405000), Some(1000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp microseconds + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp nanoseconds + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test overflow + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time64(TimeUnit::Millisecond)); + assert!(b.is_err()); + } + + #[test] + fn test_cast_timestamp_to_time32() { + // test timestamp secs + let a = TimestampSecondArray::from(vec![Some(86405), Some(1), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp milliseconds + let a = TimestampMillisecondArray::from(vec![Some(86405000), Some(1000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp microseconds + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp nanoseconds + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test overflow + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)); + assert!(b.is_err()); + } + + // Cast Timestamp(_, None) -> Timestamp(_, Some(timezone)) + #[test] + fn test_cast_timestamp_with_timezone_1() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T00:00:00.123456789"), + Some("2010-01-01T00:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + let to_type = DataType::Timestamp(TimeUnit::Microsecond, Some("+0700".into())); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T00:00:00.123456+07:00", result.value(0)); + assert_eq!("2010-01-01T00:00:00.123456+07:00", result.value(1)); + assert!(result.is_null(2)); + } + + // Cast Timestamp(_, Some(timezone)) -> Timestamp(_, None) + #[test] + fn test_cast_timestamp_with_timezone_2() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T07:00:00.123456789"), + Some("2010-01-01T07:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Millisecond, Some("+0700".into())); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + // Check intermediate representation is correct + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T07:00:00.123+07:00", result.value(0)); + assert_eq!("2010-01-01T07:00:00.123+07:00", result.value(1)); + assert!(result.is_null(2)); + + let to_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T00:00:00.123", result.value(0)); + assert_eq!("2010-01-01T00:00:00.123", result.value(1)); + assert!(result.is_null(2)); + } + + // Cast Timestamp(_, Some(timezone)) -> Timestamp(_, Some(timezone)) + #[test] + fn test_cast_timestamp_with_timezone_3() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T07:00:00.123456789"), + Some("2010-01-01T07:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Microsecond, Some("+0700".into())); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + // Check intermediate representation is correct + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T07:00:00.123456+07:00", result.value(0)); + assert_eq!("2010-01-01T07:00:00.123456+07:00", result.value(1)); + assert!(result.is_null(2)); + + let to_type = DataType::Timestamp(TimeUnit::Second, Some("-08:00".into())); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("1999-12-31T16:00:00-08:00", result.value(0)); + assert_eq!("2009-12-31T16:00:00-08:00", result.value(1)); + assert!(result.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000, c.value(0)); + assert_eq!(1545696000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_ms() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Millisecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_us() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005000, c.value(0)); + assert_eq!(1545696000001000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_ns() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005000000, c.value(0)); + assert_eq!(1545696000001000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_timestamp_to_i64() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("UTC".to_string()); + let b = cast(&array, &DataType::Int64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(&DataType::Int64, c.data_type()); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_string() { + let array = Date32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19", c.value(0)); + assert_eq!("2018-12-25", c.value(1)); + } + + #[test] + fn test_cast_date64_to_string() { + let array = Date64Array::from(vec![10000 * 86400000, 17890 * 86400000]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19T00:00:00", c.value(0)); + assert_eq!("2018-12-25T00:00:00", c.value(1)); + } + + #[test] + fn test_cast_timestamp_to_strings() { + // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19T00:00:03.005"), + Some("2018-12-25T00:00:02.001"), + None + ] + ); + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19T00:00:03.005"), + Some("2018-12-25T00:00:02.001"), + None + ] + ); + } + + #[test] + fn test_cast_timestamp_to_strings_opt() { + let ts_format = "%Y-%m-%d %H:%M:%S%.6f"; + let tz = "+0545"; // UTC + 0545 is Asia/Kathmandu + let cast_options = CastOptions { + safe: true, + format_options: FormatOptions::default() + .with_timestamp_format(Some(ts_format)) + .with_timestamp_tz_format(Some(ts_format)), + }; + // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None + let array_without_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 00:00:03.005000"), + Some("2018-12-25 00:00:02.001000"), + None + ] + ); + let out = + cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 00:00:03.005000"), + Some("2018-12-25 00:00:02.001000"), + None + ] + ); + + let array_with_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]) + .with_timezone(tz.to_string()); + let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 05:45:03.005000"), + Some("2018-12-25 05:45:02.001000"), + None + ] + ); + let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 05:45:03.005000"), + Some("2018-12-25 05:45:02.001000"), + None + ] + ); + } + + #[test] + fn test_cast_between_timestamps() { + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000003, c.value(0)); + assert_eq!(1545696002, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_duration_to_i64() { + let base = vec![5, 6, 7, 8, 100000000]; + + let duration_arrays = vec![ + Arc::new(DurationNanosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMicrosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMillisecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationSecondArray::from(base.clone())) as ArrayRef, + ]; + + for arr in duration_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_primitive::(); + assert_eq!(base.as_slice(), result.values()); + } + } + + #[test] + fn test_cast_between_durations_and_numerics() { + fn test_cast_between_durations() + where + FromType: ArrowPrimitiveType, + ToType: ArrowPrimitiveType, + PrimitiveArray: From>>, + { + let from_unit = match FromType::DATA_TYPE { + DataType::Duration(unit) => unit, + _ => panic!("Expected a duration type"), + }; + let to_unit = match ToType::DATA_TYPE { + DataType::Duration(unit) => unit, + _ => panic!("Expected a duration type"), + }; + let from_size = time_unit_multiple(&from_unit); + let to_size = time_unit_multiple(&to_unit); + + let (v1_before, v2_before) = (8640003005, 1696002001); + let (v1_after, v2_after) = if from_size >= to_size { + ( + v1_before / (from_size / to_size), + v2_before / (from_size / to_size), + ) + } else { + ( + v1_before * (to_size / from_size), + v2_before * (to_size / from_size), + ) + }; + + let array = + PrimitiveArray::::from(vec![Some(v1_before), Some(v2_before), None]); + let b = cast(&array, &ToType::DATA_TYPE).unwrap(); + let c = b.as_primitive::(); + assert_eq!(v1_after, c.value(0)); + assert_eq!(v2_after, c.value(1)); + assert!(c.is_null(2)); + } + + // between each individual duration type + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + + // cast failed + let array = DurationSecondArray::from(vec![ + Some(i64::MAX), + Some(8640203410378005), + Some(10241096), + None, + ]); + let b = cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert!(c.is_null(0)); + assert!(c.is_null(1)); + assert_eq!(10241096000000000, c.value(2)); + assert!(c.is_null(3)); + + // durations to numerics + let array = DurationSecondArray::from(vec![ + Some(i64::MAX), + Some(8640203410378005), + Some(10241096), + None, + ]); + let b = cast(&array, &DataType::Int64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(i64::MAX, c.value(0)); + assert_eq!(8640203410378005, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(0, c.value(0)); + assert_eq!(0, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + + // numerics to durations + let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103), Some(10241096), None]); + let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(i32::MAX as i64, c.value(0)); + assert_eq!(802034103, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + } + + #[test] + fn test_cast_to_strings() { + let a = Int32Array::from(vec![1, 2, 3]); + let out = cast(&a, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(out, vec![Some("1"), Some("2"), Some("3")]); + let out = cast(&a, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(out, vec![Some("1"), Some("2"), Some("3")]); + } + + #[test] + fn test_str_to_str_casts() { + for data in [ + vec![Some("foo"), Some("bar"), Some("ham")], + vec![Some("foo"), None, Some("bar")], + ] { + let a = LargeStringArray::from(data.clone()); + let to = cast(&a, &DataType::Utf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + + let a = StringArray::from(data); + let to = cast(&a, &DataType::LargeUtf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + } + } + + const VIEW_TEST_DATA: [Option<&str>; 5] = [ + Some("hello"), + Some("repeated"), + None, + Some("large payload over 12 bytes"), + Some("repeated"), + ]; + + #[test] + fn test_string_view_to_binary_view() { + let string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); + + assert!(can_cast_types( + string_view_array.data_type(), + &DataType::BinaryView + )); + + let binary_view_array = cast(&string_view_array, &DataType::BinaryView).unwrap(); + assert_eq!(binary_view_array.data_type(), &DataType::BinaryView); + + let expect_binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + assert_eq!(binary_view_array.as_ref(), &expect_binary_view_array); + } + + #[test] + fn test_binary_view_to_string_view() { + let binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + + assert!(can_cast_types( + binary_view_array.data_type(), + &DataType::Utf8View + )); + + let string_view_array = cast(&binary_view_array, &DataType::Utf8View).unwrap(); + assert_eq!(string_view_array.data_type(), &DataType::Utf8View); + + let expect_string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); + assert_eq!(string_view_array.as_ref(), &expect_string_view_array); + } + + #[test] + fn test_string_to_view() { + _test_string_to_view::(); + _test_string_to_view::(); + } + + fn _test_string_to_view() + where + O: OffsetSizeTrait, + { + let string_array = GenericStringArray::::from_iter(VIEW_TEST_DATA); + + assert!(can_cast_types( + string_array.data_type(), + &DataType::Utf8View + )); + + assert!(can_cast_types( + string_array.data_type(), + &DataType::BinaryView + )); + + let string_view_array = cast(&string_array, &DataType::Utf8View).unwrap(); + assert_eq!(string_view_array.data_type(), &DataType::Utf8View); + + let binary_view_array = cast(&string_array, &DataType::BinaryView).unwrap(); + assert_eq!(binary_view_array.data_type(), &DataType::BinaryView); + + let expect_string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); + assert_eq!(string_view_array.as_ref(), &expect_string_view_array); + + let expect_binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + assert_eq!(binary_view_array.as_ref(), &expect_binary_view_array); + } + + #[test] + fn test_bianry_to_view() { + _test_binary_to_view::(); + _test_binary_to_view::(); + } + + fn _test_binary_to_view() + where + O: OffsetSizeTrait, + { + let binary_array = GenericBinaryArray::::from_iter(VIEW_TEST_DATA); + + assert!(can_cast_types( + binary_array.data_type(), + &DataType::BinaryView + )); + + let binary_view_array = cast(&binary_array, &DataType::BinaryView).unwrap(); + assert_eq!(binary_view_array.data_type(), &DataType::BinaryView); + + let expect_binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + assert_eq!(binary_view_array.as_ref(), &expect_binary_view_array); + } + + #[test] + fn test_dict_to_view() { + let values = StringArray::from_iter(VIEW_TEST_DATA); + let keys = Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let string_dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let typed_dict = string_dict_array.downcast_dict::().unwrap(); + + let string_view_array = { + let mut builder = StringViewBuilder::new().with_fixed_block_size(8); // multiple buffers. + for v in typed_dict.into_iter() { + builder.append_option(v); + } + builder.finish() + }; + let expected_string_array_type = string_view_array.data_type(); + let casted_string_array = cast(&string_dict_array, expected_string_array_type).unwrap(); + assert_eq!(casted_string_array.data_type(), expected_string_array_type); + assert_eq!(casted_string_array.as_ref(), &string_view_array); + + let binary_buffer = cast(&typed_dict.values(), &DataType::Binary).unwrap(); + let binary_dict_array = + DictionaryArray::::new(typed_dict.keys().clone(), binary_buffer); + let typed_binary_dict = binary_dict_array.downcast_dict::().unwrap(); + + let binary_view_array = { + let mut builder = BinaryViewBuilder::new().with_fixed_block_size(8); // multiple buffers. + for v in typed_binary_dict.into_iter() { + builder.append_option(v); + } + builder.finish() + }; + let expected_binary_array_type = binary_view_array.data_type(); + let casted_binary_array = cast(&binary_dict_array, expected_binary_array_type).unwrap(); + assert_eq!(casted_binary_array.data_type(), expected_binary_array_type); + assert_eq!(casted_binary_array.as_ref(), &binary_view_array); + } + + #[test] + fn test_view_to_dict() { + let string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); + let string_dict_array: DictionaryArray = VIEW_TEST_DATA.into_iter().collect(); + let casted_type = string_dict_array.data_type(); + let casted_dict_array = cast(&string_view_array, casted_type).unwrap(); + assert_eq!(casted_dict_array.data_type(), casted_type); + assert_eq!(casted_dict_array.as_ref(), &string_dict_array); + + let binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + let binary_dict_array = string_dict_array.downcast_dict::().unwrap(); + let binary_buffer = cast(&binary_dict_array.values(), &DataType::Binary).unwrap(); + let binary_dict_array = + DictionaryArray::::new(binary_dict_array.keys().clone(), binary_buffer); + let casted_type = binary_dict_array.data_type(); + let casted_binary_array = cast(&binary_view_array, casted_type).unwrap(); + assert_eq!(casted_binary_array.data_type(), casted_type); + assert_eq!(casted_binary_array.as_ref(), &binary_dict_array); + } + + #[test] + fn test_view_to_string() { + _test_view_to_string::(); + _test_view_to_string::(); + } + + fn _test_view_to_string() + where + O: OffsetSizeTrait, + { + let string_view_array = { + let mut builder = StringViewBuilder::new().with_fixed_block_size(8); // multiple buffers. + for s in VIEW_TEST_DATA.iter() { + builder.append_option(*s); + } + builder.finish() + }; + + let binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); + + let expected_string_array = GenericStringArray::::from_iter(VIEW_TEST_DATA); + let expected_type = expected_string_array.data_type(); + + assert!(can_cast_types(string_view_array.data_type(), expected_type)); + assert!(can_cast_types(binary_view_array.data_type(), expected_type)); + + let string_view_casted_array = cast(&string_view_array, expected_type).unwrap(); + assert_eq!(string_view_casted_array.data_type(), expected_type); + assert_eq!(string_view_casted_array.as_ref(), &expected_string_array); + + let binary_view_casted_array = cast(&binary_view_array, expected_type).unwrap(); + assert_eq!(binary_view_casted_array.data_type(), expected_type); + assert_eq!(binary_view_casted_array.as_ref(), &expected_string_array); + } + + #[test] + fn test_view_to_binary() { + _test_view_to_binary::(); + _test_view_to_binary::(); + } + + fn _test_view_to_binary() + where + O: OffsetSizeTrait, + { + let view_array = { + let mut builder = BinaryViewBuilder::new().with_fixed_block_size(8); // multiple buffers. + for s in VIEW_TEST_DATA.iter() { + builder.append_option(*s); + } + builder.finish() + }; + + let expected_binary_array = GenericBinaryArray::::from_iter(VIEW_TEST_DATA); + let expected_type = expected_binary_array.data_type(); + + assert!(can_cast_types(view_array.data_type(), expected_type)); + + let binary_array = cast(&view_array, expected_type).unwrap(); + assert_eq!(binary_array.data_type(), expected_type); + + assert_eq!(binary_array.as_ref(), &expected_binary_array); + } + + #[test] + fn test_cast_from_f64() { + let f64_values: Vec = vec![ + i64::MIN as f64, + i32::MIN as f64, + i16::MIN as f64, + i8::MIN as f64, + 0_f64, + u8::MAX as f64, + u16::MAX as f64, + u32::MAX as f64, + u64::MAX as f64, + ]; + let f64_array: ArrayRef = Arc::new(Float64Array::from(f64_values)); + + let f64_expected = vec![ + -9223372036854776000.0, + -2147483648.0, + -32768.0, + -128.0, + 0.0, + 255.0, + 65535.0, + 4294967295.0, + 18446744073709552000.0, + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![ + -9223372000000000000.0, + -2147483600.0, + -32768.0, + -128.0, + 0.0, + 255.0, + 65535.0, + 4294967300.0, + 18446744000000000000.0, + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f64_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_f32() { + let f32_values: Vec = vec![ + i32::MIN as f32, + i32::MIN as f32, + i16::MIN as f32, + i8::MIN as f32, + 0_f32, + u8::MAX as f32, + u16::MAX as f32, + u32::MAX as f32, + u32::MAX as f32, + ]; + let f32_array: ArrayRef = Arc::new(Float32Array::from(f32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967296.0", + "4294967296.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967300.0", + "4294967300.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f32_array, &DataType::Float32) + ); + + let f16_expected = vec![ + "-inf", "-inf", "-32768.0", "-128.0", "0.0", "255.0", "inf", "inf", "inf", + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f32_array, &DataType::Float16) + ); + + let i64_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f32_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f32_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f32_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f32_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f32_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f32_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f32_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint64() { + let u64_values: Vec = vec![ + 0, + u8::MAX as u64, + u16::MAX as u64, + u32::MAX as u64, + u64::MAX, + ]; + let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values)); + + let f64_expected = vec![0.0, 255.0, 65535.0, 4294967295.0, 18446744073709552000.0]; + assert_eq!( + f64_expected, + get_cast_values::(&u64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![0.0, 255.0, 65535.0, 4294967300.0, 18446744000000000000.0]; + assert_eq!( + f32_expected, + get_cast_values::(&u64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&u64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + i64_expected, + get_cast_values::(&u64_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u64_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u64_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u64_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535", "4294967295", "18446744073709551615"]; + assert_eq!( + u64_expected, + get_cast_values::(&u64_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + u32_expected, + get_cast_values::(&u64_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u64_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint32() { + let u32_values: Vec = vec![0, u8::MAX as u32, u16::MAX as u32, u32::MAX]; + let u32_array: ArrayRef = Arc::new(UInt32Array::from(u32_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0", "4294967295.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u32_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0", "4294967300.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u32_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0", "inf", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u32_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + i64_expected, + get_cast_values::(&u32_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u32_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u32_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u32_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u64_expected, + get_cast_values::(&u32_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u32_expected, + get_cast_values::(&u32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint16() { + let u16_values: Vec = vec![0, u8::MAX as u16, u16::MAX]; + let u16_array: ArrayRef = Arc::new(UInt16Array::from(u16_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u16_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u16_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u16_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255", "65535"]; + assert_eq!( + i64_expected, + get_cast_values::(&u16_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535"]; + assert_eq!( + i32_expected, + get_cast_values::(&u16_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u16_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u16_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535"]; + assert_eq!( + u64_expected, + get_cast_values::(&u16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535"]; + assert_eq!( + u32_expected, + get_cast_values::(&u16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535"]; + assert_eq!( + u16_expected, + get_cast_values::(&u16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint8() { + let u8_values: Vec = vec![0, u8::MAX]; + let u8_array: ArrayRef = Arc::new(UInt8Array::from(u8_values)); + + let f64_expected = vec!["0.0", "255.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u8_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u8_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&u8_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255"]; + assert_eq!( + i64_expected, + get_cast_values::(&u8_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255"]; + assert_eq!( + i32_expected, + get_cast_values::(&u8_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255"]; + assert_eq!( + i16_expected, + get_cast_values::(&u8_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u8_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255"]; + assert_eq!( + u64_expected, + get_cast_values::(&u8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255"]; + assert_eq!( + u32_expected, + get_cast_values::(&u8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255"]; + assert_eq!( + u16_expected, + get_cast_values::(&u8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255"]; + assert_eq!( + u8_expected, + get_cast_values::(&u8_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int64() { + let i64_values: Vec = vec![ + i64::MIN, + i32::MIN as i64, + i16::MIN as i64, + i8::MIN as i64, + 0, + i8::MAX as i64, + i16::MAX as i64, + i32::MAX as i64, + i64::MAX, + ]; + let i64_array: ArrayRef = Arc::new(Int64Array::from(i64_values)); + + let f64_expected = vec![ + -9223372036854776000.0, + -2147483648.0, + -32768.0, + -128.0, + 0.0, + 127.0, + 32767.0, + 2147483647.0, + 9223372036854776000.0, + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![ + -9223372000000000000.0, + -2147483600.0, + -32768.0, + -128.0, + 0.0, + 127.0, + 32767.0, + 2147483600.0, + 9223372000000000000.0, + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + f16::from_f64(9223372000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&i64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&i64_array, &DataType::Int32) + ); + + assert_eq!( + i32_expected, + get_cast_values::(&i64_array, &DataType::Date32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&i64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "127", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&i64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&i64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&i64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&i64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "127", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&i64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int32() { + let i32_values: Vec = vec![ + i32::MIN, + i16::MIN as i32, + i8::MIN as i32, + 0, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX, + ]; + let i32_array: ArrayRef = Arc::new(Int32Array::from(i32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483647.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483600.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i32_array, &DataType::Float32) + ); + + let f16_expected = vec![ + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i32_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i16_expected = vec!["null", "-32768", "-128", "0", "127", "32767", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&i32_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "null", "-128", "0", "127", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i32_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u64_expected, + get_cast_values::(&i32_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u32_expected, + get_cast_values::(&i32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "null", "0", "127", "32767", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&i32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "null", "0", "127", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i32_array, &DataType::UInt8) + ); + + // The date32 to date64 cast increases the numerical values in order to keep the same dates. + let i64_expected = vec![ + "-185542587187200000", + "-2831155200000", + "-11059200000", + "0", + "10972800000", + "2831068800000", + "185542587100800000", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&i32_array, &DataType::Date64) + ); + } + + #[test] + fn test_cast_from_int16() { + let i16_values: Vec = vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX]; + let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values)); + + let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i16_array, &DataType::Float64) + ); + + let f32_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i16_array, &DataType::Float32) + ); + + let f16_expected = vec![ + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i16_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i64_expected, + get_cast_values::(&i16_array, &DataType::Int64) + ); + + let i32_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i32_expected, + get_cast_values::(&i16_array, &DataType::Int32) + ); + + let i16_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i16_expected, + get_cast_values::(&i16_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "-128", "0", "127", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i16_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u64_expected, + get_cast_values::(&i16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u32_expected, + get_cast_values::(&i16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u16_expected, + get_cast_values::(&i16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "0", "127", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_date32() { + let i32_values: Vec = vec![ + i32::MIN, + i16::MIN as i32, + i8::MIN as i32, + 0, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX, + ]; + let date32_array: ArrayRef = Arc::new(Date32Array::from(i32_values)); + + let i64_expected = vec![ + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&date32_array, &DataType::Int64) + ); + } + + #[test] + fn test_cast_from_int8() { + let i8_values: Vec = vec![i8::MIN, 0, i8::MAX]; + let i8_array = Int8Array::from(i8_values); + + let f64_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i8_array, &DataType::Float64) + ); + + let f32_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i8_array, &DataType::Float32) + ); + + let f16_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&i8_array, &DataType::Float16) + ); + + let i64_expected = vec!["-128", "0", "127"]; + assert_eq!( + i64_expected, + get_cast_values::(&i8_array, &DataType::Int64) + ); + + let i32_expected = vec!["-128", "0", "127"]; + assert_eq!( + i32_expected, + get_cast_values::(&i8_array, &DataType::Int32) + ); + + let i16_expected = vec!["-128", "0", "127"]; + assert_eq!( + i16_expected, + get_cast_values::(&i8_array, &DataType::Int16) + ); + + let i8_expected = vec!["-128", "0", "127"]; + assert_eq!( + i8_expected, + get_cast_values::(&i8_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "0", "127"]; + assert_eq!( + u64_expected, + get_cast_values::(&i8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "0", "127"]; + assert_eq!( + u32_expected, + get_cast_values::(&i8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "0", "127"]; + assert_eq!( + u16_expected, + get_cast_values::(&i8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "0", "127"]; + assert_eq!( + u8_expected, + get_cast_values::(&i8_array, &DataType::UInt8) + ); + } + + /// Convert `array` into a vector of strings by casting to data type dt + fn get_cast_values(array: &dyn Array, dt: &DataType) -> Vec + where + T: ArrowPrimitiveType, + { + let c = cast(array, dt).unwrap(); + let a = c.as_primitive::(); + let mut v: Vec = vec![]; + for i in 0..array.len() { + if a.is_null(i) { + v.push("null".to_string()) + } else { + v.push(format!("{:?}", a.value(i))); + } + } + v + } + + #[test] + fn test_cast_utf8_dict() { + // FROM a dictionary with of Utf8 values + use DataType::*; + + let mut builder = StringDictionaryBuilder::::new(); + builder.append("one").unwrap(); + builder.append_null(); + builder.append("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Test casting TO StringArray + let cast_type = Utf8; + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Test casting TO Dictionary (with different index sizes) + + let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_primitive() { + use DataType::*; + // test converting from an array that has indexes of a type + // that are out of bounds for a particular other kind of + // index. + + let mut builder = PrimitiveDictionaryBuilder::::new(); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + builder.append(i).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{res:?}"); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{actual_error}' in actual error '{expected_error}'" + ); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_utf8() { + use DataType::*; + // Same test as test_cast_dict_to_dict_bad_index_value but use + // string values (and encode the expected behavior here); + + let mut builder = StringDictionaryBuilder::::new(); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + let val = format!("val{i}"); + builder.append(&val).unwrap(); + } + let array = builder.finish(); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{res:?}"); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{actual_error}' in actual error '{expected_error}'" + ); + } + + #[test] + fn test_cast_primitive_dict() { + // FROM a dictionary with of INT32 values + use DataType::*; + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(1).unwrap(); + builder.append_null(); + builder.append(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Test casting TO PrimitiveArray, different dictionary type + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Utf8); + + let cast_array = cast(&array, &Int64).expect("cast to int64 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Int64); + } + + #[test] + fn test_cast_primitive_array_to_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(); + builder.append_value(1); + builder.append_null(); + builder.append_value(3); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_time_array_to_dict() { + use DataType::*; + + let array = Arc::new(Date32Array::from(vec![Some(1000), None, Some(2000)])) as ArrayRef; + + let expected = vec!["1972-09-27", "null", "1975-06-24"]; + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Date32)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_timestamp_array_to_dict() { + use DataType::*; + + let array = Arc::new( + TimestampSecondArray::from(vec![Some(1000), None, Some(2000)]).with_timezone_utc(), + ) as ArrayRef; + + let expected = vec!["1970-01-01T00:16:40", "null", "1970-01-01T00:33:20"]; + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Timestamp(TimeUnit::Second, None))); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_string_array_to_dict() { + use DataType::*; + + let array = Arc::new(StringArray::from(vec![Some("one"), None, Some("three")])) as ArrayRef; + + let expected = vec!["one", "null", "three"]; + + // Cast to a dictionary (same value type, Utf8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_null_array_to_from_decimal_array() { + let data_type = DataType::Decimal128(12, 4); + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, &data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &data_type); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + + let array = new_null_array(&data_type, 4); + assert_eq!(array.data_type(), &data_type); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + assert_eq!(cast_array.data_type(), &DataType::Null); + assert_eq!(cast_array.len(), 4); + assert_eq!(cast_array.logical_nulls().unwrap().null_count(), 4); + } + + #[test] + fn test_cast_null_array_from_and_to_primitive_array() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident, $TYPE:tt) => {{ + { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = cast_array.as_primitive::<$TYPE>(); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + }}; + } + + typed_test!(Int16Array, Int16, Int16Type); + typed_test!(Int32Array, Int32, Int32Type); + typed_test!(Int64Array, Int64, Int64Type); + + typed_test!(UInt16Array, UInt16, UInt16Type); + typed_test!(UInt32Array, UInt32, UInt32Type); + typed_test!(UInt64Array, UInt64, UInt64Type); + + typed_test!(Float32Array, Float32, Float32Type); + typed_test!(Float64Array, Float64, Float64Type); + + typed_test!(Date32Array, Date32, Date32Type); + typed_test!(Date64Array, Date64, Date64Type); + } + + fn cast_from_null_to_other(data_type: &DataType) { + // Cast from null to data_type + { + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), data_type); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + } + } + + #[test] + fn test_cast_null_from_and_to_variable_sized() { + cast_from_null_to_other(&DataType::Utf8); + cast_from_null_to_other(&DataType::LargeUtf8); + cast_from_null_to_other(&DataType::Binary); + cast_from_null_to_other(&DataType::LargeBinary); + } + + #[test] + fn test_cast_null_from_and_to_nested_type() { + // Cast null from and to map + let data_type = DataType::Map( + Arc::new(Field::new_struct( + "entry", + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ], + false, + )), + false, + ); + cast_from_null_to_other(&data_type); + + // Cast null from and to list + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&data_type); + let data_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&data_type); + let data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + cast_from_null_to_other(&data_type); + + // Cast null from and to dictionary + let values = vec![None, None, None, None] as Vec>; + let array: DictionaryArray = values.into_iter().collect(); + let array = Arc::new(array) as ArrayRef; + let data_type = array.data_type().to_owned(); + cast_from_null_to_other(&data_type); + + // Cast null from and to struct + let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); + cast_from_null_to_other(&data_type); + } + + /// Print the `DictionaryArray` `array` as a vector of strings + fn array_to_strings(array: &ArrayRef) -> Vec { + let options = FormatOptions::new().with_null("null"); + let formatter = ArrayFormatter::try_new(array.as_ref(), &options).unwrap(); + (0..array.len()) + .map(|i| formatter.value(i).to_string()) + .collect() + } + + #[test] + fn test_cast_utf8_to_date32() { + use chrono::NaiveDate; + let from_ymd = chrono::NaiveDate::from_ymd_opt; + let since = chrono::NaiveDate::signed_duration_since; + + let a = StringArray::from(vec![ + "2000-01-01", // valid date with leading 0s + "2000-01-01T12:00:00", // valid datetime, will throw away the time part + "2000-2-2", // valid date without leading 0s + "2000-00-00", // invalid month and day + "2000", // just a year is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + + // test valid inputs + let date_value = since( + NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), + from_ymd(1970, 1, 1).unwrap(), + ) + .num_days() as i32; + assert!(c.is_valid(0)); // "2000-01-01" + assert_eq!(date_value, c.value(0)); + + assert!(c.is_valid(1)); // "2000-01-01T12:00:00" + assert_eq!(date_value, c.value(1)); + + let date_value = since( + NaiveDate::from_ymd_opt(2000, 2, 2).unwrap(), + from_ymd(1970, 1, 1).unwrap(), + ) + .num_days() as i32; + assert!(c.is_valid(2)); // "2000-2-2" + assert_eq!(date_value, c.value(2)); + + // test invalid inputs + assert!(!c.is_valid(3)); // "2000-00-00" + assert!(!c.is_valid(4)); // "2000" + } + + #[test] + fn test_cast_utf8_to_date64() { + let a = StringArray::from(vec![ + "2000-01-01T12:00:00", // date + time valid + "2020-12-15T12:34:56", // date + time valid + "2020-2-2T12:34:56", // valid date time without leading 0s + "2000-00-00T12:00:00", // invalid month and day + "2000-01-01 12:00:00", // missing the 'T' + "2000-01-01", // just a date is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + + // test valid inputs + assert!(c.is_valid(0)); // "2000-01-01T12:00:00" + assert_eq!(946728000000, c.value(0)); + assert!(c.is_valid(1)); // "2020-12-15T12:34:56" + assert_eq!(1608035696000, c.value(1)); + assert!(!c.is_valid(2)); // "2020-2-2T12:34:56" + + assert!(!c.is_valid(3)); // "2000-00-00T12:00:00" + assert!(c.is_valid(4)); // "2000-01-01 12:00:00" + assert_eq!(946728000000, c.value(4)); + assert!(c.is_valid(5)); // "2000-01-01" + assert_eq!(946684800000, c.value(5)); + } + + #[test] + fn test_can_cast_fsl_to_fsl() { + let from_array = Arc::new( + FixedSizeListArray::from_iter_primitive::( + [Some([Some(1.0), Some(2.0)]), None], + 2, + ), + ) as ArrayRef; + let to_array = Arc::new( + FixedSizeListArray::from_iter_primitive::( + [ + Some([Some(f16::from_f32(1.0)), Some(f16::from_f32(2.0))]), + None, + ], + 2, + ), + ) as ArrayRef; + + assert!(can_cast_types(from_array.data_type(), to_array.data_type())); + let actual = cast(&from_array, to_array.data_type()).unwrap(); + assert_eq!(actual.data_type(), to_array.data_type()); + + let invalid_target = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Binary, true)), 2); + assert!(!can_cast_types(from_array.data_type(), &invalid_target)); + + let invalid_size = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 5); + assert!(!can_cast_types(from_array.data_type(), &invalid_size)); + } + + #[test] + fn test_can_cast_types_fixed_size_list_to_list() { + // DataType::List + let array1 = Arc::new(make_fixed_size_list_array()) as ArrayRef; + assert!(can_cast_types( + array1.data_type(), + &DataType::List(Arc::new(Field::new("", DataType::Int32, false))) + )); + + // DataType::LargeList + let array2 = Arc::new(make_fixed_size_list_array_for_large_list()) as ArrayRef; + assert!(can_cast_types( + array2.data_type(), + &DataType::LargeList(Arc::new(Field::new("", DataType::Int64, false))) + )); + } + + #[test] + fn test_cast_fixed_size_list_to_list() { + // Important cases: + // 1. With/without nulls + // 2. LargeList and List + // 3. With and without inner casts + + let cases = [ + // fixed_size_list => list + ( + Arc::new(FixedSizeListArray::from_iter_primitive::( + [[1, 1].map(Some), [2, 2].map(Some)].map(Some), + 2, + )) as ArrayRef, + Arc::new(ListArray::from_iter_primitive::([ + Some([Some(1), Some(1)]), + Some([Some(2), Some(2)]), + ])) as ArrayRef, + ), + // fixed_size_list => list (nullable) + ( + Arc::new(FixedSizeListArray::from_iter_primitive::( + [None, Some([Some(2), Some(2)])], + 2, + )) as ArrayRef, + Arc::new(ListArray::from_iter_primitive::([ + None, + Some([Some(2), Some(2)]), + ])) as ArrayRef, + ), + // fixed_size_list => large_list + ( + Arc::new(FixedSizeListArray::from_iter_primitive::( + [[1, 1].map(Some), [2, 2].map(Some)].map(Some), + 2, + )) as ArrayRef, + Arc::new(LargeListArray::from_iter_primitive::([ + Some([Some(1), Some(1)]), + Some([Some(2), Some(2)]), + ])) as ArrayRef, + ), + // fixed_size_list => large_list (nullable) + ( + Arc::new(FixedSizeListArray::from_iter_primitive::( + [None, Some([Some(2), Some(2)])], + 2, + )) as ArrayRef, + Arc::new(LargeListArray::from_iter_primitive::([ + None, + Some([Some(2), Some(2)]), + ])) as ArrayRef, + ), + ]; + + for (array, expected) in cases { + let array = Arc::new(array) as ArrayRef; + + assert!( + can_cast_types(array.data_type(), expected.data_type()), + "can_cast_types claims we cannot cast {:?} to {:?}", + array.data_type(), + expected.data_type() + ); + + let list_array = cast(&array, expected.data_type()) + .unwrap_or_else(|_| panic!("Failed to cast {:?} to {:?}", array, expected)); + assert_eq!( + list_array.as_ref(), + &expected, + "Incorrect result from casting {:?} to {:?}", + array, + expected + ); + } + } + + #[test] + fn test_cast_utf8_to_list() { + // DataType::List + let array = Arc::new(StringArray::from(vec!["5"])) as ArrayRef; + let field = Arc::new(Field::new("", DataType::Int32, false)); + let list_array = cast(&array, &DataType::List(field.clone())).unwrap(); + let actual = list_array.as_list_opt::().unwrap(); + let expect = ListArray::from_iter_primitive::([Some([Some(5)])]); + assert_eq!(&expect.value(0), &actual.value(0)); + + // DataType::LargeList + let list_array = cast(&array, &DataType::LargeList(field.clone())).unwrap(); + let actual = list_array.as_list_opt::().unwrap(); + let expect = LargeListArray::from_iter_primitive::([Some([Some(5)])]); + assert_eq!(&expect.value(0), &actual.value(0)); + + // DataType::FixedSizeList + let list_array = cast(&array, &DataType::FixedSizeList(field.clone(), 1)).unwrap(); + let actual = list_array.as_fixed_size_list_opt().unwrap(); + let expect = + FixedSizeListArray::from_iter_primitive::([Some([Some(5)])], 1); + assert_eq!(&expect.value(0), &actual.value(0)); + } + + #[test] + fn test_cast_single_element_fixed_size_list() { + // FixedSizeList[1] => T + let from_array = Arc::new(FixedSizeListArray::from_iter_primitive::( + [(Some([Some(5)]))], + 1, + )) as ArrayRef; + let casted_array = cast(&from_array, &DataType::Int32).unwrap(); + let actual: &Int32Array = casted_array.as_primitive(); + let expected = Int32Array::from(vec![Some(5)]); + assert_eq!(&expected, actual); + + // FixedSizeList[1] => FixedSizeList[1] + let from_array = Arc::new(FixedSizeListArray::from_iter_primitive::( + [(Some([Some(5)]))], + 1, + )) as ArrayRef; + let to_field = Arc::new(Field::new("dummy", DataType::Float32, false)); + let actual = cast(&from_array, &DataType::FixedSizeList(to_field.clone(), 1)).unwrap(); + let expected = Arc::new(FixedSizeListArray::new( + to_field.clone(), + 1, + Arc::new(Float32Array::from(vec![Some(5.0)])) as ArrayRef, + None, + )) as ArrayRef; + assert_eq!(*expected, *actual); + + // FixedSizeList[1] => FixedSizeList[1]>[1] + let from_array = Arc::new(FixedSizeListArray::from_iter_primitive::( + [(Some([Some(5)]))], + 1, + )) as ArrayRef; + let to_field_inner = Arc::new(Field::new("item", DataType::Float32, false)); + let to_field = Arc::new(Field::new( + "dummy", + DataType::FixedSizeList(to_field_inner.clone(), 1), + false, + )); + let actual = cast(&from_array, &DataType::FixedSizeList(to_field.clone(), 1)).unwrap(); + let expected = Arc::new(FixedSizeListArray::new( + to_field.clone(), + 1, + Arc::new(FixedSizeListArray::new( + to_field_inner.clone(), + 1, + Arc::new(Float32Array::from(vec![Some(5.0)])) as ArrayRef, + None, + )) as ArrayRef, + None, + )) as ArrayRef; + assert_eq!(*expected, *actual); + + // T => FixedSizeList[1] (non-nullable) + let field = Arc::new(Field::new("dummy", DataType::Float32, false)); + let from_array = Arc::new(Int8Array::from(vec![Some(5)])) as ArrayRef; + let casted_array = cast(&from_array, &DataType::FixedSizeList(field.clone(), 1)).unwrap(); + let actual = casted_array.as_fixed_size_list(); + let expected = Arc::new(FixedSizeListArray::new( + field.clone(), + 1, + Arc::new(Float32Array::from(vec![Some(5.0)])) as ArrayRef, + None, + )) as ArrayRef; + assert_eq!(expected.as_ref(), actual); + + // T => FixedSizeList[1] (nullable) + let field = Arc::new(Field::new("nullable", DataType::Float32, true)); + let from_array = Arc::new(Int8Array::from(vec![None])) as ArrayRef; + let casted_array = cast(&from_array, &DataType::FixedSizeList(field.clone(), 1)).unwrap(); + let actual = casted_array.as_fixed_size_list(); + let expected = Arc::new(FixedSizeListArray::new( + field.clone(), + 1, + Arc::new(Float32Array::from(vec![None])) as ArrayRef, + None, + )) as ArrayRef; + assert_eq!(expected.as_ref(), actual); + } + + #[test] + fn test_cast_list_containers() { + // large-list to list + let array = Arc::new(make_large_list_array()) as ArrayRef; + let list_array = cast( + &array, + &DataType::List(Arc::new(Field::new("", DataType::Int32, false))), + ) + .unwrap(); + let actual = list_array.as_any().downcast_ref::().unwrap(); + let expected = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(&expected.value(0), &actual.value(0)); + assert_eq!(&expected.value(1), &actual.value(1)); + assert_eq!(&expected.value(2), &actual.value(2)); + + // list to large-list + let array = Arc::new(make_list_array()) as ArrayRef; + let large_list_array = cast( + &array, + &DataType::LargeList(Arc::new(Field::new("", DataType::Int32, false))), + ) + .unwrap(); + let actual = large_list_array + .as_any() + .downcast_ref::() + .unwrap(); + let expected = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(&expected.value(0), &actual.value(0)); + assert_eq!(&expected.value(1), &actual.value(1)); + assert_eq!(&expected.value(2), &actual.value(2)); + } + + #[test] + fn test_cast_list_to_fsl() { + // There four noteworthy cases we should handle: + // 1. No nulls + // 2. Nulls that are always empty + // 3. Nulls that have varying lengths + // 4. Nulls that are correctly sized (same as target list size) + + // Non-null case + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 3, + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + // Null cases + // Array is [[1, 2, 3], null, [4, 5, 6], null] + let cases = [ + ( + // Zero-length nulls + vec![1, 2, 3, 4, 5, 6], + vec![3, 0, 3, 0], + ), + ( + // Varying-length nulls + vec![1, 2, 3, 0, 0, 4, 5, 6, 0], + vec![3, 2, 3, 1], + ), + ( + // Correctly-sized nulls + vec![1, 2, 3, 0, 0, 0, 4, 5, 6, 0, 0, 0], + vec![3, 3, 3, 3], + ), + ( + // Mixed nulls + vec![1, 2, 3, 4, 5, 6, 0, 0, 0], + vec![3, 0, 3, 3], + ), + ]; + let null_buffer = NullBuffer::from(vec![true, false, true, false]); + + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5), Some(6)]), + None, + ], + 3, + )) as ArrayRef; + + for (values, lengths) in cases.iter() { + let array = Arc::new(ListArray::new( + field.clone(), + OffsetBuffer::from_lengths(lengths.clone()), + Arc::new(Int32Array::from(values.clone())), + Some(null_buffer.clone()), + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + } + + #[test] + fn test_cast_list_to_fsl_safety() { + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6), Some(7), Some(8), Some(9)]), + Some(vec![Some(3), Some(4), Some(5)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res) + .contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2")); + + // When safe=true (default), the cast will fill nulls for lists that are + // too short and truncate lists that are too long. + let res = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + ) + .unwrap(); + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, // Too short -> replaced with null + None, // Too long -> replaced with null + Some(vec![Some(3), Some(4), Some(5)]), + ], + 3, + )) as ArrayRef; + assert_eq!(expected.as_ref(), res.as_ref()); + + // The safe option is false and the source array contains a null list. + // issue: https://github.com/apache/arrow-rs/issues/5642 + let array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + ])) as ArrayRef; + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &CastOptions { + safe: false, + ..Default::default() + }, + ) + .unwrap(); + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(2), Some(3)]), None], + 3, + )) as ArrayRef; + assert_eq!(expected.as_ref(), res.as_ref()); + } + + #[test] + fn test_cast_large_list_to_fsl() { + let values = vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])]; + let array = Arc::new(LargeListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + + #[test] + fn test_cast_list_to_fsl_subcast() { + let array = Arc::new(LargeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX)]), + ], + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX as i64)]), + ], + 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int16, true)), 2), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res).contains("Can't cast value 2147483647 to type Int16")); + } + + #[test] + fn test_cast_list_to_fsl_empty() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let array = new_empty_array(&DataType::List(field.clone())); + + let target_type = DataType::FixedSizeList(field.clone(), 3); + let expected = new_empty_array(&target_type); + + let actual = cast(array.as_ref(), &target_type).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + + fn make_list_array() -> ListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + ListArray::from(list_data) + } + + fn make_large_list_array() -> LargeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + LargeListArray::from(list_data) + } + + fn make_fixed_size_list_array() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + let list_data = ArrayData::builder(list_data_type) + .len(2) + .add_child_data(value_data) + .build() + .unwrap(); + FixedSizeListArray::from(list_data) + } + + fn make_fixed_size_list_array_for_large_list() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int64) + .len(8) + .add_buffer(Buffer::from_slice_ref([0i64, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 4); + let list_data = ArrayData::builder(list_data_type) + .len(2) + .add_child_data(value_data) + .build() + .unwrap(); + FixedSizeListArray::from(list_data) + } + + #[test] + fn test_cast_map_dont_allow_change_of_order() { + let string_builder = StringBuilder::new(); + let value_builder = StringBuilder::new(); + let mut builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + string_builder, + value_builder, + ); + + builder.keys().append_value("0"); + builder.values().append_value("test_val_1"); + builder.append(true).unwrap(); + builder.keys().append_value("1"); + builder.values().append_value("test_val_2"); + builder.append(true).unwrap(); + + // map builder returns unsorted map by default + let array = builder.finish(); + + let new_ordered = true; + let new_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ] + .into(), + ), + false, + )), + new_ordered, + ); + + let new_array_result = cast(&array, &new_type.clone()); + assert!(!can_cast_types(array.data_type(), &new_type)); + assert!( + matches!(new_array_result, Err(ArrowError::CastError(t)) if t == r#"Casting from Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) to Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, true) not supported"#) + ); + } + + #[test] + fn test_cast_map_dont_allow_when_container_cant_cast() { + let string_builder = StringBuilder::new(); + let value_builder = IntervalDayTimeArray::builder(2); + let mut builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + string_builder, + value_builder, + ); + + builder.keys().append_value("0"); + builder.values().append_value(IntervalDayTime::new(1, 1)); + builder.append(true).unwrap(); + builder.keys().append_value("1"); + builder.values().append_value(IntervalDayTime::new(2, 2)); + builder.append(true).unwrap(); + + // map builder returns unsorted map by default + let array = builder.finish(); + + let new_ordered = true; + let new_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Duration(TimeUnit::Second), false), + ] + .into(), + ), + false, + )), + new_ordered, + ); + + let new_array_result = cast(&array, &new_type.clone()); + assert!(!can_cast_types(array.data_type(), &new_type)); + assert!( + matches!(new_array_result, Err(ArrowError::CastError(t)) if t == r#"Casting from Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Interval(DayTime), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) to Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Duration(Second), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, true) not supported"#) + ); + } + + #[test] + fn test_cast_map_field_names() { + let string_builder = StringBuilder::new(); + let value_builder = StringBuilder::new(); + let mut builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + string_builder, + value_builder, + ); + + builder.keys().append_value("0"); + builder.values().append_value("test_val_1"); + builder.append(true).unwrap(); + builder.keys().append_value("1"); + builder.values().append_value("test_val_2"); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let array = builder.finish(); + + let new_type = DataType::Map( + Arc::new(Field::new( + "entries_new", + DataType::Struct( + vec![ + Field::new("key_new", DataType::Utf8, false), + Field::new("value_values", DataType::Utf8, false), + ] + .into(), + ), + false, + )), + false, + ); + + assert_ne!(new_type, array.data_type().clone()); + + let new_array = cast(&array, &new_type.clone()).unwrap(); + assert_eq!(new_type, new_array.data_type().clone()); + let map_array = new_array.as_map(); + + assert_ne!(new_type, array.data_type().clone()); + assert_eq!(new_type, map_array.data_type().clone()); + + let key_string = map_array + .keys() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&key_string, &vec!["0", "1"]); + + let values_string_array = cast(map_array.values(), &DataType::Utf8).unwrap(); + let values_string = values_string_array + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&values_string, &vec!["test_val_1", "test_val_2"]); + + assert_eq!( + map_array.nulls(), + Some(&NullBuffer::from(vec![true, true, false])) + ); + } + + #[test] + fn test_cast_map_contained_values() { + let string_builder = StringBuilder::new(); + let value_builder = Int8Builder::new(); + let mut builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + string_builder, + value_builder, + ); + + builder.keys().append_value("0"); + builder.values().append_value(44); + builder.append(true).unwrap(); + builder.keys().append_value("1"); + builder.values().append_value(22); + builder.append(true).unwrap(); + + let array = builder.finish(); + + let new_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ] + .into(), + ), + false, + )), + false, + ); + + let new_array = cast(&array, &new_type.clone()).unwrap(); + assert_eq!(new_type, new_array.data_type().clone()); + let map_array = new_array.as_map(); + + assert_ne!(new_type, array.data_type().clone()); + assert_eq!(new_type, map_array.data_type().clone()); + + let key_string = map_array + .keys() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&key_string, &vec!["0", "1"]); + + let values_string_array = cast(map_array.values(), &DataType::Utf8).unwrap(); + let values_string = values_string_array + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&values_string, &vec!["44", "22"]); + } + + #[test] + fn test_utf8_cast_offsets() { + // test if offset of the array is taken into account during cast + let str_array = StringArray::from(vec!["a", "b", "c"]); + let str_array = str_array.slice(1, 2); + + let out = cast(&str_array, &DataType::LargeUtf8).unwrap(); + + let large_str_array = out.as_any().downcast_ref::().unwrap(); + let strs = large_str_array.into_iter().flatten().collect::>(); + assert_eq!(strs, &["b", "c"]) + } + + #[test] + fn test_list_cast_offsets() { + // test if offset of the array is taken into account during cast + let array1 = make_list_array().slice(1, 2); + let array2 = Arc::new(make_list_array()) as ArrayRef; + + let dt = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let out1 = cast(&array1, &dt).unwrap(); + let out2 = cast(&array2, &dt).unwrap(); + + assert_eq!(&out1, &out2.slice(1, 2)) + } + + #[test] + fn test_list_to_string() { + let str_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g", "h"]); + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + let value_data = str_array.into_data(); + + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[a, b, c]", "[d, e, f]", "[g, h]"]); + + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[a, b, c]", "[d, e, f]", "[g, h]"]); + + let array = Arc::new(make_list_array()) as ArrayRef; + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); + + let array = Arc::new(make_large_list_array()) as ArrayRef; + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); + } + + #[test] + fn test_cast_f64_to_decimal128() { + // to reproduce https://github.com/apache/arrow-rs/issues/2997 + + let decimal_type = DataType::Decimal128(18, 2); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(7_i128), // round up + Some(7_i128), // round up + Some(7_i128), // round up + Some(6_i128), // round down + ] + ); + + let decimal_type = DataType::Decimal128(18, 3); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(70_i128), // round up + Some(66_i128), // round up + Some(65_i128), // round down + Some(65_i128), // round up + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal128_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + } + + #[test] + fn test_cast_numeric_to_decimal256_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + } + + #[test] + fn test_cast_floating_point_to_decimal128_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal128 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal256_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal256 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal128_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal128(38, 30)"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal256_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal256(76, 50)"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_decimal128_to_decimal128_negative_scale() { + let input_type = DataType::Decimal128(20, 0); + let output_type = DataType::Decimal128(20, -1); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123450), Some(2123455), Some(3123456), None]; + let input_decimal_array = create_decimal_array(array, 20, 0).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(312346_i128), + None + ] + ); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123460", decimal_arr.value_as_string(1)); + assert_eq!("3123460", decimal_arr.value_as_string(2)); + } + + #[test] + fn test_cast_numeric_to_decimal128_negative() { + let decimal_type = DataType::Decimal128(38, -1); + let array = Arc::new(Int32Array::from(vec![ + Some(1123456), + Some(2123456), + Some(3123456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123450", decimal_arr.value_as_string(1)); + assert_eq!("3123450", decimal_arr.value_as_string(2)); + + let array = Arc::new(Float32Array::from(vec![ + Some(1123.456), + Some(2123.456), + Some(3123.456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1120", decimal_arr.value_as_string(0)); + assert_eq!("2120", decimal_arr.value_as_string(1)); + assert_eq!("3120", decimal_arr.value_as_string(2)); + } + + #[test] + fn test_cast_decimal128_to_decimal128_negative() { + let input_type = DataType::Decimal128(10, -1); + let output_type = DataType::Decimal128(10, -2); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(123)]; + let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(12_i128),]); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1200", decimal_arr.value_as_string(0)); + + let array = vec![Some(125)]; + let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(13_i128),]); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1300", decimal_arr.value_as_string(0)); + } + + #[test] + fn test_cast_decimal128_to_decimal256_negative() { + let input_type = DataType::Decimal128(10, 3); + let output_type = DataType::Decimal256(10, 5); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i128::MAX), Some(i128::MIN)]; + let input_decimal_array = create_decimal_array(array, 10, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + + let hundred = i256::from_i128(100); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)), + Some(i256::from_i128(i128::MIN).mul_wrapping(hundred)) + ] + ); + } + + #[test] + fn test_parse_string_to_decimal() { + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("123.45", 2).unwrap(), + 38, + 2, + ), + "123.45" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("12345", 2).unwrap(), + 38, + 2, + ), + "12345.00" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 2).unwrap(), + 38, + 2, + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".12345", 2).unwrap(), + 38, + 2, + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2, + ), + "0.13" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2, + ), + "0.13" + ); + + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("123.45", 3).unwrap(), + 38, + 3, + ), + "123.450" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("12345", 3).unwrap(), + 38, + 3, + ), + "12345.000" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 3).unwrap(), + 38, + 3, + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".12345", 3).unwrap(), + 38, + 3, + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".1265", 3).unwrap(), + 38, + 3, + ), + "0.127" + ); + } + + fn test_cast_string_to_decimal(array: ArrayRef) { + // Decimal128 + let output_type = DataType::Decimal128(38, 2); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("123.45", decimal_arr.value_as_string(0)); + assert_eq!("1.23", decimal_arr.value_as_string(1)); + assert_eq!("0.12", decimal_arr.value_as_string(2)); + assert_eq!("0.13", decimal_arr.value_as_string(3)); + assert_eq!("1.26", decimal_arr.value_as_string(4)); + assert_eq!("12345.00", decimal_arr.value_as_string(5)); + assert_eq!("12345.00", decimal_arr.value_as_string(6)); + assert_eq!("0.12", decimal_arr.value_as_string(7)); + assert_eq!("12.23", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.00", decimal_arr.value_as_string(10)); + assert_eq!("0.00", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + assert_eq!("-1.23", decimal_arr.value_as_string(13)); + assert_eq!("-1.24", decimal_arr.value_as_string(14)); + assert_eq!("0.00", decimal_arr.value_as_string(15)); + assert_eq!("-123.00", decimal_arr.value_as_string(16)); + assert_eq!("-123.23", decimal_arr.value_as_string(17)); + assert_eq!("-0.12", decimal_arr.value_as_string(18)); + assert_eq!("1.23", decimal_arr.value_as_string(19)); + assert_eq!("1.24", decimal_arr.value_as_string(20)); + assert_eq!("0.00", decimal_arr.value_as_string(21)); + assert_eq!("123.00", decimal_arr.value_as_string(22)); + assert_eq!("123.23", decimal_arr.value_as_string(23)); + assert_eq!("0.12", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); + assert!(decimal_arr.is_null(27)); + + // Decimal256 + let output_type = DataType::Decimal256(76, 3); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("123.450", decimal_arr.value_as_string(0)); + assert_eq!("1.235", decimal_arr.value_as_string(1)); + assert_eq!("0.123", decimal_arr.value_as_string(2)); + assert_eq!("0.127", decimal_arr.value_as_string(3)); + assert_eq!("1.263", decimal_arr.value_as_string(4)); + assert_eq!("12345.000", decimal_arr.value_as_string(5)); + assert_eq!("12345.000", decimal_arr.value_as_string(6)); + assert_eq!("0.123", decimal_arr.value_as_string(7)); + assert_eq!("12.234", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.000", decimal_arr.value_as_string(10)); + assert_eq!("0.000", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + assert_eq!("-1.235", decimal_arr.value_as_string(13)); + assert_eq!("-1.236", decimal_arr.value_as_string(14)); + assert_eq!("0.000", decimal_arr.value_as_string(15)); + assert_eq!("-123.000", decimal_arr.value_as_string(16)); + assert_eq!("-123.234", decimal_arr.value_as_string(17)); + assert_eq!("-0.123", decimal_arr.value_as_string(18)); + assert_eq!("1.235", decimal_arr.value_as_string(19)); + assert_eq!("1.236", decimal_arr.value_as_string(20)); + assert_eq!("0.000", decimal_arr.value_as_string(21)); + assert_eq!("123.000", decimal_arr.value_as_string(22)); + assert_eq!("123.234", decimal_arr.value_as_string(23)); + assert_eq!("0.123", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); + assert!(decimal_arr.is_null(27)); + } + + #[test] + fn test_cast_utf8_to_decimal() { + let str_array = StringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + Some("-1.23499999"), + Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), + Some("--1.23499999"), + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_large_utf8_to_decimal() { + let str_array = LargeStringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + Some("-1.23499999"), + Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), + Some("--1.23499999"), + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_invalid_utf8_to_decimal() { + let str_array = StringArray::from(vec!["4.4.5", ". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + + // Safe cast + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + // Non-safe cast + let output_type = DataType::Decimal128(38, 2); + let str_array = StringArray::from(vec!["4.4.5"]); + let array = Arc::new(str_array) as ArrayRef; + let option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err + .to_string() + .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + + let str_array = StringArray::from(vec![". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err + .to_string() + .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); + } + + fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert!(decimal_arr.is_null(0)); + assert!(decimal_arr.is_null(1)); + assert!(decimal_arr.is_null(2)); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + } + + #[test] + fn test_cast_string_to_decimal128_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_utf8_to_decimal128_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal128_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + fn test_cast_string_to_decimal256_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!( + "170141183460469231731687303715884105727.00", + decimal_arr.value_as_string(0) + ); + assert_eq!( + "-170141183460469231731687303715884105728.00", + decimal_arr.value_as_string(1) + ); + assert_eq!( + "99999999999999999999999999999999999999.00", + decimal_arr.value_as_string(2) + ); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + assert!(decimal_arr.is_null(5)); + assert!(decimal_arr.is_null(6)); + } + + #[test] + fn test_cast_string_to_decimal256_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_utf8_to_decimal256_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal256_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } + + #[test] + fn test_cast_outside_supported_range_for_nanoseconds() { + const EXPECTED_ERROR_MESSAGE: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; + + let array = StringArray::from(vec![Some("1650-01-01 01:01:01.000001")]); + + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + let result = cast_string_to_timestamp::( + &array, + &None::>, + &cast_options, + ); + + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + format!( + "Cast error: Overflow converting {} to Nanosecond. {}", + array.value(0), + EXPECTED_ERROR_MESSAGE + ) + ); + } + + #[test] + fn test_cast_date32_to_timestamp() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1609459200, c.value(0)); + assert_eq!(1640995200, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_ms() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Millisecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000, c.value(0)); + assert_eq!(1640995200000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_us() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000000, c.value(0)); + assert_eq!(1640995200000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_ns() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000000000, c.value(0)); + assert_eq!(1640995200000000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_timezone_cast() { + let a = StringArray::from(vec![ + "2000-01-01T12:00:00", // date + time valid + "2020-12-15T12:34:56", // date + time valid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000000000); + assert_eq!(v.value(1), 1608035696000000000); + + let b = cast( + &b, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + ) + .unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000000000); + assert_eq!(v.value(1), 1608035696000000000); + + let b = cast( + &b, + &DataType::Timestamp(TimeUnit::Millisecond, Some("+02:00".into())), + ) + .unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000); + assert_eq!(v.value(1), 1608035696000); + } + + #[test] + fn test_cast_utf8_to_timestamp() { + fn test_tz(tz: Arc) { + let valid = StringArray::from(vec![ + "2023-01-01 04:05:06.789000-08:00", + "2023-01-01 04:05:06.789000-07:00", + "2023-01-01 04:05:06.789 -0800", + "2023-01-01 04:05:06.789 -08:00", + "2023-01-01 040506 +0730", + "2023-01-01 040506 +07:30", + "2023-01-01 04:05:06.789", + "2023-01-01 04:05:06", + "2023-01-01", + ]); + + let array = Arc::new(valid) as ArrayRef; + let b = cast_with_options( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.clone())), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + + let tz = tz.as_ref().parse().unwrap(); + + let as_tz = + |v: i64| as_datetime_with_timezone::(v, tz).unwrap(); + + let as_utc = |v: &i64| as_tz(*v).naive_utc().to_string(); + let as_local = |v: &i64| as_tz(*v).naive_local().to_string(); + + let values = b.as_primitive::().values(); + let utc_results: Vec<_> = values.iter().map(as_utc).collect(); + let local_results: Vec<_> = values.iter().map(as_local).collect(); + + // Absolute timestamps should be parsed preserving the same UTC instant + assert_eq!( + &utc_results[..6], + &[ + "2023-01-01 12:05:06.789".to_string(), + "2023-01-01 11:05:06.789".to_string(), + "2023-01-01 12:05:06.789".to_string(), + "2023-01-01 12:05:06.789".to_string(), + "2022-12-31 20:35:06".to_string(), + "2022-12-31 20:35:06".to_string(), + ] + ); + // Non-absolute timestamps should be parsed preserving the same local instant + assert_eq!( + &local_results[6..], + &[ + "2023-01-01 04:05:06.789".to_string(), + "2023-01-01 04:05:06".to_string(), + "2023-01-01 00:00:00".to_string() + ] + ) + } + + test_tz("+00:00".into()); + test_tz("+02:00".into()); + } + + #[test] + fn test_cast_invalid_utf8() { + let v1: &[u8] = b"\xFF invalid"; + let v2: &[u8] = b"\x00 Foo"; + let s = BinaryArray::from(vec![v1, v2]); + let options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap(); + let a = array.as_string::(); + a.to_data().validate_full().unwrap(); + + assert_eq!(a.null_count(), 1); + assert_eq!(a.len(), 2); + assert!(a.is_null(0)); + assert_eq!(a.value(0), ""); + assert_eq!(a.value(1), "\x00 Foo"); + } + + #[test] + fn test_cast_utf8_to_timestamptz() { + let valid = StringArray::from(vec!["2023-01-01"]); + + let array = Arc::new(valid) as ArrayRef; + let b = cast( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + ) + .unwrap(); + + let expect = DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())); + + assert_eq!(b.data_type(), &expect); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1672531200000000000, c.value(0)); + } + + #[test] + fn test_cast_decimal_to_utf8() { + fn test_decimal_to_string( + output_type: DataType, + array: PrimitiveArray, + ) { + let b = cast(&array, &output_type).unwrap(); + + assert_eq!(b.data_type(), &output_type); + let c = b.as_string::(); + + assert_eq!("1123.454", c.value(0)); + assert_eq!("2123.456", c.value(1)); + assert_eq!("-3123.453", c.value(2)); + assert_eq!("-3123.456", c.value(3)); + assert_eq!("0.000", c.value(4)); + assert_eq!("0.123", c.value(5)); + assert_eq!("1234.567", c.value(6)); + assert_eq!("-1234.567", c.value(7)); + assert!(c.is_null(8)); + } + let array128: Vec> = vec![ + Some(1123454), + Some(2123456), + Some(-3123453), + Some(-3123456), + Some(0), + Some(123), + Some(123456789), + Some(-123456789), + None, + ]; + + let array256: Vec> = array128.iter().map(|v| v.map(i256::from_i128)).collect(); + + test_decimal_to_string::( + DataType::Utf8, + create_decimal_array(array128.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal_array(array128, 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal256_array(array256.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal256_array(array256, 7, 3).unwrap(), + ); + } + + #[test] + fn test_cast_numeric_to_decimal128_precision_overflow() { + let array = Int64Array::from(vec![1234567]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(7, 3), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(7, 3), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal128 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_numeric_to_decimal256_precision_overflow() { + let array = Int64Array::from(vec![1234567]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(7, 3), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(7, 3), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + } + + /// helper function to test casting from duration to interval + fn cast_from_duration_to_interval>( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::new(array.into(), None); + let array = Arc::new(array) as ArrayRef; + let interval = DataType::Interval(IntervalUnit::MonthDayNano); + let out = cast_with_options(&array, &interval, cast_options)?; + let out = out.as_primitive::().clone(); + Ok(out) + } + + #[test] + fn test_cast_from_duration_to_interval() { + // from duration second to interval month day nano + let array = vec![1234567]; + let casted_array = + cast_from_duration_to_interval::(array, &CastOptions::default()) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(0, 0, 1234567000000000) + ); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration millisecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(0, 0, 1234567000000) + ); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration microsecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(0, 0, 1234567000) + ); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration nanosecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(0, 0, 1234567) + ); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(0, 0, i64::MAX) + ); + } + + /// helper function to test casting from interval to duration + fn cast_from_interval_to_duration( + array: &IntervalMonthDayNanoArray, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let casted_array = cast_with_options(&array, &T::DATA_TYPE, cast_options)?; + casted_array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError(format!("Failed to downcast to {}", T::DATA_TYPE)) + }) + .cloned() + } + + #[test] + fn test_cast_from_interval_to_duration() { + let nullable = CastOptions::default(); + let fallible = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let v = IntervalMonthDayNano::new(0, 0, 1234567); + + // from interval month day nano to duration second + let array = vec![v].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 0); + + let array = vec![IntervalMonthDayNano::MAX].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let res = cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); + + // from interval month day nano to duration millisecond + let array = vec![v].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1); + + let array = vec![IntervalMonthDayNano::MAX].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let res = cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); + + // from interval month day nano to duration microsecond + let array = vec![v].into(); + let casted_array: DurationMicrosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1234); + + let array = vec![IntervalMonthDayNano::MAX].into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); + assert!(casted_array.is_err()); + + // from interval month day nano to duration nanosecond + let array = vec![v].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1234567); + + let array = vec![IntervalMonthDayNano::MAX].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); + assert!(casted_array.is_err()); + + let array = vec![ + IntervalMonthDayNanoType::make_value(0, 1, 0), + IntervalMonthDayNanoType::make_value(-1, 0, 0), + IntervalMonthDayNanoType::make_value(1, 1, 0), + IntervalMonthDayNanoType::make_value(1, 0, 1), + IntervalMonthDayNanoType::make_value(0, 0, -1), + ] + .into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + assert!(!casted_array.is_valid(1)); + assert!(!casted_array.is_valid(2)); + assert!(!casted_array.is_valid(3)); + assert!(casted_array.is_valid(4)); + assert_eq!(casted_array.value(4), -1); + } + + /// helper function to test casting from interval year month to interval month day nano + fn cast_from_interval_year_month_to_interval_month_day_nano( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Interval(IntervalUnit::MonthDayNano), + cast_options, + )?; + casted_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Failed to downcast to IntervalMonthDayNanoArray".to_string(), + ) + }) + .cloned() + } + + #[test] + fn test_cast_from_interval_year_month_to_interval_month_day_nano() { + // from interval year month to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_interval_year_month_to_interval_month_day_nano( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!( + casted_array.value(0), + IntervalMonthDayNano::new(1234567, 0, 0) + ); + } + + /// helper function to test casting from interval day time to interval month day nano + fn cast_from_interval_day_time_to_interval_month_day_nano( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Interval(IntervalUnit::MonthDayNano), + cast_options, + )?; + Ok(casted_array + .as_primitive::() + .clone()) + } + + #[test] + fn test_cast_from_interval_day_time_to_interval_month_day_nano() { + // from interval day time to interval month day nano + let array = vec![IntervalDayTime::new(123, 0)]; + let casted_array = + cast_from_interval_day_time_to_interval_month_day_nano(array, &CastOptions::default()) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), IntervalMonthDayNano::new(0, 123, 0)); + } + + #[test] + fn test_cast_below_unixtimestamp() { + let valid = StringArray::from(vec![ + "1900-01-03 23:59:59", + "1969-12-31 00:00:01", + "1989-12-31 00:00:01", + ]); + + let array = Arc::new(valid) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + + let ts_array = casted_array + .as_primitive::() + .values() + .iter() + .map(|ts| ts / 1_000_000) + .collect::>(); + + let array = TimestampMillisecondArray::from(ts_array).with_timezone("+00:00".to_string()); + let casted_array = cast(&array, &DataType::Date32).unwrap(); + let date_array = casted_array.as_primitive::(); + let casted_array = cast(&date_array, &DataType::Utf8).unwrap(); + let string_array = casted_array.as_string::(); + assert_eq!("1900-01-03", string_array.value(0)); + assert_eq!("1969-12-31", string_array.value(1)); + assert_eq!("1989-12-31", string_array.value(2)); + } + + #[test] + fn test_nested_list() { + let mut list = ListBuilder::new(Int32Builder::new()); + list.append_value([Some(1), Some(2), Some(3)]); + list.append_value([Some(4), None, Some(6)]); + let list = list.finish(); + + let to_field = Field::new("nested", list.data_type().clone(), false); + let to = DataType::List(Arc::new(to_field)); + let out = cast(&list, &to).unwrap(); + let opts = FormatOptions::default().with_null("null"); + let formatted = ArrayFormatter::try_new(out.as_ref(), &opts).unwrap(); + + assert_eq!(formatted.value(0).to_string(), "[[1], [2], [3]]"); + assert_eq!(formatted.value(1).to_string(), "[[4], [null], [6]]"); + } + + #[test] + fn test_nested_list_cast() { + let mut builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + builder.append_value([Some([Some(1), Some(2), None]), None]); + builder.append_value([None, Some([]), None]); + builder.append_null(); + builder.append_value([Some([Some(2), Some(3)])]); + let start = builder.finish(); + + let mut builder = LargeListBuilder::new(LargeListBuilder::new(Int8Builder::new())); + builder.append_value([Some([Some(1), Some(2), None]), None]); + builder.append_value([None, Some([]), None]); + builder.append_null(); + builder.append_value([Some([Some(2), Some(3)])]); + let expected = builder.finish(); + + let actual = cast(&start, expected.data_type()).unwrap(); + assert_eq!(actual.as_ref(), &expected); + } + + const CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: FormatOptions::new(), + }; + + #[test] + #[allow(clippy::assertions_on_constants)] + fn test_const_options() { + assert!(CAST_OPTIONS.safe) + } + + #[test] + fn test_list_format_options() { + let options = CastOptions { + safe: false, + format_options: FormatOptions::default().with_null("null"), + }; + let array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(0), None, Some(2)]), + ]); + let a = cast_with_options(&array, &DataType::Utf8, &options).unwrap(); + let r: Vec<_> = a.as_string::().iter().flatten().collect(); + assert_eq!(r, &["[0, 1, 2]", "[0, null, 2]"]); + } + #[test] + fn test_cast_string_to_timestamp_invalid_tz() { + // content after Z should be ignored + let bad_timestamp = "2023-12-05T21:58:10.45ZZTOP"; + let array = StringArray::from(vec![Some(bad_timestamp)]); + + let data_types = [ + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ]; + + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + for dt in data_types { + assert_eq!( + cast_with_options(&array, &dt, &cast_options) + .unwrap_err() + .to_string(), + "Parser error: Invalid timezone \"ZZTOP\": only offset based timezones supported without chrono-tz feature" + ); + } + } + #[test] + fn test_cast_struct_to_struct() { + let struct_type = DataType::Struct( + vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ] + .into(), + ); + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + let casted_array = cast(&struct_array, &to_type).unwrap(); + let casted_array = casted_array.as_struct(); + assert_eq!(casted_array.data_type(), &to_type); + let casted_boolean_array = casted_array + .column(0) + .as_string::() + .into_iter() + .flatten() + .collect::>(); + let casted_int_array = casted_array + .column(1) + .as_string::() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(casted_boolean_array, vec!["false", "false", "true", "true"]); + assert_eq!(casted_int_array, vec!["42", "28", "19", "31"]); + + // test for can't cast + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Date32, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + assert!(!can_cast_types(&struct_type, &to_type)); + let result = cast(&struct_array, &to_type); + assert_eq!( + "Cast error: Casting from Boolean to Date32 not supported", + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_struct_to_struct_nullability() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None])); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, true)), + int.clone() as ArrayRef, + ), + ]); + + // okay: nullable to nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, true), + ] + .into(), + ); + cast(&struct_array, &to_type).expect("Cast nullable to nullable struct field should work"); + + // error: nullable to non-nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + cast(&struct_array, &to_type) + .expect_err("Cast nullable to non-nullable struct field should fail"); + + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100])); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + + // okay: non-nullable to non-nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + cast(&struct_array, &to_type) + .expect("Cast non-nullable to non-nullable struct field should work"); + + // err: non-nullable to non-nullable but overflowing return null during casting + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int8, false), + ] + .into(), + ); + cast(&struct_array, &to_type).expect_err( + "Cast non-nullable to non-nullable struct field returning null should fail", + ); + } +} diff --git a/arrow-cast/src/cast/string.rs b/arrow-cast/src/cast/string.rs new file mode 100644 index 000000000000..7d0e7e21c859 --- /dev/null +++ b/arrow-cast/src/cast/string.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; +use arrow_buffer::NullBuffer; + +pub(crate) fn value_to_string( + array: &dyn Array, + options: &CastOptions, +) -> Result { + let mut builder = GenericStringBuilder::::new(); + let formatter = ArrayFormatter::try_new(array, &options.format_options)?; + let nulls = array.nulls(); + for i in 0..array.len() { + match nulls.map(|x| x.is_null(i)).unwrap_or_default() { + true => builder.append_null(), + false => { + formatter.value(i).write(&mut builder)?; + // tell the builder the row is finished + builder.append_value(""); + } + } + } + Ok(Arc::new(builder.finish())) +} + +/// Parse UTF-8 +pub(crate) fn parse_string( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let string_array = array.as_string::(); + parse_string_iter::(string_array.iter(), cast_options, || { + string_array.nulls().cloned() + }) +} + +/// Parse UTF-8 View +pub(crate) fn parse_string_view( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let string_view_array = array.as_string_view(); + parse_string_iter::(string_view_array.iter(), cast_options, || { + string_view_array.nulls().cloned() + }) +} + +fn parse_string_iter< + 'a, + P: Parser, + I: Iterator>, + F: FnOnce() -> Option, +>( + iter: I, + cast_options: &CastOptions, + nulls: F, +) -> Result { + let array = if cast_options.safe { + let iter = iter.map(|x| x.and_then(P::parse)); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::

::from_trusted_len_iter(iter) } + } else { + let v = iter + .map(|x| match x { + Some(v) => P::parse(v).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + P::DATA_TYPE + )) + }), + None => Ok(P::Native::default()), + }) + .collect::, ArrowError>>()?; + PrimitiveArray::new(v.into(), nulls()) + }; + + Ok(Arc::new(array) as ArrayRef) +} + +/// Casts generic string arrays to an ArrowTimestampType (TimeStampNanosecondArray, etc.) +pub(crate) fn cast_string_to_timestamp( + array: &dyn Array, + to_tz: &Option>, + cast_options: &CastOptions, +) -> Result { + let array = array.as_string::(); + let out: PrimitiveArray = match to_tz { + Some(tz) => { + let tz: Tz = tz.as_ref().parse()?; + cast_string_to_timestamp_impl(array.iter(), &tz, cast_options)? + } + None => cast_string_to_timestamp_impl(array.iter(), &Utc, cast_options)?, + }; + Ok(Arc::new(out.with_timezone_opt(to_tz.clone()))) +} + +/// Casts string view arrays to an ArrowTimestampType (TimeStampNanosecondArray, etc.) +pub(crate) fn cast_view_to_timestamp( + array: &dyn Array, + to_tz: &Option>, + cast_options: &CastOptions, +) -> Result { + let array = array.as_string_view(); + let out: PrimitiveArray = match to_tz { + Some(tz) => { + let tz: Tz = tz.as_ref().parse()?; + cast_string_to_timestamp_impl(array.iter(), &tz, cast_options)? + } + None => cast_string_to_timestamp_impl(array.iter(), &Utc, cast_options)?, + }; + Ok(Arc::new(out.with_timezone_opt(to_tz.clone()))) +} + +fn cast_string_to_timestamp_impl< + 'a, + I: Iterator>, + T: ArrowTimestampType, + Tz: TimeZone, +>( + iter: I, + tz: &Tz, + cast_options: &CastOptions, +) -> Result, ArrowError> { + if cast_options.safe { + let iter = iter.map(|v| { + v.and_then(|v| { + let naive = string_to_datetime(tz, v).ok()?.naive_utc(); + T::make_value(naive) + }) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + + Ok(unsafe { PrimitiveArray::from_trusted_len_iter(iter) }) + } else { + let vec = iter + .map(|v| { + v.map(|v| { + let naive = string_to_datetime(tz, v)?.naive_utc(); + T::make_value(naive).ok_or_else(|| match T::UNIT { + TimeUnit::Nanosecond => ArrowError::CastError(format!( + "Overflow converting {naive} to Nanosecond. The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804" + )), + _ => ArrowError::CastError(format!( + "Overflow converting {naive} to {:?}", + T::UNIT + )) + }) + }) + .transpose() + }) + .collect::>, _>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { PrimitiveArray::from_trusted_len_iter(vec.iter()) }) + } +} + +pub(crate) fn cast_string_to_interval( + array: &dyn Array, + cast_options: &CastOptions, + parse_function: F, +) -> Result +where + Offset: OffsetSizeTrait, + ArrowType: ArrowPrimitiveType, + F: Fn(&str) -> Result + Copy, +{ + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + cast_string_to_interval_impl::<_, ArrowType, F>( + string_array.iter(), + cast_options, + parse_function, + ) +} + +pub(crate) fn cast_string_to_year_month_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_year_month, + ) +} + +pub(crate) fn cast_string_to_day_time_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_day_time, + ) +} + +pub(crate) fn cast_string_to_month_day_nano_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_month_day_nano, + ) +} + +pub(crate) fn cast_view_to_interval( + array: &dyn Array, + cast_options: &CastOptions, + parse_function: F, +) -> Result +where + ArrowType: ArrowPrimitiveType, + F: Fn(&str) -> Result + Copy, +{ + let string_view_array = array.as_any().downcast_ref::().unwrap(); + cast_string_to_interval_impl::<_, ArrowType, F>( + string_view_array.iter(), + cast_options, + parse_function, + ) +} + +pub(crate) fn cast_view_to_year_month_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_view_to_interval::<_, IntervalYearMonthType>( + array, + cast_options, + parse_interval_year_month, + ) +} + +pub(crate) fn cast_view_to_day_time_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_view_to_interval::<_, IntervalDayTimeType>(array, cast_options, parse_interval_day_time) +} + +pub(crate) fn cast_view_to_month_day_nano_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_view_to_interval::<_, IntervalMonthDayNanoType>( + array, + cast_options, + parse_interval_month_day_nano, + ) +} + +fn cast_string_to_interval_impl<'a, I, ArrowType, F>( + iter: I, + cast_options: &CastOptions, + parse_function: F, +) -> Result +where + I: Iterator>, + ArrowType: ArrowPrimitiveType, + F: Fn(&str) -> Result + Copy, +{ + let interval_array = if cast_options.safe { + let iter = iter.map(|v| v.and_then(|v| parse_function(v).ok())); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } + } else { + let vec = iter + .map(|v| v.map(parse_function).transpose()) + .collect::, ArrowError>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::::from_trusted_len_iter(vec) } + }; + Ok(Arc::new(interval_array) as ArrayRef) +} + +/// A specified helper to cast from `GenericBinaryArray` to `GenericStringArray` when they have same +/// offset size so re-encoding offset is unnecessary. +pub(crate) fn cast_binary_to_string( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::>>() + .unwrap(); + + match GenericStringArray::::try_from_binary(array.clone()) { + Ok(a) => Ok(Arc::new(a)), + Err(e) => match cast_options.safe { + true => { + // Fallback to slow method to convert invalid sequences to nulls + let mut builder = + GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + + let iter = array + .iter() + .map(|v| v.and_then(|v| std::str::from_utf8(v).ok())); + + builder.extend(iter); + Ok(Arc::new(builder.finish())) + } + false => Err(e), + }, + } +} + +/// Casts Utf8 to Boolean +pub(crate) fn cast_utf8_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => Ok(Some(true)), + "f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" | "0" => { + Ok(Some(false)) + } + invalid_value => match cast_options.safe { + true => Ok(None), + false => Err(ArrowError::CastError(format!( + "Cannot cast value '{invalid_value}' to value of Boolean type", + ))), + }, + }, + None => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) +} diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs new file mode 100644 index 000000000000..df96816ea23a --- /dev/null +++ b/arrow-cast/src/display.rs @@ -0,0 +1,1222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for printing array values as human-readable strings. +//! +//! This is often used for debugging or logging purposes. +//! +//! See the [`pretty`] crate for additional functions for +//! record batch pretty printing. +//! +//! [`pretty`]: crate::pretty +use std::fmt::{Display, Formatter, Write}; +use std::ops::Range; + +use arrow_array::cast::*; +use arrow_array::temporal_conversions::*; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::*; +use chrono::{NaiveDate, NaiveDateTime, SecondsFormat, TimeZone, Utc}; +use lexical_core::FormattedSize; + +type TimeFormat<'a> = Option<&'a str>; + +/// Format for displaying durations +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum DurationFormat { + /// ISO 8601 - `P198DT72932.972880S` + ISO8601, + /// A human readable representation - `198 days 16 hours 34 mins 15.407810000 secs` + Pretty, +} + +/// Options for formatting arrays +/// +/// By default nulls are formatted as `""` and temporal types formatted +/// according to RFC3339 +/// +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FormatOptions<'a> { + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + safe: bool, + /// Format string for nulls + null: &'a str, + /// Date format for date arrays + date_format: TimeFormat<'a>, + /// Format for DateTime arrays + datetime_format: TimeFormat<'a>, + /// Timestamp format for timestamp arrays + timestamp_format: TimeFormat<'a>, + /// Timestamp format for timestamp with timezone arrays + timestamp_tz_format: TimeFormat<'a>, + /// Time format for time arrays + time_format: TimeFormat<'a>, + /// Duration format + duration_format: DurationFormat, +} + +impl<'a> Default for FormatOptions<'a> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> FormatOptions<'a> { + /// Creates a new set of format options + pub const fn new() -> Self { + Self { + safe: true, + null: "", + date_format: None, + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + duration_format: DurationFormat::ISO8601, + } + } + + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + pub const fn with_display_error(mut self, safe: bool) -> Self { + self.safe = safe; + self + } + + /// Overrides the string used to represent a null + /// + /// Defaults to `""` + pub const fn with_null(self, null: &'a str) -> Self { + Self { null, ..self } + } + + /// Overrides the format used for [`DataType::Date32`] columns + pub const fn with_date_format(self, date_format: Option<&'a str>) -> Self { + Self { + date_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Date64`] columns + pub const fn with_datetime_format(self, datetime_format: Option<&'a str>) -> Self { + Self { + datetime_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Timestamp`] columns without a timezone + pub const fn with_timestamp_format(self, timestamp_format: Option<&'a str>) -> Self { + Self { + timestamp_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Timestamp`] columns with a timezone + pub const fn with_timestamp_tz_format(self, timestamp_tz_format: Option<&'a str>) -> Self { + Self { + timestamp_tz_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Time32`] and [`DataType::Time64`] columns + pub const fn with_time_format(self, time_format: Option<&'a str>) -> Self { + Self { + time_format, + ..self + } + } + + /// Overrides the format used for duration columns + /// + /// Defaults to [`DurationFormat::ISO8601`] + pub const fn with_duration_format(self, duration_format: DurationFormat) -> Self { + Self { + duration_format, + ..self + } + } +} + +/// Implements [`Display`] for a specific array value +pub struct ValueFormatter<'a> { + idx: usize, + formatter: &'a ArrayFormatter<'a>, +} + +impl<'a> ValueFormatter<'a> { + /// Writes this value to the provided [`Write`] + /// + /// Note: this ignores [`FormatOptions::with_display_error`] and + /// will return an error on formatting issue + pub fn write(&self, s: &mut dyn Write) -> Result<(), ArrowError> { + match self.formatter.format.write(self.idx, s) { + Ok(_) => Ok(()), + Err(FormatError::Arrow(e)) => Err(e), + Err(FormatError::Format(_)) => Err(ArrowError::CastError("Format error".to_string())), + } + } + + /// Fallibly converts this to a string + pub fn try_to_string(&self) -> Result { + let mut s = String::new(); + self.write(&mut s)?; + Ok(s) + } +} + +impl<'a> Display for ValueFormatter<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.formatter.format.write(self.idx, f) { + Ok(()) => Ok(()), + Err(FormatError::Arrow(e)) if self.formatter.safe => { + write!(f, "ERROR: {e}") + } + Err(_) => Err(std::fmt::Error), + } + } +} + +/// A string formatter for an [`Array`] +/// +/// This can be used with [`std::write`] to write type-erased `dyn Array` +/// +/// ``` +/// # use std::fmt::{Display, Formatter, Write}; +/// # use arrow_array::{Array, ArrayRef, Int32Array}; +/// # use arrow_cast::display::{ArrayFormatter, FormatOptions}; +/// # use arrow_schema::ArrowError; +/// struct MyContainer { +/// values: ArrayRef, +/// } +/// +/// impl Display for MyContainer { +/// fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +/// let options = FormatOptions::default(); +/// let formatter = ArrayFormatter::try_new(self.values.as_ref(), &options) +/// .map_err(|_| std::fmt::Error)?; +/// +/// let mut iter = 0..self.values.len(); +/// if let Some(idx) = iter.next() { +/// write!(f, "{}", formatter.value(idx))?; +/// } +/// for idx in iter { +/// write!(f, ", {}", formatter.value(idx))?; +/// } +/// Ok(()) +/// } +/// } +/// ``` +/// +/// [`ValueFormatter::write`] can also be used to get a semantic error, instead of the +/// opaque [`std::fmt::Error`] +/// +/// ``` +/// # use std::fmt::Write; +/// # use arrow_array::Array; +/// # use arrow_cast::display::{ArrayFormatter, FormatOptions}; +/// # use arrow_schema::ArrowError; +/// fn format_array( +/// f: &mut dyn Write, +/// array: &dyn Array, +/// options: &FormatOptions, +/// ) -> Result<(), ArrowError> { +/// let formatter = ArrayFormatter::try_new(array, options)?; +/// for i in 0..array.len() { +/// formatter.value(i).write(f)? +/// } +/// Ok(()) +/// } +/// ``` +/// +pub struct ArrayFormatter<'a> { + format: Box, + safe: bool, +} + +impl<'a> ArrayFormatter<'a> { + /// Returns an [`ArrayFormatter`] that can be used to format `array` + /// + /// This returns an error if an array of the given data type cannot be formatted + pub fn try_new(array: &'a dyn Array, options: &FormatOptions<'a>) -> Result { + Ok(Self { + format: make_formatter(array, options)?, + safe: options.safe, + }) + } + + /// Returns a [`ValueFormatter`] that implements [`Display`] for + /// the value of the array at `idx` + pub fn value(&self, idx: usize) -> ValueFormatter<'_> { + ValueFormatter { + formatter: self, + idx, + } + } +} + +fn make_formatter<'a>( + array: &'a dyn Array, + options: &FormatOptions<'a>, +) -> Result, ArrowError> { + downcast_primitive_array! { + array => array_format(array, options), + DataType::Null => array_format(as_null_array(array), options), + DataType::Boolean => array_format(as_boolean_array(array), options), + DataType::Utf8 => array_format(array.as_string::(), options), + DataType::LargeUtf8 => array_format(array.as_string::(), options), + DataType::Utf8View => array_format(array.as_string_view(), options), + DataType::Binary => array_format(array.as_binary::(), options), + DataType::BinaryView => array_format(array.as_binary_view(), options), + DataType::LargeBinary => array_format(array.as_binary::(), options), + DataType::FixedSizeBinary(_) => { + let a = array.as_any().downcast_ref::().unwrap(); + array_format(a, options) + } + DataType::Dictionary(_, _) => downcast_dictionary_array! { + array => array_format(array, options), + _ => unreachable!() + } + DataType::List(_) => array_format(as_generic_list_array::(array), options), + DataType::LargeList(_) => array_format(as_generic_list_array::(array), options), + DataType::FixedSizeList(_, _) => { + let a = array.as_any().downcast_ref::().unwrap(); + array_format(a, options) + } + DataType::Struct(_) => array_format(as_struct_array(array), options), + DataType::Map(_, _) => array_format(as_map_array(array), options), + DataType::Union(_, _) => array_format(as_union_array(array), options), + DataType::RunEndEncoded(_, _) => downcast_run_array! { + array => array_format(array, options), + _ => unreachable!() + }, + d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not yet supported"))), + } +} + +/// Either an [`ArrowError`] or [`std::fmt::Error`] +enum FormatError { + Format(std::fmt::Error), + Arrow(ArrowError), +} + +type FormatResult = Result<(), FormatError>; + +impl From for FormatError { + fn from(value: std::fmt::Error) -> Self { + Self::Format(value) + } +} + +impl From for FormatError { + fn from(value: ArrowError) -> Self { + Self::Arrow(value) + } +} + +/// [`Display`] but accepting an index +trait DisplayIndex { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult; +} + +/// [`DisplayIndex`] with additional state +trait DisplayIndexState<'a> { + type State; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result; + + fn write(&self, state: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult; +} + +impl<'a, T: DisplayIndex> DisplayIndexState<'a> for T { + type State = (); + + fn prepare(&self, _options: &FormatOptions<'a>) -> Result { + Ok(()) + } + + fn write(&self, _: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + DisplayIndex::write(self, idx, f) + } +} + +struct ArrayFormat<'a, F: DisplayIndexState<'a>> { + state: F::State, + array: F, + null: &'a str, +} + +fn array_format<'a, F>( + array: F, + options: &FormatOptions<'a>, +) -> Result, ArrowError> +where + F: DisplayIndexState<'a> + Array + 'a, +{ + let state = array.prepare(options)?; + Ok(Box::new(ArrayFormat { + state, + array, + null: options.null, + })) +} + +impl<'a, F: DisplayIndexState<'a> + Array> DisplayIndex for ArrayFormat<'a, F> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + if !self.null.is_empty() { + f.write_str(self.null)? + } + return Ok(()); + } + DisplayIndexState::write(&self.array, &self.state, idx, f) + } +} + +impl<'a> DisplayIndex for &'a BooleanArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a NullArray { + type State = &'a str; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.null) + } + + fn write(&self, state: &Self::State, _idx: usize, f: &mut dyn Write) -> FormatResult { + f.write_str(state)?; + Ok(()) + } +} + +macro_rules! primitive_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndex for &'a PrimitiveArray<$t> + { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let mut buffer = [0u8; <$t as ArrowPrimitiveType>::Native::FORMATTED_SIZE]; + let b = lexical_core::write(value, &mut buffer); + // Lexical core produces valid UTF-8 + let s = unsafe { std::str::from_utf8_unchecked(b) }; + f.write_str(s)?; + Ok(()) + } + })+ + }; +} + +macro_rules! primitive_display_float { + ($($t:ty),+) => { + $(impl<'a> DisplayIndex for &'a PrimitiveArray<$t> + { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let mut buffer = ryu::Buffer::new(); + f.write_str(buffer.format(value))?; + Ok(()) + } + })+ + }; +} + +primitive_display!(Int8Type, Int16Type, Int32Type, Int64Type); +primitive_display!(UInt8Type, UInt16Type, UInt32Type, UInt64Type); +primitive_display_float!(Float32Type, Float64Type); + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +macro_rules! decimal_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = (u8, i8); + + fn prepare(&self, _options: &FormatOptions<'a>) -> Result { + Ok((self.precision(), self.scale())) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", <$t>::format_decimal(self.values()[idx], s.0, s.1))?; + Ok(()) + } + })+ + }; +} + +decimal_display!(Decimal128Type, Decimal256Type); + +fn write_timestamp( + f: &mut dyn Write, + naive: NaiveDateTime, + timezone: Option, + format: Option<&str>, +) -> FormatResult { + match timezone { + Some(tz) => { + let date = Utc.from_utc_datetime(&naive).with_timezone(&tz); + match format { + Some(s) => write!(f, "{}", date.format(s))?, + None => write!(f, "{}", date.to_rfc3339_opts(SecondsFormat::AutoSi, true))?, + } + } + None => match format { + Some(s) => write!(f, "{}", naive.format(s))?, + None => write!(f, "{naive:?}")?, + }, + } + Ok(()) +} + +macro_rules! timestamp_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = (Option, TimeFormat<'a>); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + match self.data_type() { + DataType::Timestamp(_, Some(tz)) => Ok((Some(tz.parse()?), options.timestamp_tz_format)), + DataType::Timestamp(_, None) => Ok((None, options.timestamp_format)), + _ => unreachable!(), + } + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let naive = as_datetime::<$t>(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to convert {} to datetime for {}", + value, + self.data_type() + )) + })?; + + write_timestamp(f, naive, s.0, s.1.clone()) + } + })+ + }; +} + +timestamp_display!( + TimestampSecondType, + TimestampMillisecondType, + TimestampMicrosecondType, + TimestampNanosecondType +); + +macro_rules! temporal_display { + ($convert:ident, $format:ident, $t:ty) => { + impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = TimeFormat<'a>; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.$format) + } + + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let naive = $convert(value as _).ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to convert {} to temporal for {}", + value, + self.data_type() + )) + })?; + + match fmt { + Some(s) => write!(f, "{}", naive.format(s))?, + None => write!(f, "{naive:?}")?, + } + Ok(()) + } + } + }; +} + +#[inline] +fn date32_to_date(value: i32) -> Option { + Some(date32_to_datetime(value)?.date()) +} + +temporal_display!(date32_to_date, date_format, Date32Type); +temporal_display!(date64_to_datetime, datetime_format, Date64Type); +temporal_display!(time32s_to_time, time_format, Time32SecondType); +temporal_display!(time32ms_to_time, time_format, Time32MillisecondType); +temporal_display!(time64us_to_time, time_format, Time64MicrosecondType); +temporal_display!(time64ns_to_time, time_format, Time64NanosecondType); + +macro_rules! duration_display { + ($convert:ident, $t:ty, $scale:tt) => { + impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = DurationFormat; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.duration_format) + } + + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + match fmt { + DurationFormat::ISO8601 => write!(f, "{}", $convert(v))?, + DurationFormat::Pretty => duration_fmt!(f, v, $scale)?, + } + Ok(()) + } + } + }; +} + +macro_rules! duration_fmt { + ($f:ident, $v:expr, 0) => {{ + let secs = $v; + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + write!($f, "{days} days {hours} hours {mins} mins {secs} secs") + }}; + ($f:ident, $v:expr, $scale:tt) => {{ + let subsec = $v; + let secs = subsec / 10_i64.pow($scale); + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let subsec = subsec - (secs * 10_i64.pow($scale)); + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + match subsec.is_negative() { + true => { + write!( + $f, + concat!("{} days {} hours {} mins -{}.{:0", $scale, "} secs"), + days, + hours, + mins, + secs.abs(), + subsec.abs() + ) + } + false => { + write!( + $f, + concat!("{} days {} hours {} mins {}.{:0", $scale, "} secs"), + days, hours, mins, secs, subsec + ) + } + } + }}; +} + +duration_display!(duration_s_to_duration, DurationSecondType, 0); +duration_display!(duration_ms_to_duration, DurationMillisecondType, 3); +duration_display!(duration_us_to_duration, DurationMicrosecondType, 6); +duration_display!(duration_ns_to_duration, DurationNanosecondType, 9); + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let interval = self.value(idx) as f64; + let years = (interval / 12_f64).floor(); + let month = interval - (years * 12_f64); + + write!(f, "{years} years {month} mons",)?; + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let mut prefix = ""; + + if value.days != 0 { + write!(f, "{prefix}{} days", value.days)?; + prefix = " "; + } + + if value.milliseconds != 0 { + let millis_fmt = MillisecondsFormatter { + milliseconds: value.milliseconds, + prefix, + }; + + f.write_fmt(format_args!("{millis_fmt}"))?; + } + + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let mut prefix = ""; + + if value.months != 0 { + write!(f, "{prefix}{} mons", value.months)?; + prefix = " "; + } + + if value.days != 0 { + write!(f, "{prefix}{} days", value.days)?; + prefix = " "; + } + + if value.nanoseconds != 0 { + let nano_fmt = NanosecondsFormatter { + nanoseconds: value.nanoseconds, + prefix, + }; + f.write_fmt(format_args!("{nano_fmt}"))?; + } + + Ok(()) + } +} + +struct NanosecondsFormatter<'a> { + nanoseconds: i64, + prefix: &'a str, +} + +impl<'a> Display for NanosecondsFormatter<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut prefix = self.prefix; + + let secs = self.nanoseconds / 1_000_000_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let nanoseconds = self.nanoseconds % 1_000_000_000; + + if hours != 0 { + write!(f, "{prefix}{} hours", hours)?; + prefix = " "; + } + + if mins != 0 { + write!(f, "{prefix}{} mins", mins)?; + prefix = " "; + } + + if secs != 0 || nanoseconds != 0 { + let secs_sign = if secs < 0 || nanoseconds < 0 { "-" } else { "" }; + write!( + f, + "{prefix}{}{}.{:09} secs", + secs_sign, + secs.abs(), + nanoseconds.abs() + )?; + } + + Ok(()) + } +} + +struct MillisecondsFormatter<'a> { + milliseconds: i32, + prefix: &'a str, +} + +impl<'a> Display for MillisecondsFormatter<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut prefix = self.prefix; + + let secs = self.milliseconds / 1_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let milliseconds = self.milliseconds % 1_000; + + if hours != 0 { + write!(f, "{prefix}{} hours", hours,)?; + prefix = " "; + } + + if mins != 0 { + write!(f, "{prefix}{} mins", mins,)?; + prefix = " "; + } + + if secs != 0 || milliseconds != 0 { + let secs_sign = if secs < 0 || milliseconds < 0 { + "-" + } else { + "" + }; + + write!( + f, + "{prefix}{}{}.{:03} secs", + secs_sign, + secs.abs(), + milliseconds.abs() + )?; + } + + Ok(()) + } +} + +impl<'a, O: OffsetSizeTrait> DisplayIndex for &'a GenericStringArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a StringViewArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +impl<'a, O: OffsetSizeTrait> DisplayIndex for &'a GenericBinaryArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + for byte in v { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a BinaryViewArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + for byte in v { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a FixedSizeBinaryArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + for byte in v { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl<'a, K: ArrowDictionaryKeyType> DisplayIndexState<'a> for &'a DictionaryArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value_idx = self.keys().values()[idx].as_usize(); + s.as_ref().write(value_idx, f) + } +} + +impl<'a, K: RunEndIndexType> DisplayIndexState<'a> for &'a RunArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value_idx = self.get_physical_index(idx); + s.as_ref().write(value_idx, f) + } +} + +fn write_list( + f: &mut dyn Write, + mut range: Range, + values: &dyn DisplayIndex, +) -> FormatResult { + f.write_char('[')?; + if let Some(idx) = range.next() { + values.write(idx, f)?; + } + for idx in range { + write!(f, ", ")?; + values.write(idx, f)?; + } + f.write_char(']')?; + Ok(()) +} + +impl<'a, O: OffsetSizeTrait> DisplayIndexState<'a> for &'a GenericListArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let offsets = self.value_offsets(); + let end = offsets[idx + 1].as_usize(); + let start = offsets[idx].as_usize(); + write_list(f, start..end, s.as_ref()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a FixedSizeListArray { + type State = (usize, Box); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let values = make_formatter(self.values().as_ref(), options)?; + let length = self.value_length(); + Ok((length as usize, values)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let start = idx * s.0; + let end = start + s.0; + write_list(f, start..end, s.1.as_ref()) + } +} + +/// Pairs a boxed [`DisplayIndex`] with its field name +type FieldDisplay<'a> = (&'a str, Box); + +impl<'a> DisplayIndexState<'a> for &'a StructArray { + type State = Vec>; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let fields = match (*self).data_type() { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + + self.columns() + .iter() + .zip(fields) + .map(|(a, f)| { + let format = make_formatter(a.as_ref(), options)?; + Ok((f.name().as_str(), format)) + }) + .collect() + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let mut iter = s.iter(); + f.write_char('{')?; + if let Some((name, display)) = iter.next() { + write!(f, "{name}: ")?; + display.as_ref().write(idx, f)?; + } + for (name, display) in iter { + write!(f, ", {name}: ")?; + display.as_ref().write(idx, f)?; + } + f.write_char('}')?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a MapArray { + type State = (Box, Box); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let keys = make_formatter(self.keys().as_ref(), options)?; + let values = make_formatter(self.values().as_ref(), options)?; + Ok((keys, values)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let offsets = self.value_offsets(); + let end = offsets[idx + 1].as_usize(); + let start = offsets[idx].as_usize(); + let mut iter = start..end; + + f.write_char('{')?; + if let Some(idx) = iter.next() { + s.0.write(idx, f)?; + write!(f, ": ")?; + s.1.write(idx, f)?; + } + + for idx in iter { + write!(f, ", ")?; + s.0.write(idx, f)?; + write!(f, ": ")?; + s.1.write(idx, f)?; + } + + f.write_char('}')?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a UnionArray { + type State = ( + Vec)>>, + UnionMode, + ); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let (fields, mode) = match (*self).data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + + let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() as usize; + let mut out: Vec> = (0..max_id + 1).map(|_| None).collect(); + for (i, field) in fields.iter() { + let formatter = make_formatter(self.child(i).as_ref(), options)?; + out[i as usize] = Some((field.name().as_str(), formatter)) + } + Ok((out, *mode)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let id = self.type_id(idx); + let idx = match s.1 { + UnionMode::Dense => self.value_offset(idx), + UnionMode::Sparse => idx, + }; + let (name, field) = s.0[id as usize].as_ref().unwrap(); + + write!(f, "{{{name}=")?; + field.write(idx, f)?; + f.write_char('}')?; + Ok(()) + } +} + +/// Get the value at the given row in an array as a String. +/// +/// Note this function is quite inefficient and is unlikely to be +/// suitable for converting large arrays or record batches. +/// +/// Please see [`ArrayFormatter`] for a more performant interface +pub fn array_value_to_string(column: &dyn Array, row: usize) -> Result { + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(column, &options)?; + Ok(formatter.value(row).to_string()) +} + +/// Converts numeric type to a `String` +pub fn lexical_to_string(n: N) -> String { + let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); + unsafe { + // JUSTIFICATION + // Benefit + // Allows using the faster serializer lexical core and convert to string + // Soundness + // Length of buf is set as written length afterwards. lexical_core + // creates a valid string, so doesn't need to be checked. + let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); + let len = lexical_core::write(n, slice).len(); + buf.set_len(len); + String::from_utf8_unchecked(buf) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::builder::StringRunBuilder; + + use super::*; + + /// Test to verify options can be constant. See #4580 + const TEST_CONST_OPTIONS: FormatOptions<'static> = FormatOptions::new() + .with_date_format(Some("foo")) + .with_timestamp_format(Some("404")); + + #[test] + fn test_const_options() { + assert_eq!(TEST_CONST_OPTIONS.date_format, Some("foo")); + } + + #[test] + fn test_map_array_to_string() { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + assert_eq!( + "{d: 30, e: 40, f: 50}", + array_value_to_string(&map_array, 1).unwrap() + ); + } + + fn format_array(array: &dyn Array, fmt: &FormatOptions) -> Vec { + let fmt = ArrayFormatter::try_new(array, fmt).unwrap(); + (0..array.len()).map(|x| fmt.value(x).to_string()).collect() + } + + #[test] + fn test_array_value_to_string_duration() { + let iso_fmt = FormatOptions::new(); + let pretty_fmt = FormatOptions::new().with_duration_format(DurationFormat::Pretty); + + let array = DurationNanosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 + 123456789, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 - 123456789, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000000001 secs"); + assert_eq!(iso[1], "-PT0.000000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000000001 secs"); + assert_eq!(iso[2], "PT0.000001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.000001000 secs"); + assert_eq!(iso[3], "-PT0.000001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.000001000 secs"); + assert_eq!(iso[4], "PT3938554.123456789S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456789 secs"); + assert_eq!(iso[5], "-PT3938554.123456789S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456789 secs"); + + let array = DurationMicrosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 + 123456, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 - 123456, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000001 secs"); + assert_eq!(iso[1], "-PT0.000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000001 secs"); + assert_eq!(iso[2], "PT0.001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.001000 secs"); + assert_eq!(iso[3], "-PT0.001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.001000 secs"); + assert_eq!(iso[4], "PT3938554.123456S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456 secs"); + assert_eq!(iso[5], "-PT3938554.123456S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456 secs"); + + let array = DurationMillisecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 + 123, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 - 123, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.001 secs"); + assert_eq!(iso[1], "-PT0.001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.001 secs"); + assert_eq!(iso[2], "PT1S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 1.000 secs"); + assert_eq!(iso[3], "-PT1S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -1.000 secs"); + assert_eq!(iso[4], "PT3938554.123S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123 secs"); + assert_eq!(iso[5], "-PT3938554.123S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123 secs"); + + let array = DurationSecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + 45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34, + -45 * 60 * 60 * 24 - 14 * 60 * 60 - 2 * 60 - 34, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT1S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 1 secs"); + assert_eq!(iso[1], "-PT1S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -1 secs"); + assert_eq!(iso[2], "PT1000S"); + assert_eq!(pretty[2], "0 days 0 hours 16 mins 40 secs"); + assert_eq!(iso[3], "-PT1000S"); + assert_eq!(pretty[3], "0 days 0 hours -16 mins -40 secs"); + assert_eq!(iso[4], "PT3938554S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34 secs"); + assert_eq!(iso[5], "-PT3938554S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34 secs"); + } + + #[test] + fn test_null() { + let array = NullArray::new(2); + let options = FormatOptions::new().with_null("NULL"); + let formatted = format_array(&array, &options); + assert_eq!(formatted, &["NULL".to_string(), "NULL".to_string()]) + } + + #[test] + fn test_string_run_arry_to_string() { + let mut builder = StringRunBuilder::::new(); + + builder.append_value("input_value"); + builder.append_value("input_value"); + builder.append_value("input_value"); + builder.append_value("input_value1"); + + let map_array = builder.finish(); + assert_eq!("input_value", array_value_to_string(&map_array, 1).unwrap()); + assert_eq!( + "input_value1", + array_value_to_string(&map_array, 3).unwrap() + ); + } +} diff --git a/arrow-cast/src/lib.rs b/arrow-cast/src/lib.rs new file mode 100644 index 000000000000..6eac1be37c88 --- /dev/null +++ b/arrow-cast/src/lib.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for converting from one data type to another in [Apache Arrow](https://docs.rs/arrow) + +#![warn(missing_docs)] +pub mod cast; +pub use cast::*; +pub mod display; +pub mod parse; +#[cfg(feature = "prettyprint")] +pub mod pretty; + +pub mod base64; diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs new file mode 100644 index 000000000000..e332e5bbaaec --- /dev/null +++ b/arrow-cast/src/parse.rs @@ -0,0 +1,2772 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Parser`] implementations for converting strings to Arrow types +//! +//! Used by the CSV and JSON readers to convert strings to Arrow types +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::ArrowNativeTypeOp; +use arrow_buffer::ArrowNativeType; +use arrow_schema::ArrowError; +use chrono::prelude::*; +use half::f16; +use std::str::FromStr; + +/// Parse nanoseconds from the first `N` values in digits, subtracting the offset `O` +#[inline] +fn parse_nanos(digits: &[u8]) -> u32 { + digits[..N] + .iter() + .fold(0_u32, |acc, v| acc * 10 + v.wrapping_sub(O) as u32) + * 10_u32.pow((9 - N) as _) +} + +/// Helper for parsing RFC3339 timestamps +struct TimestampParser { + /// The timestamp bytes to parse minus `b'0'` + /// + /// This makes interpretation as an integer inexpensive + digits: [u8; 32], + /// A mask containing a `1` bit where the corresponding byte is a valid ASCII digit + mask: u32, +} + +impl TimestampParser { + fn new(bytes: &[u8]) -> Self { + let mut digits = [0; 32]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(bytes).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u32) << idx + } + + Self { digits, mask } + } + + /// Returns true if the byte at `idx` in the original string equals `b` + fn test(&self, idx: usize, b: u8) -> bool { + self.digits[idx] == b.wrapping_sub(b'0') + } + + /// Parses a date of the form `1997-01-31` + fn date(&self) -> Option { + if self.mask & 0b1111111111 != 0b1101101111 || !self.test(4, b'-') || !self.test(7, b'-') { + return None; + } + + let year = self.digits[0] as u16 * 1000 + + self.digits[1] as u16 * 100 + + self.digits[2] as u16 * 10 + + self.digits[3] as u16; + + let month = self.digits[5] * 10 + self.digits[6]; + let day = self.digits[8] * 10 + self.digits[9]; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) + } + + /// Parses a time of any of forms + /// - `09:26:56` + /// - `09:26:56.123` + /// - `09:26:56.123456` + /// - `09:26:56.123456789` + /// - `092656` + /// + /// Returning the end byte offset + fn time(&self) -> Option<(NaiveTime, usize)> { + // Make a NaiveTime handling leap seconds + let time = |hour, min, sec, nano| match sec { + 60 => { + let nano = 1_000_000_000 + nano; + NaiveTime::from_hms_nano_opt(hour as _, min as _, 59, nano) + } + _ => NaiveTime::from_hms_nano_opt(hour as _, min as _, sec as _, nano), + }; + + match (self.mask >> 11) & 0b11111111 { + // 09:26:56 + 0b11011011 if self.test(13, b':') && self.test(16, b':') => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[14] * 10 + self.digits[15]; + let second = self.digits[17] * 10 + self.digits[18]; + + match self.test(19, b'.') { + true => { + let digits = (self.mask >> 20).trailing_ones(); + let nanos = match digits { + 0 => return None, + 1 => parse_nanos::<1, 0>(&self.digits[20..21]), + 2 => parse_nanos::<2, 0>(&self.digits[20..22]), + 3 => parse_nanos::<3, 0>(&self.digits[20..23]), + 4 => parse_nanos::<4, 0>(&self.digits[20..24]), + 5 => parse_nanos::<5, 0>(&self.digits[20..25]), + 6 => parse_nanos::<6, 0>(&self.digits[20..26]), + 7 => parse_nanos::<7, 0>(&self.digits[20..27]), + 8 => parse_nanos::<8, 0>(&self.digits[20..28]), + _ => parse_nanos::<9, 0>(&self.digits[20..29]), + }; + Some((time(hour, minute, second, nanos)?, 20 + digits as usize)) + } + false => Some((time(hour, minute, second, 0)?, 19)), + } + } + // 092656 + 0b111111 => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[13] * 10 + self.digits[14]; + let second = self.digits[15] * 10 + self.digits[16]; + let time = time(hour, minute, second, 0)?; + Some((time, 17)) + } + _ => None, + } + } +} + +/// Accepts a string and parses it relative to the provided `timezone` +/// +/// In addition to RFC3339 / ISO8601 standard timestamps, it also +/// accepts strings that use a space ` ` to separate the date and time +/// as well as strings that have no explicit timezone offset. +/// +/// Examples of accepted inputs: +/// * `1997-01-31T09:26:56.123Z` # RCF3339 +/// * `1997-01-31T09:26:56.123-05:00` # RCF3339 +/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T +/// * `2023-01-01 04:05:06.789 -08` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified +/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset +/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656+04:00` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31` # close to RCF3339, only date no time +/// +/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled +/// +/// * `2023-01-01 040506 America/Los_Angeles` +/// +/// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error +/// will be returned +/// +/// Some formats supported by PostgresSql +/// are not supported, like +/// +/// * "2023-01-01 04:05:06.789 +07:30:00", +/// * "2023-01-01 040506 +07:30:00", +/// * "2023-01-01 04:05:06.789 PST", +/// +/// [IANA timezones]: https://www.iana.org/time-zones +pub fn string_to_datetime(timezone: &T, s: &str) -> Result, ArrowError> { + let err = + |ctx: &str| ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")); + + let bytes = s.as_bytes(); + if bytes.len() < 10 { + return Err(err("timestamp must contain at least 10 characters")); + } + + let parser = TimestampParser::new(bytes); + let date = parser.date().ok_or_else(|| err("error parsing date"))?; + if bytes.len() == 10 { + let datetime = date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap()); + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if !parser.test(10, b'T') && !parser.test(10, b't') && !parser.test(10, b' ') { + return Err(err("invalid timestamp separator")); + } + + let (time, mut tz_offset) = parser.time().ok_or_else(|| err("error parsing time"))?; + let datetime = date.and_time(time); + + if tz_offset == 32 { + // Decimal overrun + while tz_offset < bytes.len() && bytes[tz_offset].is_ascii_digit() { + tz_offset += 1; + } + } + + if bytes.len() <= tz_offset { + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if (bytes[tz_offset] == b'z' || bytes[tz_offset] == b'Z') && tz_offset == bytes.len() - 1 { + return Ok(timezone.from_utc_datetime(&datetime)); + } + + // Parse remainder of string as timezone + let parsed_tz: Tz = s[tz_offset..].trim_start().parse()?; + let parsed = parsed_tz + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset"))?; + + Ok(parsed.with_timezone(timezone)) +} + +/// Accepts a string in RFC3339 / ISO8601 standard format and some +/// variants and converts it to a nanosecond precision timestamp. +/// +/// See [`string_to_datetime`] for the full set of supported formats +/// +/// Implements the `to_timestamp` function to convert a string to a +/// timestamp, following the model of spark SQL’s to_`timestamp`. +/// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// We hope to extend this function in the future with a second +/// parameter to specifying the format string. +/// +/// ## Timestamp Precision +/// +/// Function uses the maximum precision timestamps supported by +/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This +/// means the range of dates that timestamps can represent is ~1677 AD +/// to 2262 AM +/// +/// ## Timezone / Offset Handling +/// +/// Numerical values of timestamps are stored compared to offset UTC. +/// +/// This function interprets string without an explicit time zone as timestamps +/// relative to UTC, see [`string_to_datetime`] for alternative semantics +/// +/// In particular: +/// +/// ``` +/// # use arrow_cast::parse::string_to_timestamp_nanos; +/// // Note all three of these timestamps are parsed as the same value +/// let a = string_to_timestamp_nanos("1997-01-31 09:26:56.123Z").unwrap(); +/// let b = string_to_timestamp_nanos("1997-01-31T09:26:56.123").unwrap(); +/// let c = string_to_timestamp_nanos("1997-01-31T14:26:56.123+05:00").unwrap(); +/// +/// assert_eq!(a, b); +/// assert_eq!(b, c); +/// ``` +/// +#[inline] +pub fn string_to_timestamp_nanos(s: &str) -> Result { + to_timestamp_nanos(string_to_datetime(&Utc, s)?.naive_utc()) +} + +/// Fallible conversion of [`NaiveDateTime`] to `i64` nanoseconds +#[inline] +fn to_timestamp_nanos(dt: NaiveDateTime) -> Result { + dt.and_utc() + .timestamp_nanos_opt() + .ok_or_else(|| ArrowError::ParseError(ERR_NANOSECONDS_NOT_SUPPORTED.to_string())) +} + +/// Accepts a string in ISO8601 standard format and some +/// variants and converts it to nanoseconds since midnight. +/// +/// Examples of accepted inputs: +/// +/// * `09:26:56.123 AM` +/// * `23:59:59` +/// * `6:00 pm` +/// +/// Internally, this function uses the `chrono` library for the time parsing +/// +/// ## Timezone / Offset Handling +/// +/// This function does not support parsing strings with a timezone +/// or offset specified, as it considers only time since midnight. +pub fn string_to_time_nanoseconds(s: &str) -> Result { + let nt = string_to_time(s) + .ok_or_else(|| ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")))?; + Ok(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) +} + +fn string_to_time(s: &str) -> Option { + let bytes = s.as_bytes(); + if bytes.len() < 4 { + return None; + } + + let (am, bytes) = match bytes.get(bytes.len() - 3..) { + Some(b" AM" | b" am" | b" Am" | b" aM") => (Some(true), &bytes[..bytes.len() - 3]), + Some(b" PM" | b" pm" | b" pM" | b" Pm") => (Some(false), &bytes[..bytes.len() - 3]), + _ => (None, bytes), + }; + + if bytes.len() < 4 { + return None; + } + + let mut digits = [b'0'; 6]; + + // Extract hour + let bytes = match (bytes[1], bytes[2]) { + (b':', _) => { + digits[1] = bytes[0]; + &bytes[2..] + } + (_, b':') => { + digits[0] = bytes[0]; + digits[1] = bytes[1]; + &bytes[3..] + } + _ => return None, + }; + + if bytes.len() < 2 { + return None; // Minutes required + } + + // Extract minutes + digits[2] = bytes[0]; + digits[3] = bytes[1]; + + let nanoseconds = match bytes.get(2) { + Some(b':') => { + if bytes.len() < 5 { + return None; + } + + // Extract seconds + digits[4] = bytes[3]; + digits[5] = bytes[4]; + + // Extract sub-seconds if any + match bytes.get(5) { + Some(b'.') => { + let decimal = &bytes[6..]; + if decimal.iter().any(|x| !x.is_ascii_digit()) { + return None; + } + match decimal.len() { + 0 => return None, + 1 => parse_nanos::<1, b'0'>(decimal), + 2 => parse_nanos::<2, b'0'>(decimal), + 3 => parse_nanos::<3, b'0'>(decimal), + 4 => parse_nanos::<4, b'0'>(decimal), + 5 => parse_nanos::<5, b'0'>(decimal), + 6 => parse_nanos::<6, b'0'>(decimal), + 7 => parse_nanos::<7, b'0'>(decimal), + 8 => parse_nanos::<8, b'0'>(decimal), + _ => parse_nanos::<9, b'0'>(decimal), + } + } + Some(_) => return None, + None => 0, + } + } + Some(_) => return None, + None => 0, + }; + + digits.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); + if digits.iter().any(|x| *x > 9) { + return None; + } + + let hour = match (digits[0] * 10 + digits[1], am) { + (12, Some(true)) => 0, // 12:00 AM -> 00:00 + (h @ 1..=11, Some(true)) => h, // 1:00 AM -> 01:00 + (12, Some(false)) => 12, // 12:00 PM -> 12:00 + (h @ 1..=11, Some(false)) => h + 12, // 1:00 PM -> 13:00 + (_, Some(_)) => return None, + (h, None) => h, + }; + + // Handle leap second + let (second, nanoseconds) = match digits[4] * 10 + digits[5] { + 60 => (59, nanoseconds + 1_000_000_000), + s => (s, nanoseconds), + }; + + NaiveTime::from_hms_nano_opt( + hour as _, + (digits[2] * 10 + digits[3]) as _, + second as _, + nanoseconds, + ) +} + +/// Specialized parsing implementations to convert strings to Arrow types. +/// +/// This is used by csv and json reader and can be used directly as well. +/// +/// # Example +/// +/// To parse a string to a [`Date32Type`]: +/// +/// ``` +/// use arrow_cast::parse::Parser; +/// use arrow_array::types::Date32Type; +/// let date = Date32Type::parse("2021-01-01").unwrap(); +/// assert_eq!(date, 18628); +/// ``` +/// +/// To parse a string to a [`TimestampNanosecondType`]: +/// +/// ``` +/// use arrow_cast::parse::Parser; +/// use arrow_array::types::TimestampNanosecondType; +/// let ts = TimestampNanosecondType::parse("2021-01-01T00:00:00.123456789Z").unwrap(); +/// assert_eq!(ts, 1609459200123456789); +/// ``` +pub trait Parser: ArrowPrimitiveType { + /// Parse a string to the native type + fn parse(string: &str) -> Option; + + /// Parse a string to the native type with a format string + /// + /// When not implemented, the format string is unused, and this method is equivalent to [parse](#tymethod.parse) + fn parse_formatted(string: &str, _format: &str) -> Option { + Self::parse(string) + } +} + +impl Parser for Float16Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()) + .ok() + .map(f16::from_f32) + } +} + +impl Parser for Float32Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +impl Parser for Float64Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +/// This API is only stable since 1.70 so can't use it when current MSRV is lower +#[inline(always)] +fn is_some_and(opt: Option, f: impl FnOnce(T) -> bool) -> bool { + match opt { + None => false, + Some(x) => f(x), + } +} + +macro_rules! parser_primitive { + ($t:ty) => { + impl Parser for $t { + fn parse(string: &str) -> Option { + if !is_some_and(string.as_bytes().last(), |x| x.is_ascii_digit()) { + return None; + } + match atoi::FromRadix10SignedChecked::from_radix_10_signed_checked( + string.as_bytes(), + ) { + (Some(n), x) if x == string.len() => Some(n), + _ => None, + } + } + } + }; +} +parser_primitive!(UInt64Type); +parser_primitive!(UInt32Type); +parser_primitive!(UInt16Type); +parser_primitive!(UInt8Type); +parser_primitive!(Int64Type); +parser_primitive!(Int32Type); +parser_primitive!(Int16Type); +parser_primitive!(Int8Type); + +impl Parser for TimestampNanosecondType { + fn parse(string: &str) -> Option { + string_to_timestamp_nanos(string).ok() + } +} + +impl Parser for TimestampMicrosecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1000) + } +} + +impl Parser for TimestampMillisecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000) + } +} + +impl Parser for TimestampSecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000_000) + } +} + +impl Parser for Time64NanosecondType { + // Will truncate any fractions of a nanosecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) + } +} + +impl Parser for Time64MicrosecondType { + // Will truncate any fractions of a microsecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| nanos / 1_000) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000 + nt.nanosecond() as i64 / 1_000) + } +} + +impl Parser for Time32MillisecondType { + // Will truncate any fractions of a millisecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i32 * 1_000 + nt.nanosecond() as i32 / 1_000_000) + } +} + +impl Parser for Time32SecondType { + // Will truncate any fractions of a second + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i32 + nt.nanosecond() as i32 / 1_000_000_000) + } +} + +/// Number of days between 0001-01-01 and 1970-01-01 +const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// Error message if nanosecond conversion request beyond supported interval +const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; + +fn parse_date(string: &str) -> Option { + if string.len() > 10 { + // Try to parse as datetime and return just the date part + return string_to_datetime(&Utc, string) + .map(|dt| dt.date_naive()) + .ok(); + }; + let mut digits = [0; 10]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(string.bytes()).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u16) << idx + } + + const HYPHEN: u8 = b'-'.wrapping_sub(b'0'); + + // refer to https://www.rfc-editor.org/rfc/rfc3339#section-3 + if digits[4] != HYPHEN { + let (year, month, day) = match (mask, string.len()) { + (0b11111111, 8) => ( + digits[0] as u16 * 1000 + + digits[1] as u16 * 100 + + digits[2] as u16 * 10 + + digits[3] as u16, + digits[4] * 10 + digits[5], + digits[6] * 10 + digits[7], + ), + _ => return None, + }; + return NaiveDate::from_ymd_opt(year as _, month as _, day as _); + } + + let (month, day) = match mask { + 0b1101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8] * 10 + digits[9]) + } + 0b101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8]) + } + 0b110101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7] * 10 + digits[8]) + } + 0b10101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7]) + } + _ => return None, + }; + + let year = + digits[0] as u16 * 1000 + digits[1] as u16 * 100 + digits[2] as u16 * 10 + digits[3] as u16; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) +} + +impl Parser for Date32Type { + fn parse(string: &str) -> Option { + let date = parse_date(string)?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let date = NaiveDate::parse_from_str(string, format).ok()?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } +} + +impl Parser for Date64Type { + fn parse(string: &str) -> Option { + if string.len() <= 10 { + let datetime = NaiveDateTime::new(parse_date(string)?, NaiveTime::default()); + Some(datetime.and_utc().timestamp_millis()) + } else { + let date_time = string_to_datetime(&Utc, string).ok()?; + Some(date_time.timestamp_millis()) + } + } + + fn parse_formatted(string: &str, format: &str) -> Option { + use chrono::format::Fixed; + use chrono::format::StrftimeItems; + let fmt = StrftimeItems::new(format); + let has_zone = fmt.into_iter().any(|item| match item { + chrono::format::Item::Fixed(fixed_item) => matches!( + fixed_item, + Fixed::RFC2822 + | Fixed::RFC3339 + | Fixed::TimezoneName + | Fixed::TimezoneOffsetColon + | Fixed::TimezoneOffsetColonZ + | Fixed::TimezoneOffset + | Fixed::TimezoneOffsetZ + ), + _ => false, + }); + if has_zone { + let date_time = chrono::DateTime::parse_from_str(string, format).ok()?; + Some(date_time.timestamp_millis()) + } else { + let date_time = NaiveDateTime::parse_from_str(string, format).ok()?; + Some(date_time.and_utc().timestamp_millis()) + } + } +} + +fn parse_e_notation( + s: &str, + mut digits: u16, + mut fractionals: i16, + mut result: T::Native, + index: usize, + precision: u16, + scale: i16, +) -> Result { + let mut exp: i16 = 0; + let base = T::Native::usize_as(10); + + let mut exp_start: bool = false; + // e has a plus sign + let mut pos_shift_direction: bool = true; + + // skip to point or exponent index + let mut bs; + if fractionals > 0 { + // it's a fraction, so the point index needs to be skipped, so +1 + bs = s.as_bytes().iter().skip(index + fractionals as usize + 1); + } else { + // it's actually an integer that is already written into the result, so let's skip on to e + bs = s.as_bytes().iter().skip(index); + } + + while let Some(b) = bs.next() { + match b { + b'0'..=b'9' => { + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + if fractionals > 0 { + fractionals += 1; + } + digits += 1; + } + &b'e' | &b'E' => { + exp_start = true; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + }; + + if exp_start { + pos_shift_direction = match bs.next() { + Some(&b'-') => false, + Some(&b'+') => true, + Some(b) => { + if !b.is_ascii_digit() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + + exp *= 10; + exp += (b - b'0') as i16; + + true + } + None => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))) + } + }; + + for b in bs.by_ref() { + if !b.is_ascii_digit() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + exp *= 10; + exp += (b - b'0') as i16; + } + } + } + + if digits == 0 && fractionals == 0 && exp == 0 { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + + if !pos_shift_direction { + // exponent has a large negative sign + // 1.12345e-30 => 0.0{29}12345, scale = 5 + if exp - (digits as i16 + scale) > 0 { + return Ok(T::Native::usize_as(0)); + } + exp *= -1; + } + + // point offset + exp = fractionals - exp; + // We have zeros on the left, we need to count them + if !pos_shift_direction && exp > digits as i16 { + digits = exp as u16; + } + // Number of numbers to be removed or added + exp = scale - exp; + + if (digits as i16 + exp) as u16 > precision { + return Err(ArrowError::ParseError(format!( + "parse decimal overflow ({s})" + ))); + } + + if exp < 0 { + result = result.div_wrapping(base.pow_wrapping(-exp as _)); + } else { + result = result.mul_wrapping(base.pow_wrapping(exp as _)); + } + + Ok(result) +} + +/// Parse the string format decimal value to i128/i256 format and checking the precision and scale. +/// The result value can't be out of bounds. +pub fn parse_decimal( + s: &str, + precision: u8, + scale: i8, +) -> Result { + let mut result = T::Native::usize_as(0); + let mut fractionals: i8 = 0; + let mut digits: u8 = 0; + let base = T::Native::usize_as(10); + + let bs = s.as_bytes(); + let (bs, negative) = match bs.first() { + Some(b'-') => (&bs[1..], true), + Some(b'+') => (&bs[1..], false), + _ => (bs, false), + }; + + if bs.is_empty() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + + let mut bs = bs.iter().enumerate(); + + let mut is_e_notation = false; + + // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds. + // Thus, if we validate the precision correctly, we can skip overflow checks. + while let Some((index, b)) = bs.next() { + match b { + b'0'..=b'9' => { + if digits == 0 && *b == b'0' { + // Ignore leading zeros. + continue; + } + digits += 1; + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + b'.' => { + let point_index = index; + + for (_, b) in bs.by_ref() { + if !b.is_ascii_digit() { + if *b == b'e' || *b == b'E' { + result = match parse_e_notation::( + s, + digits as u16, + fractionals as i16, + result, + point_index, + precision as u16, + scale as i16, + ) { + Err(e) => return Err(e), + Ok(v) => v, + }; + + is_e_notation = true; + + break; + } + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + if fractionals == scale && scale != 0 { + // We have processed all the digits that we need. All that + // is left is to validate that the rest of the string contains + // valid digits. + continue; + } + fractionals += 1; + digits += 1; + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + + if is_e_notation { + break; + } + + // Fail on "." + if digits == 0 { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + b'e' | b'E' => { + result = match parse_e_notation::( + s, + digits as u16, + fractionals as i16, + result, + index, + precision as u16, + scale as i16, + ) { + Err(e) => return Err(e), + Ok(v) => v, + }; + + is_e_notation = true; + + break; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + } + + if !is_e_notation { + if fractionals < scale { + let exp = scale - fractionals; + if exp as u8 + digits > precision { + return Err(ArrowError::ParseError(format!( + "parse decimal overflow ({s})" + ))); + } + let mul = base.pow_wrapping(exp as _); + result = result.mul_wrapping(mul); + } else if digits > precision { + return Err(ArrowError::ParseError(format!( + "parse decimal overflow ({s})" + ))); + } + } + + Ok(if negative { + result.neg_wrapping() + } else { + result + }) +} + +/// Parse human-readable interval string to Arrow [IntervalYearMonthType] +pub fn parse_interval_year_month( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Year); + let interval = Interval::parse(value, &config)?; + + let months = interval.to_year_months().map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast {value} to IntervalYearMonth. Only year and month fields are allowed." + )) + })?; + + Ok(IntervalYearMonthType::make_value(0, months)) +} + +/// Parse human-readable interval string to Arrow [IntervalDayTimeType] +pub fn parse_interval_day_time( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Day); + let interval = Interval::parse(value, &config)?; + + let (days, millis) = interval.to_day_time().map_err(|_| ArrowError::CastError(format!( + "Cannot cast {value} to IntervalDayTime because the nanos part isn't multiple of milliseconds" + )))?; + + Ok(IntervalDayTimeType::make_value(days, millis)) +} + +/// Parse human-readable interval string to Arrow [IntervalMonthDayNanoType] +pub fn parse_interval_month_day_nano_config( + value: &str, + config: IntervalParseConfig, +) -> Result<::Native, ArrowError> { + let interval = Interval::parse(value, &config)?; + + let (months, days, nanos) = interval.to_month_day_nanos(); + + Ok(IntervalMonthDayNanoType::make_value(months, days, nanos)) +} + +/// Parse human-readable interval string to Arrow [IntervalMonthDayNanoType] +pub fn parse_interval_month_day_nano( + value: &str, +) -> Result<::Native, ArrowError> { + parse_interval_month_day_nano_config(value, IntervalParseConfig::new(IntervalUnit::Month)) +} + +const NANOS_PER_MILLIS: i64 = 1_000_000; +const NANOS_PER_SECOND: i64 = 1_000 * NANOS_PER_MILLIS; +const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; +#[cfg(test)] +const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR; + +/// Config to parse interval strings +/// +/// Currently stores the `default_unit` to use if the string doesn't have one specified +#[derive(Debug, Clone)] +pub struct IntervalParseConfig { + /// The default unit to use if none is specified + /// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = [IntervalUnit::Second] + default_unit: IntervalUnit, +} + +impl IntervalParseConfig { + /// Create a new [IntervalParseConfig] with the given default unit + pub fn new(default_unit: IntervalUnit) -> Self { + Self { default_unit } + } +} + +#[rustfmt::skip] +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +/// Represents the units of an interval, with each variant +/// corresponding to a bit in the interval's bitfield representation +pub enum IntervalUnit { + /// A Century + Century = 0b_0000_0000_0001, + /// A Decade + Decade = 0b_0000_0000_0010, + /// A Year + Year = 0b_0000_0000_0100, + /// A Month + Month = 0b_0000_0000_1000, + /// A Week + Week = 0b_0000_0001_0000, + /// A Day + Day = 0b_0000_0010_0000, + /// An Hour + Hour = 0b_0000_0100_0000, + /// A Minute + Minute = 0b_0000_1000_0000, + /// A Second + Second = 0b_0001_0000_0000, + /// A Millisecond + Millisecond = 0b_0010_0000_0000, + /// A Microsecond + Microsecond = 0b_0100_0000_0000, + /// A Nanosecond + Nanosecond = 0b_1000_0000_0000, +} + +/// Logic for parsing interval unit strings +/// +/// See +/// for a list of unit names supported by PostgreSQL which we try to match here. +impl FromStr for IntervalUnit { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "c" | "cent" | "cents" | "century" | "centuries" => Ok(Self::Century), + "dec" | "decs" | "decade" | "decades" => Ok(Self::Decade), + "y" | "yr" | "yrs" | "year" | "years" => Ok(Self::Year), + "mon" | "mons" | "month" | "months" => Ok(Self::Month), + "w" | "week" | "weeks" => Ok(Self::Week), + "d" | "day" | "days" => Ok(Self::Day), + "h" | "hr" | "hrs" | "hour" | "hours" => Ok(Self::Hour), + "m" | "min" | "mins" | "minute" | "minutes" => Ok(Self::Minute), + "s" | "sec" | "secs" | "second" | "seconds" => Ok(Self::Second), + "ms" | "msec" | "msecs" | "msecond" | "mseconds" | "millisecond" | "milliseconds" => { + Ok(Self::Millisecond) + } + "us" | "usec" | "usecs" | "usecond" | "useconds" | "microsecond" | "microseconds" => { + Ok(Self::Microsecond) + } + "nanosecond" | "nanoseconds" => Ok(Self::Nanosecond), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unknown interval type: {s}" + ))), + } + } +} + +impl IntervalUnit { + fn from_str_or_config( + s: Option<&str>, + config: &IntervalParseConfig, + ) -> Result { + match s { + Some(s) => s.parse(), + None => Ok(config.default_unit), + } + } +} + +/// A tuple representing (months, days, nanoseconds) in an interval +pub type MonthDayNano = (i32, i32, i64); + +/// Chosen based on the number of decimal digits in 1 week in nanoseconds +const INTERVAL_PRECISION: u32 = 15; + +#[derive(Clone, Copy, Debug, PartialEq)] +struct IntervalAmount { + /// The integer component of the interval amount + integer: i64, + /// The fractional component multiplied by 10^INTERVAL_PRECISION + frac: i64, +} + +#[cfg(test)] +impl IntervalAmount { + fn new(integer: i64, frac: i64) -> Self { + Self { integer, frac } + } +} + +impl FromStr for IntervalAmount { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.split_once('.') { + Some((integer, frac)) + if frac.len() <= INTERVAL_PRECISION as usize + && !frac.is_empty() + && !frac.starts_with('-') => + { + // integer will be "" for values like ".5" + // and "-" for values like "-.5" + let explicit_neg = integer.starts_with('-'); + let integer = if integer.is_empty() || integer == "-" { + Ok(0) + } else { + integer.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + }) + }?; + + let frac_unscaled = frac.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + })?; + + // scale fractional part by interval precision + let frac = frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); + + // propagate the sign of the integer part to the fractional part + let frac = if integer < 0 || explicit_neg { + -frac + } else { + frac + }; + + let result = Self { integer, frac }; + + Ok(result) + } + Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + ))), + Some((_, frac)) if frac.len() > INTERVAL_PRECISION as usize => { + Err(ArrowError::ParseError(format!( + "{s} exceeds the precision available for interval amount" + ))) + } + Some(_) | None => { + let integer = s.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + })?; + + let result = Self { integer, frac: 0 }; + Ok(result) + } + } + } +} + +#[derive(Debug, Default, PartialEq)] +struct Interval { + months: i32, + days: i32, + nanos: i64, +} + +impl Interval { + fn new(months: i32, days: i32, nanos: i64) -> Self { + Self { + months, + days, + nanos, + } + } + + fn to_year_months(&self) -> Result { + match (self.months, self.days, self.nanos) { + (months, days, nanos) if days == 0 && nanos == 0 => Ok(months), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent interval with days and nanos as year-months: {:?}", + self + ))), + } + } + + fn to_day_time(&self) -> Result<(i32, i32), ArrowError> { + let days = self.months.mul_checked(30)?.add_checked(self.days)?; + + match self.nanos { + nanos if nanos % NANOS_PER_MILLIS == 0 => { + let millis = (self.nanos / 1_000_000).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} nanos as milliseconds in a signed 32-bit integer", + self.nanos + )) + })?; + + Ok((days, millis)) + } + nanos => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent {nanos} as milliseconds" + ))), + } + } + + fn to_month_day_nanos(&self) -> (i32, i32, i64) { + (self.months, self.days, self.nanos) + } + + /// Parse string value in traditional Postgres format such as + /// `1 year 2 months 3 days 4 hours 5 minutes 6 seconds` + fn parse(value: &str, config: &IntervalParseConfig) -> Result { + let components = parse_interval_components(value, config)?; + + components + .into_iter() + .try_fold(Self::default(), |result, (amount, unit)| { + result.add(amount, unit) + }) + } + + /// Interval addition following Postgres behavior. Fractional units will be spilled into smaller units. + /// When the interval unit is larger than months, the result is rounded to total months and not spilled to days/nanos. + /// Fractional parts of weeks and days are represented using days and nanoseconds. + /// e.g. INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days + /// e.g. INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours + /// [Postgres reference](https://www.postgresql.org/docs/15/datatype-datetime.html#DATATYPE-INTERVAL-INPUT:~:text=Field%20values%20can,fractional%20on%20output.) + fn add(&self, amount: IntervalAmount, unit: IntervalUnit) -> Result { + let result = match unit { + IntervalUnit::Century => { + let months_int = amount.integer.mul_checked(100)?.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 2); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} centuries as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Decade => { + let months_int = amount.integer.mul_checked(10)?.mul_checked(12)?; + + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 1); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} decades as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Year => { + let months_int = amount.integer.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} years as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Month => { + let months = amount.integer.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months in a signed 32-bit integer", + &amount.integer + )) + })?; + + let days = amount.frac * 3 / 10_i64.pow(INTERVAL_PRECISION - 1); + let days = days.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months as days in a signed 32-bit integer", + amount.frac / 10_i64.pow(INTERVAL_PRECISION) + )) + })?; + + Self::new( + self.months.add_checked(months)?, + self.days.add_checked(days)?, + self.nanos, + ) + } + IntervalUnit::Week => { + let days = amount.integer.mul_checked(7)?.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} weeks as days in a signed 32-bit integer", + &amount.integer + )) + })?; + + let nanos = amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Day => { + let days = amount.integer.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} days in a signed 32-bit integer", + amount.integer + )) + })?; + + let nanos = amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Hour => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_HOUR)?; + let nanos_frac = amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Minute => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MINUTE)?; + let nanos_frac = amount.frac * 6 / 10_i64.pow(INTERVAL_PRECISION - 10); + + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Second => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_SECOND)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 9); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Millisecond => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MILLIS)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 6); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Microsecond => { + let nanos_int = amount.integer.mul_checked(1_000)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 3); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Nanosecond => { + let nanos_int = amount.integer; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + }; + + Ok(result) + } +} + +/// parse the string into a vector of interval components i.e. (amount, unit) tuples +fn parse_interval_components( + value: &str, + config: &IntervalParseConfig, +) -> Result, ArrowError> { + let raw_pairs = split_interval_components(value); + + // parse amounts and units + let Ok(pairs): Result, ArrowError> = raw_pairs + .iter() + .map(|(a, u)| Ok((a.parse()?, IntervalUnit::from_str_or_config(*u, config)?))) + .collect() + else { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {value:?}" + ))); + }; + + // collect parsed results + let (amounts, units): (Vec<_>, Vec<_>) = pairs.into_iter().unzip(); + + // duplicate units? + let mut observed_interval_types = 0; + for (unit, (_, raw_unit)) in units.iter().zip(raw_pairs) { + if observed_interval_types & (*unit as u16) != 0 { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {:?}. Repeated type '{}'", + value, + raw_unit.unwrap_or_default(), + ))); + } + + observed_interval_types |= *unit as u16; + } + + let result = amounts.iter().copied().zip(units.iter().copied()); + + Ok(result.collect::>()) +} + +/// Split an interval into a vec of amounts and units. +/// +/// Pairs are separated by spaces, but within a pair the amount and unit may or may not be separated by a space. +/// +/// This should match the behavior of PostgreSQL's interval parser. +fn split_interval_components(value: &str) -> Vec<(&str, Option<&str>)> { + let mut result = vec![]; + let mut words = value.split(char::is_whitespace); + while let Some(word) = words.next() { + if let Some(split_word_at) = word.find(not_interval_amount) { + let (amount, unit) = word.split_at(split_word_at); + result.push((amount, Some(unit))); + } else if let Some(unit) = words.next() { + result.push((word, Some(unit))); + } else { + result.push((word, None)); + break; + } + } + result +} + +/// test if a character is NOT part of an interval numeric amount +fn not_interval_amount(c: char) -> bool { + !c.is_ascii_digit() && c != '.' && c != '-' +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::date32_to_datetime; + use arrow_buffer::i256; + + #[test] + fn test_parse_nanos() { + assert_eq!(parse_nanos::<3, 0>(&[1, 2, 3]), 123_000_000); + assert_eq!(parse_nanos::<5, 0>(&[1, 2, 3, 4, 5]), 123_450_000); + assert_eq!(parse_nanos::<6, b'0'>(b"123456"), 123_456_000); + } + + #[test] + fn string_to_timestamp_timezone() { + // Explicit timezone + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08T13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08T13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + fn string_to_timestamp_timezone_space() { + // Ensure space rather than T between time and date is accepted + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08 13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08 13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function: mktime + fn string_to_timestamp_no_timezone() { + // This test is designed to succeed in regardless of the local + // timezone the test machine is running. Thus it is still + // somewhat susceptible to bugs in the use of chrono + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + // Also ensure that parsing timestamps with no fractional + // second part works as well + let datetime_whole_secs = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(13, 42, 29).unwrap(), + ) + .and_utc(); + + // Ensure both T and ' ' variants work + assert_eq!( + datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + // ensure without time work + // no time, should be the nano second at + // 2020-09-08 0:0:0 + let datetime_no_time = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + ) + .and_utc(); + + assert_eq!( + datetime_no_time.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08").unwrap() + ) + } + + #[test] + fn string_to_timestamp_chrono() { + let cases = [ + "2020-09-08T13:42:29Z", + "1969-01-01T00:00:00.1Z", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12+00:00", + "2020-09-08T12:00:12.1+00:00", + "2020-09-08T12:00:12.12+00:00", + "2020-09-08T12:00:12.123+00:00", + "2020-09-08T12:00:12.1234+00:00", + "2020-09-08T12:00:12.12345+00:00", + "2020-09-08T12:00:12.123456+00:00", + "2020-09-08T12:00:12.1234567+00:00", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12.123456789+00:00", + "2020-09-08T12:00:12.12345678912z", + "2020-09-08T12:00:12.123456789123Z", + "2020-09-08T12:00:12.123456789123+02:00", + "2020-09-08T12:00:12.12345678912345Z", + "2020-09-08T12:00:12.1234567891234567+02:00", + "2020-09-08T12:00:60Z", + "2020-09-08T12:00:60.123Z", + "2020-09-08T12:00:60.123456+02:00", + "2020-09-08T12:00:60.1234567891234567+02:00", + "2020-09-08T12:00:60.999999999+02:00", + "2020-09-08t12:00:12.12345678+00:00", + "2020-09-08t12:00:12+00:00", + "2020-09-08t12:00:12Z", + ]; + + for case in cases { + let chrono = DateTime::parse_from_rfc3339(case).unwrap(); + let chrono_utc = chrono.with_timezone(&Utc); + + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono_utc, custom) + } + } + + #[test] + fn string_to_timestamp_naive() { + let cases = [ + "2018-11-13T17:11:10.011375885995", + "2030-12-04T17:11:10.123", + "2030-12-04T17:11:10.1234", + "2030-12-04T17:11:10.123456", + ]; + for case in cases { + let chrono = NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono, custom.naive_utc()) + } + } + + #[test] + fn string_to_timestamp_invalid() { + // Test parsing invalid formats + let cases = [ + ("", "timestamp must contain at least 10 characters"), + ("SS", "timestamp must contain at least 10 characters"), + ("Wed, 18 Feb 2015 23:16:09 GMT", "error parsing date"), + ("1997-01-31H09:26:56.123Z", "invalid timestamp separator"), + ("1997-01-31 09:26:56.123Z", "error parsing time"), + ("1997:01:31T09:26:56.123Z", "error parsing date"), + ("1997:1:31T09:26:56.123Z", "error parsing date"), + ("1997-01-32T09:26:56.123Z", "error parsing date"), + ("1997-13-32T09:26:56.123Z", "error parsing date"), + ("1997-02-29T09:26:56.123Z", "error parsing date"), + ("2015-02-30T17:35:20-08:00", "error parsing date"), + ("1997-01-10T9:26:56.123Z", "error parsing time"), + ("2015-01-20T25:35:20-08:00", "error parsing time"), + ("1997-01-10T09:61:56.123Z", "error parsing time"), + ("1997-01-10T09:61:90.123Z", "error parsing time"), + ("1997-01-10T12:00:6.123Z", "error parsing time"), + ("1997-01-31T092656.123Z", "error parsing time"), + ("1997-01-10T12:00:06.", "error parsing time"), + ("1997-01-10T12:00:06. ", "error parsing time"), + ]; + + for (s, ctx) in cases { + let expected = format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); + let actual = string_to_datetime(&Utc, s).unwrap_err().to_string(); + assert_eq!(actual, expected) + } + } + + // Parse a timestamp to timestamp int with a useful human readable error message + fn parse_timestamp(s: &str) -> Result { + let result = string_to_timestamp_nanos(s); + if let Err(e) = &result { + eprintln!("Error parsing timestamp '{s}': {e:?}"); + } + result + } + + #[test] + fn string_without_timezone_to_timestamp() { + // string without timezone should always output the same regardless the local or session timezone + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 0).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + naive_datetime.and_utc().timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + let tz: Tz = "+02:00".parse().unwrap(); + let date = string_to_datetime(&tz, "2020-09-08 13:42:29").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 11:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 13:42:29"); + + let date = string_to_datetime(&tz, "2020-09-08 13:42:29Z").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 13:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 15:42:29"); + + let dt = + NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ").unwrap(); + let local: Tz = "+08:00".parse().unwrap(); + + // Parsed as offset from UTC + let date = string_to_datetime(&local, "2020-09-08T13:42:29Z").unwrap(); + assert_eq!(dt, date.naive_utc()); + assert_ne!(dt, date.naive_local()); + + // Parsed as offset from local + let date = string_to_datetime(&local, "2020-09-08 13:42:29").unwrap(); + assert_eq!(dt, date.naive_local()); + assert_ne!(dt, date.naive_utc()); + } + + #[test] + fn parse_date32() { + let cases = [ + "2020-09-08", + "2020-9-8", + "2020-09-8", + "2020-9-08", + "2020-12-1", + "1690-2-5", + "2020-09-08 01:02:03", + ]; + for case in cases { + let v = date32_to_datetime(Date32Type::parse(case).unwrap()).unwrap(); + let expected = NaiveDate::parse_from_str(case, "%Y-%m-%d") + .or(NaiveDate::parse_from_str(case, "%Y-%m-%d %H:%M:%S")) + .unwrap(); + assert_eq!(v.date(), expected); + } + + let err_cases = [ + "", + "80-01-01", + "342", + "Foo", + "2020-09-08-03", + "2020--04-03", + "2020--", + "2020-09-08 01", + "2020-09-08 01:02", + "2020-09-08 01-02-03", + "2020-9-8 01:02:03", + "2020-09-08 1:2:3", + ]; + for case in err_cases { + assert_eq!(Date32Type::parse(case), None); + } + } + + #[test] + fn parse_time64_nanos() { + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567899999999"), + Some(7_801_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 AM"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 am"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 PM"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 pm"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 AM"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 am"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 PM"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 pm"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("02:10"), + Some(7_800_000_000_000) + ); + assert_eq!(Time64NanosecondType::parse("2:10"), Some(7_800_000_000_000)); + assert_eq!( + Time64NanosecondType::parse("12:10 AM"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10 am"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 PM"), + Some(51_000_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 pm"), + Some(51_000_000_000_000) + ); + + // parse directly as nanoseconds + assert_eq!(Time64NanosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64NanosecondType::parse("23:59:60"), + Some(86_400_000_000_000) + ); + + // custom format + assert_eq!( + Time64NanosecondType::parse_formatted("02 - 10 - 01 - .1234567", "%H - %M - %S - %.f"), + Some(7_801_123_456_700) + ); + } + + #[test] + fn parse_time64_micros() { + // expected formats + assert_eq!( + Time64MicrosecondType::parse("02:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 AM"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 am"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 PM"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 pm"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("02:10:01"), + Some(7_801_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("2:10:01"), Some(7_801_000_000)); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 AM"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 am"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 PM"), + Some(51_001_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 pm"), + Some(51_001_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("02:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("2:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 AM"), Some(600_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 am"), Some(600_000_000)); + assert_eq!( + Time64MicrosecondType::parse("2:10 PM"), + Some(51_000_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10 pm"), + Some(51_000_000_000) + ); + + // parse directly as microseconds + assert_eq!(Time64MicrosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64MicrosecondType::parse("23:59:60"), + Some(86_400_000_000) + ); + + // custom format + assert_eq!( + Time64MicrosecondType::parse_formatted("02 - 10 - 01 - .1234", "%H - %M - %S - %.f"), + Some(7_801_123_400) + ); + } + + #[test] + fn parse_time32_millis() { + // expected formats + assert_eq!(Time32MillisecondType::parse("02:10:01.1"), Some(7_801_100)); + assert_eq!(Time32MillisecondType::parse("2:10:01.1"), Some(7_801_100)); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 AM"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 am"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 PM"), + Some(51_001_120) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 pm"), + Some(51_001_120) + ); + assert_eq!(Time32MillisecondType::parse("02:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 AM"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 am"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 PM"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 pm"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("02:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("2:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("12:10 AM"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("12:10 am"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("2:10 PM"), Some(51_000_000)); + assert_eq!(Time32MillisecondType::parse("2:10 pm"), Some(51_000_000)); + + // parse directly as milliseconds + assert_eq!(Time32MillisecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32MillisecondType::parse("23:59:60"), Some(86_400_000)); + + // custom format + assert_eq!( + Time32MillisecondType::parse_formatted("02 - 10 - 01 - .1", "%H - %M - %S - %.f"), + Some(7_801_100) + ); + } + + #[test] + fn parse_time32_secs() { + // expected formats + assert_eq!(Time32SecondType::parse("02:10:01.1"), Some(7_801)); + assert_eq!(Time32SecondType::parse("02:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("2:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("12:10:01 AM"), Some(601)); + assert_eq!(Time32SecondType::parse("12:10:01 am"), Some(601)); + assert_eq!(Time32SecondType::parse("2:10:01 PM"), Some(51_001)); + assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); + assert_eq!(Time32SecondType::parse("02:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("2:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("12:10 AM"), Some(600)); + assert_eq!(Time32SecondType::parse("12:10 am"), Some(600)); + assert_eq!(Time32SecondType::parse("2:10 PM"), Some(51_000)); + assert_eq!(Time32SecondType::parse("2:10 pm"), Some(51_000)); + + // parse directly as seconds + assert_eq!(Time32SecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32SecondType::parse("23:59:60"), Some(86400)); + + // custom format + assert_eq!( + Time32SecondType::parse_formatted("02 - 10 - 01", "%H - %M - %S"), + Some(7_801) + ); + } + + #[test] + fn test_string_to_time_invalid() { + let cases = [ + "25:00", + "9:00:", + "009:00", + "09:0:00", + "25:00:00", + "13:00 AM", + "13:00 PM", + "12:00. AM", + "09:0:00", + "09:01:0", + "09:01:1", + "9:1:0", + "09:01:0", + "1:00.123", + "1:00:00.123f", + " 9:00:00", + ":09:00", + "T9:00:00", + "AM", + ]; + for case in cases { + assert!(string_to_time(case).is_none(), "{case}"); + } + } + + #[test] + fn test_string_to_time_chrono() { + let cases = [ + ("1:00", "%H:%M"), + ("12:00", "%H:%M"), + ("13:00", "%H:%M"), + ("24:00", "%H:%M"), + ("1:00:00", "%H:%M:%S"), + ("12:00:30", "%H:%M:%S"), + ("13:00:59", "%H:%M:%S"), + ("24:00:60", "%H:%M:%S"), + ("09:00:00", "%H:%M:%S%.f"), + ("0:00:30.123456", "%H:%M:%S%.f"), + ("0:00 AM", "%I:%M %P"), + ("1:00 AM", "%I:%M %P"), + ("12:00 AM", "%I:%M %P"), + ("13:00 AM", "%I:%M %P"), + ("0:00 PM", "%I:%M %P"), + ("1:00 PM", "%I:%M %P"), + ("12:00 PM", "%I:%M %P"), + ("13:00 PM", "%I:%M %P"), + ("1:00 pM", "%I:%M %P"), + ("1:00 Pm", "%I:%M %P"), + ("1:00 aM", "%I:%M %P"), + ("1:00 Am", "%I:%M %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123 PM", "%I:%M:%S%.f %P"), + ("1:00:30.1234 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.12F456 PM", "%I:%M:%S%.f %P"), + ]; + for (s, format) in cases { + let chrono = NaiveTime::parse_from_str(s, format).ok(); + let custom = string_to_time(s); + assert_eq!(chrono, custom, "{s}"); + } + } + + #[test] + fn test_parse_interval() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + assert_eq!( + Interval::new(1i32, 0i32, 0i64), + Interval::parse("1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 0i32, 0i64), + Interval::parse("2 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-1i32, -18i32, -(NANOS_PER_DAY / 5)), + Interval::parse("-1.5 months -3.2 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse("0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse(".5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 10i32, 9 * NANOS_PER_HOUR), + Interval::parse("2.1 months 7.25 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::parse("1 centurys 1 month", &config) + .unwrap_err() + .to_string(), + r#"Parser error: Invalid input syntax for type interval: "1 centurys 1 month""# + ); + + assert_eq!( + Interval::new(37i32, 0i32, 0i64), + Interval::parse("3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(35i32, 0i32, 0i64), + Interval::parse("3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-37i32, 0i32, 0i64), + Interval::parse("-3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-35i32, 0i32, 0i64), + Interval::parse("-3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 5i32, 0i64), + Interval::parse("5 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 3 * NANOS_PER_HOUR), + Interval::parse("7 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 5 * NANOS_PER_MINUTE), + Interval::parse("7 days 5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, -5 * NANOS_PER_MINUTE), + Interval::parse("7 days -5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -7i32, 5 * NANOS_PER_HOUR), + Interval::parse("-7 days 5 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 0i32, + -7i32, + -5 * NANOS_PER_HOUR - 5 * NANOS_PER_MINUTE - 5 * NANOS_PER_SECOND + ), + Interval::parse("-7 days -5 hours -5 minutes -5 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 0i32, 25 * NANOS_PER_MILLIS), + Interval::parse("1 year 25 millisecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 12i32, + 1i32, + (NANOS_PER_SECOND as f64 * 0.000000001_f64) as i64 + ), + Interval::parse("1 year 1 day 0.000000001 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, NANOS_PER_MILLIS / 10), + Interval::parse("1 year 1 day 0.1 milliseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1000i64), + Interval::parse("1 year 1 day 1 microsecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1i64), + Interval::parse("1 year 1 day 1 nanoseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(1i32, 0i32, -NANOS_PER_SECOND), + Interval::parse("1 month -1 second", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + -13i32, + -8i32, + -NANOS_PER_HOUR + - NANOS_PER_MINUTE + - NANOS_PER_SECOND + - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64 + ), + Interval::parse( + "-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", + &config + ) + .unwrap(), + ); + + // no units + assert_eq!( + Interval::new(1, 0, 0), + Interval::parse("1", &config).unwrap() + ); + assert_eq!( + Interval::new(42, 0, 0), + Interval::parse("42", &config).unwrap() + ); + assert_eq!( + Interval::new(0, 0, 42_000_000_000), + Interval::parse("42", &IntervalParseConfig::new(IntervalUnit::Second)).unwrap() + ); + + // shorter units + assert_eq!( + Interval::new(1, 0, 0), + Interval::parse("1 mon", &config).unwrap() + ); + assert_eq!( + Interval::new(1, 0, 0), + Interval::parse("1 mons", &config).unwrap() + ); + assert_eq!( + Interval::new(0, 0, 1_000_000), + Interval::parse("1 ms", &config).unwrap() + ); + assert_eq!( + Interval::new(0, 0, 1_000), + Interval::parse("1 us", &config).unwrap() + ); + + // no space + assert_eq!( + Interval::new(0, 0, 1_000), + Interval::parse("1us", &config).unwrap() + ); + assert_eq!( + Interval::new(0, 0, NANOS_PER_SECOND), + Interval::parse("1s", &config).unwrap() + ); + assert_eq!( + Interval::new(1, 2, 10_864_000_000_000), + Interval::parse("1mon 2days 3hr 1min 4sec", &config).unwrap() + ); + + assert_eq!( + Interval::new( + -13i32, + -8i32, + -NANOS_PER_HOUR + - NANOS_PER_MINUTE + - NANOS_PER_SECOND + - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64 + ), + Interval::parse( + "-1year -1month -1week -1day -1 hour -1 minute -1 second -1.11millisecond", + &config + ) + .unwrap(), + ); + + assert_eq!( + Interval::parse("1h s", &config).unwrap_err().to_string(), + r#"Parser error: Invalid input syntax for type interval: "1h s""# + ); + + assert_eq!( + Interval::parse("1XX", &config).unwrap_err().to_string(), + r#"Parser error: Invalid input syntax for type interval: "1XX""# + ); + } + + #[test] + fn test_duplicate_interval_type() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let err = Interval::parse("1 month 1 second 1 second", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 month 1 second 1 second\". Repeated type 'second'")"#, + format!("{err:?}") + ); + + // test with singular and plural forms + let err = Interval::parse("1 century 2 centuries", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 century 2 centuries\". Repeated type 'centuries'")"#, + format!("{err:?}") + ); + } + + #[test] + fn test_interval_amount_parsing() { + // integer + let result = IntervalAmount::from_str("123").unwrap(); + let expected = IntervalAmount::new(123, 0); + + assert_eq!(result, expected); + + // positive w/ fractional + let result = IntervalAmount::from_str("0.3").unwrap(); + let expected = IntervalAmount::new(0, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // negative w/ fractional + let result = IntervalAmount::from_str("-3.5").unwrap(); + let expected = IntervalAmount::new(-3, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // invalid: missing fractional + let result = IntervalAmount::from_str("3."); + assert!(result.is_err()); + + // invalid: sign in fractional + let result = IntervalAmount::from_str("3.-5"); + assert!(result.is_err()); + } + + #[test] + fn test_interval_precision() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let result = Interval::parse("100000.1 days", &config).unwrap(); + let expected = Interval::new(0_i32, 100_000_i32, NANOS_PER_DAY / 10); + + assert_eq!(result, expected); + } + + #[test] + fn test_interval_addition() { + // add 4.1 centuries + let start = Interval::new(1, 2, 3); + let expected = Interval::new(4921, 2, 3); + + let result = start + .add( + IntervalAmount::new(4, 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Century, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 10.25 decades + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1231, 2, 3); + + let result = start + .add( + IntervalAmount::new(10, 25 * 10_i64.pow(INTERVAL_PRECISION - 2)), + IntervalUnit::Decade, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 30.3 years (reminder: Postgres logic does not spill to days/nanos when interval is larger than a month) + let start = Interval::new(1, 2, 3); + let expected = Interval::new(364, 2, 3); + + let result = start + .add( + IntervalAmount::new(30, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Year, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 1.5 months + let start = Interval::new(1, 2, 3); + let expected = Interval::new(2, 17, 3); + + let result = start + .add( + IntervalAmount::new(1, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Month, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -2 weeks + let start = Interval::new(1, 25, 3); + let expected = Interval::new(1, 11, 3); + + let result = start + .add(IntervalAmount::new(-2, 0), IntervalUnit::Week) + .unwrap(); + + assert_eq!(result, expected); + + // add 2.2 days + let start = Interval::new(12, 15, 3); + let expected = Interval::new(12, 17, 3 + 17_280 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(2, 2 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Day, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 12.5 hours + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1, 2, 3 + 45_000 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(12, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Hour, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -1.5 minutes + let start = Interval::new(0, 0, -3); + let expected = Interval::new(0, 0, -90_000_000_000 - 3); + + let result = start + .add( + IntervalAmount::new(-1, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Minute, + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn string_to_timestamp_old() { + parse_timestamp("1677-06-14T07:29:01.256") + .map_err(|e| assert!(e.to_string().ends_with(ERR_NANOSECONDS_NOT_SUPPORTED))) + .unwrap_err(); + } + + #[test] + fn test_parse_decimal_with_parameter() { + let tests = [ + ("0", 0i128), + ("123.123", 123123i128), + ("123.1234", 123123i128), + ("123.1", 123100i128), + ("123", 123000i128), + ("-123.123", -123123i128), + ("-123.1234", -123123i128), + ("-123.1", -123100i128), + ("-123", -123000i128), + ("0.0000123", 0i128), + ("12.", 12000i128), + ("-12.", -12000i128), + ("00.1", 100i128), + ("-00.1", -100i128), + ("12345678912345678.1234", 12345678912345678123i128), + ("-12345678912345678.1234", -12345678912345678123i128), + ("99999999999999999.999", 99999999999999999999i128), + ("-99999999999999999.999", -99999999999999999999i128), + (".123", 123i128), + ("-.123", -123i128), + ("123.", 123000i128), + ("-123.", -123000i128), + ]; + for (s, i) in tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!(i, result_128.unwrap()); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!(i256::from_i128(i), result_256.unwrap()); + } + + let e_notation_tests = [ + ("1.23e3", "1230.0", 2), + ("5.6714e+2", "567.14", 4), + ("5.6714e-2", "0.056714", 4), + ("5.6714e-2", "0.056714", 3), + ("5.6741214125e2", "567.41214125", 4), + ("8.91E4", "89100.0", 2), + ("3.14E+5", "314000.0", 2), + ("2.718e0", "2.718", 2), + ("9.999999e-1", "0.9999999", 4), + ("1.23e+3", "1230", 2), + ("1.234559e+3", "1234.559", 2), + ("1.00E-10", "0.0000000001", 11), + ("1.23e-4", "0.000123", 2), + ("9.876e7", "98760000.0", 2), + ("5.432E+8", "543200000.0", 10), + ("1.234567e9", "1234567000.0", 2), + ("1.234567e2", "123.45670000", 2), + ("4749.3e-5", "0.047493", 10), + ("4749.3e+5", "474930000", 10), + ("4749.3e-5", "0.047493", 1), + ("4749.3e+5", "474930000", 1), + ("0E-8", "0", 10), + ("0E+6", "0", 10), + ("1E-8", "0.00000001", 10), + ("12E+6", "12000000", 10), + ("12E-6", "0.000012", 10), + ("0.1e-6", "0.0000001", 10), + ("0.1e+6", "100000", 10), + ("0.12e-6", "0.00000012", 10), + ("0.12e+6", "120000", 10), + ("000000000001e0", "000000000001", 3), + ("000001.1034567002e0", "000001.1034567002", 3), + ("1.234e16", "12340000000000000", 0), + ("123.4e16", "1234000000000000000", 0), + ]; + for (e, d, scale) in e_notation_tests { + let result_128_e = parse_decimal::(e, 20, scale); + let result_128_d = parse_decimal::(d, 20, scale); + assert_eq!(result_128_e.unwrap(), result_128_d.unwrap()); + let result_256_e = parse_decimal::(e, 20, scale); + let result_256_d = parse_decimal::(d, 20, scale); + assert_eq!(result_256_e.unwrap(), result_256_d.unwrap()); + } + let can_not_parse_tests = [ + "123,123", + ".", + "123.123.123", + "", + "+", + "-", + "e", + "1.3e+e3", + "5.6714ee-2", + "4.11ee-+4", + "4.11e++4", + "1.1e.12", + "1.23e+3.", + "1.23e+3.1", + ]; + for s in can_not_parse_tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_128.unwrap_err().to_string() + ); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_256.unwrap_err().to_string() + ); + } + let overflow_parse_tests = [ + ("12345678", 3), + ("1.2345678e7", 3), + ("12345678.9", 3), + ("1.23456789e+7", 3), + ("99999999.99", 3), + ("9.999999999e7", 3), + ("12345678908765.123456", 3), + ("123456789087651234.56e-4", 3), + ("1234560000000", 0), + ("1.23456e12", 0), + ]; + for (s, scale) in overflow_parse_tests { + let result_128 = parse_decimal::(s, 10, scale); + let expected_128 = "Parser error: parse decimal overflow"; + let actual_128 = result_128.unwrap_err().to_string(); + + assert!( + actual_128.contains(expected_128), + "actual: '{actual_128}', expected: '{expected_128}'" + ); + + let result_256 = parse_decimal::(s, 10, scale); + let expected_256 = "Parser error: parse decimal overflow"; + let actual_256 = result_256.unwrap_err().to_string(); + + assert!( + actual_256.contains(expected_256), + "actual: '{actual_256}', expected: '{expected_256}'" + ); + } + + let edge_tests_128 = [ + ( + "99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 0, + ), + ( + "999999999999999999999999999999999999.99", + 99999999999999999999999999999999999999i128, + 2, + ), + ( + "9999999999999999999999999.9999999999999", + 99999999999999999999999999999999999999i128, + 13, + ), + ( + "9999999999999999999999999", + 99999999999999999999999990000000000000i128, + 13, + ), + ( + "0.99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 38, + ), + ( + "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001016744", + 0i128, + 15, + ), + ( + "1.016744e-320", + 0i128, + 15, + ), + ]; + for (s, i, scale) in edge_tests_128 { + let result_128 = parse_decimal::(s, 38, scale); + assert_eq!(i, result_128.unwrap()); + } + let edge_tests_256 = [ + ( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 0, + ), + ( + "999999999999999999999999999999999999999999999999999999999999999999999999.9999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 4, + ), + ( + "99999999999999999999999999999999999999999999999999.99999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 26, + ), + ( + "9.999999999999999999999999999999999999999999999999999999999999999999999999999e49", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 26, + ), + ( + "99999999999999999999999999999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999900000000000000000000000000", + ) + .unwrap(), + 26, + ), + ( + "9.9999999999999999999999999999999999999999999999999e+49", + i256::from_string( + "9999999999999999999999999999999999999999999999999900000000000000000000000000", + ) + .unwrap(), + 26, + ), + ]; + for (s, i, scale) in edge_tests_256 { + let result = parse_decimal::(s, 76, scale); + assert_eq!(i, result.unwrap()); + } + } + + #[test] + fn test_parse_empty() { + assert_eq!(Int32Type::parse(""), None); + assert_eq!(Int64Type::parse(""), None); + assert_eq!(UInt32Type::parse(""), None); + assert_eq!(UInt64Type::parse(""), None); + assert_eq!(Float32Type::parse(""), None); + assert_eq!(Float64Type::parse(""), None); + assert_eq!(Int32Type::parse("+"), None); + assert_eq!(Int64Type::parse("+"), None); + assert_eq!(UInt32Type::parse("+"), None); + assert_eq!(UInt64Type::parse("+"), None); + assert_eq!(Float32Type::parse("+"), None); + assert_eq!(Float64Type::parse("+"), None); + assert_eq!(TimestampNanosecondType::parse(""), None); + assert_eq!(Date32Type::parse(""), None); + } + + #[test] + fn test_parse_interval_month_day_nano_config() { + let interval = parse_interval_month_day_nano_config( + "1", + IntervalParseConfig::new(IntervalUnit::Second), + ) + .unwrap(); + assert_eq!(interval.months, 0); + assert_eq!(interval.days, 0); + assert_eq!(interval.nanoseconds, NANOS_PER_SECOND); + } +} diff --git a/arrow/src/util/pretty.rs b/arrow-cast/src/pretty.rs similarity index 50% rename from arrow/src/util/pretty.rs rename to arrow-cast/src/pretty.rs index b0013619b50c..4a3cbda283a5 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -15,44 +15,68 @@ // specific language governing permissions and limitations // under the License. -//! Utilities for printing record batches. Note this module is not -//! available unless `feature = "prettyprint"` is enabled. +//! Utilities for pretty printing [`RecordBatch`]es and [`Array`]s. +//! +//! Note this module is not available unless `feature = "prettyprint"` is enabled. +//! +//! [`RecordBatch`]: arrow_array::RecordBatch +//! [`Array`]: arrow_array::Array -use crate::{array::ArrayRef, record_batch::RecordBatch}; -use comfy_table::{Cell, Table}; use std::fmt::Display; -use crate::error::Result; +use comfy_table::{Cell, Table}; + +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::ArrowError; -use super::display::array_value_to_string; +use crate::display::{ArrayFormatter, FormatOptions}; -///! Create a visual representation of record batches -pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { - create_table(results) +/// Create a visual representation of record batches +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { + let options = FormatOptions::default().with_display_error(true); + pretty_format_batches_with_options(results, &options) } -///! Create a visual representation of columns +/// Create a visual representation of record batches +pub fn pretty_format_batches_with_options( + results: &[RecordBatch], + options: &FormatOptions, +) -> Result { + create_table(results, options) +} + +/// Create a visual representation of columns pub fn pretty_format_columns( col_name: &str, results: &[ArrayRef], -) -> Result { - create_column(col_name, results) +) -> Result { + let options = FormatOptions::default().with_display_error(true); + pretty_format_columns_with_options(col_name, results, &options) } -///! Prints a visual representation of record batches to stdout -pub fn print_batches(results: &[RecordBatch]) -> Result<()> { - println!("{}", create_table(results)?); +/// Utility function to create a visual representation of columns with options +fn pretty_format_columns_with_options( + col_name: &str, + results: &[ArrayRef], + options: &FormatOptions, +) -> Result { + create_column(col_name, results, options) +} + +/// Prints a visual representation of record batches to stdout +pub fn print_batches(results: &[RecordBatch]) -> Result<(), ArrowError> { + println!("{}", pretty_format_batches(results)?); Ok(()) } -///! Prints a visual representation of a list of column to stdout -pub fn print_columns(col_name: &str, results: &[ArrayRef]) -> Result<()> { - println!("{}", create_column(col_name, results)?); +/// Prints a visual representation of a list of column to stdout +pub fn print_columns(col_name: &str, results: &[ArrayRef]) -> Result<(), ArrowError> { + println!("{}", pretty_format_columns(col_name, results)?); Ok(()) } -///! Convert a series of record batches into a table -fn create_table(results: &[RecordBatch]) -> Result { +/// Convert a series of record batches into a table +fn create_table(results: &[RecordBatch], options: &FormatOptions) -> Result { let mut table = Table::new(); table.load_preset("||--+-++| ++++++"); @@ -64,16 +88,21 @@ fn create_table(results: &[RecordBatch]) -> Result
{ let mut header = Vec::new(); for field in schema.fields() { - header.push(Cell::new(&field.name())); + header.push(Cell::new(field.name())); } table.set_header(header); for batch in results { + let formatters = batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), options)) + .collect::, ArrowError>>()?; + for row in 0..batch.num_rows() { let mut cells = Vec::new(); - for col in 0..batch.num_columns() { - let column = batch.column(col); - cells.push(Cell::new(&array_value_to_string(column, row)?)); + for formatter in &formatters { + cells.push(Cell::new(formatter.value(row))); } table.add_row(cells); } @@ -82,7 +111,11 @@ fn create_table(results: &[RecordBatch]) -> Result
{ Ok(table) } -fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ +fn create_column( + field: &str, + columns: &[ArrayRef], + options: &FormatOptions, +) -> Result { let mut table = Table::new(); table.load_preset("||--+-++| ++++++"); @@ -94,8 +127,9 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ table.set_header(header); for col in columns { + let formatter = ArrayFormatter::try_new(col.as_ref(), options)?; for row in 0..col.len() { - let cells = vec![Cell::new(&array_value_to_string(col, row)?)]; + let cells = vec![Cell::new(formatter.value(row))]; table.add_row(cells); } } @@ -105,28 +139,23 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ #[cfg(test)] mod tests { - use crate::{ - array::{ - self, new_null_array, Array, Date32Array, Date64Array, - FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder, - StringArray, StringBuilder, StringDictionaryBuilder, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UnionArray, UnionBuilder, - }, - buffer::Buffer, - datatypes::{DataType, Field, Float64Type, Int32Type, Schema, UnionMode}, - }; - - use super::*; - use crate::array::{Decimal128Array, FixedSizeListBuilder}; use std::fmt::Write; use std::sync::Arc; use half::f16; + use arrow_array::builder::*; + use arrow_array::types::*; + use arrow_array::*; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; + use arrow_schema::*; + + use crate::display::array_value_to_string; + + use super::*; + #[test] - fn test_pretty_format_batches() -> Result<()> { + fn test_pretty_format_batches() { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, true), @@ -150,9 +179,10 @@ mod tests { Some(100), ])), ], - )?; + ) + .unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+---+-----+", @@ -167,13 +197,11 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_columns() -> Result<()> { + fn test_pretty_format_columns() { let columns = vec![ Arc::new(array::StringArray::from(vec![ Some("a"), @@ -184,18 +212,16 @@ mod tests { Arc::new(array::StringArray::from(vec![Some("e"), None, Some("g")])), ]; - let table = pretty_format_columns("a", &columns)?.to_string(); + let table = pretty_format_columns("a", &columns).unwrap().to_string(); let expected = vec![ - "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", - "| |", "| g |", "+---+", + "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", "| |", + "| g |", "+---+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] @@ -231,28 +257,25 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{:#?}", table); + assert_eq!(expected, actual, "Actual result:\n{table:#?}"); } #[test] - fn test_pretty_format_dictionary() -> Result<()> { + fn test_pretty_format_dictionary() { // define a schema. - let field_type = - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); - let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + let field = Field::new_dictionary("d1", DataType::Int32, DataType::Utf8, true); + let schema = Arc::new(Schema::new(vec![field])); - let keys_builder = PrimitiveBuilder::::with_capacity(10); - let values_builder = StringBuilder::new(); - let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + let mut builder = StringDictionaryBuilder::::new(); - builder.append("one")?; + builder.append_value("one"); builder.append_null(); - builder.append("three")?; + builder.append_value("three"); let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-------+", @@ -266,18 +289,14 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_fixed_size_list() -> Result<()> { + fn test_pretty_format_fixed_size_list() { // define a schema. - let field_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, true)), - 3, - ); + let field_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys_builder = Int32Array::builder(3); @@ -292,8 +311,8 @@ mod tests { let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-----------+", "| d1 |", @@ -306,27 +325,101 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_pretty_format_string_view() { + let schema = Arc::new(Schema::new(vec![Field::new( + "d1", + DataType::Utf8View, + true, + )])); + + // Use a small capacity so we end up with multiple views + let mut builder = StringViewBuilder::with_capacity(20); + builder.append_value("hello"); + builder.append_null(); + builder.append_value("longer than 12 bytes"); + builder.append_value("another than 12 bytes"); + builder.append_null(); + builder.append_value("small"); + + let array: ArrayRef = Arc::new(builder.finish()); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + let expected = vec![ + "+-----------------------+", + "| d1 |", + "+-----------------------+", + "| hello |", + "| |", + "| longer than 12 bytes |", + "| another than 12 bytes |", + "| |", + "| small |", + "+-----------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{table:#?}"); + } + + #[test] + fn test_pretty_format_binary_view() { + let schema = Arc::new(Schema::new(vec![Field::new( + "d1", + DataType::BinaryView, + true, + )])); + + // Use a small capacity so we end up with multiple views + let mut builder = BinaryViewBuilder::with_capacity(20); + builder.append_value(b"hello"); + builder.append_null(); + builder.append_value(b"longer than 12 bytes"); + builder.append_value(b"another than 12 bytes"); + builder.append_null(); + builder.append_value(b"small"); - Ok(()) + let array: ArrayRef = Arc::new(builder.finish()); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + let expected = vec![ + "+--------------------------------------------+", + "| d1 |", + "+--------------------------------------------+", + "| 68656c6c6f |", + "| |", + "| 6c6f6e676572207468616e203132206279746573 |", + "| 616e6f74686572207468616e203132206279746573 |", + "| |", + "| 736d616c6c |", + "+--------------------------------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n\n{table:#?}"); } #[test] - fn test_pretty_format_fixed_size_binary() -> Result<()> { + fn test_pretty_format_fixed_size_binary() { // define a schema. let field_type = DataType::FixedSizeBinary(3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 3); - builder.append_value(&[1, 2, 3]).unwrap(); + builder.append_value([1, 2, 3]).unwrap(); builder.append_null(); - builder.append_value(&[7, 8, 9]).unwrap(); + builder.append_value([7, 8, 9]).unwrap(); let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+--------+", "| d1 |", @@ -339,9 +432,7 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } /// Generate an array with type $ARRAYTYPE with a numeric value of @@ -368,17 +459,49 @@ mod tests { let expected = $EXPECTED_RESULT; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n\n{:#?}\n\n", actual); + assert_eq!(expected, actual, "Actual result:\n\n{actual:#?}\n\n"); }; } + fn timestamp_batch(timezone: &str, value: T::Native) -> RecordBatch { + let mut builder = PrimitiveBuilder::::with_capacity(10); + builder.append_value(value); + builder.append_null(); + let array = builder.finish(); + let array = array.with_timezone(timezone); + + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + array.data_type().clone(), + true, + )])); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_pretty_format_timestamp_second_with_fixed_offset_timezone() { + let batch = timestamp_batch::("+08:00", 11111111); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+---------------------------+", + "| f |", + "+---------------------------+", + "| 1970-05-09T22:25:11+08:00 |", + "| |", + "+---------------------------+", + ]; + let actual: Vec<&str> = table.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n\n{actual:#?}\n\n"); + } + #[test] fn test_pretty_format_timestamp_second() { let expected = vec![ "+---------------------+", "| f |", "+---------------------+", - "| 1970-05-09 14:25:11 |", + "| 1970-05-09T14:25:11 |", "| |", "+---------------------+", ]; @@ -391,7 +514,7 @@ mod tests { "+-------------------------+", "| f |", "+-------------------------+", - "| 1970-01-01 03:05:11.111 |", + "| 1970-01-01T03:05:11.111 |", "| |", "+-------------------------+", ]; @@ -404,7 +527,7 @@ mod tests { "+----------------------------+", "| f |", "+----------------------------+", - "| 1970-01-01 00:00:11.111111 |", + "| 1970-01-01T00:00:11.111111 |", "| |", "+----------------------------+", ]; @@ -417,7 +540,7 @@ mod tests { "+-------------------------------+", "| f |", "+-------------------------------+", - "| 1970-01-01 00:00:00.011111111 |", + "| 1970-01-01T00:00:00.011111111 |", "| |", "+-------------------------------+", ]; @@ -440,12 +563,12 @@ mod tests { #[test] fn test_pretty_format_date_64() { let expected = vec![ - "+------------+", - "| f |", - "+------------+", - "| 2005-03-18 |", - "| |", - "+------------+", + "+---------------------+", + "| f |", + "+---------------------+", + "| 2005-03-18T01:58:20 |", + "| |", + "+---------------------+", ]; check_datetime!(Date64Array, 1111111100000, expected); } @@ -503,7 +626,7 @@ mod tests { } #[test] - fn test_int_display() -> Result<()> { + fn test_int_display() { let array = Arc::new(Int32Array::from(vec![6, 3])) as ArrayRef; let actual_one = array_value_to_string(&array, 0).unwrap(); let expected_one = "6"; @@ -512,11 +635,10 @@ mod tests { let expected_two = "3"; assert_eq!(actual_one, expected_one); assert_eq!(actual_two, expected_two); - Ok(()) } #[test] - fn test_decimal_display() -> Result<()> { + fn test_decimal_display() { let precision = 10; let scale = 2; @@ -534,9 +656,9 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![dm])?; + let batch = RecordBatch::try_new(schema, vec![dm]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-------+", @@ -550,13 +672,11 @@ mod tests { ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_decimal_display_zero_scale() -> Result<()> { + fn test_decimal_display_zero_scale() { let precision = 5; let scale = 0; @@ -574,33 +694,31 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![dm])?; + let batch = RecordBatch::try_new(schema, vec![dm]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", - "| 3040 |", "+------+", + "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", "| 3040 |", + "+------+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_struct() -> Result<()> { + fn test_pretty_format_struct() { let schema = Schema::new(vec![ - Field::new( + Field::new_struct( "c1", - DataType::Struct(vec![ - Field::new("c11", DataType::Int32, false), - Field::new( + vec![ + Field::new("c11", DataType::Int32, true), + Field::new_struct( "c12", - DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]), + vec![Field::new("c121", DataType::Utf8, false)], false, ), - ]), + ], false, ), Field::new("c2", DataType::Utf8, false), @@ -608,47 +726,43 @@ mod tests { let c1 = StructArray::from(vec![ ( - Field::new("c11", DataType::Int32, false), + Arc::new(Field::new("c11", DataType::Int32, true)), Arc::new(Int32Array::from(vec![Some(1), None, Some(5)])) as ArrayRef, ), ( - Field::new( + Arc::new(Field::new_struct( "c12", - DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]), + vec![Field::new("c121", DataType::Utf8, false)], false, - ), + )), Arc::new(StructArray::from(vec![( - Field::new("c121", DataType::Utf8, false), - Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) - as ArrayRef, + Arc::new(Field::new("c121", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) as ArrayRef, )])) as ArrayRef, ), ]); let c2 = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - r#"+-------------------------------------+----+"#, - r#"| c1 | c2 |"#, - r#"+-------------------------------------+----+"#, - r#"| {"c11": 1, "c12": {"c121": "e"}} | a |"#, - r#"| {"c11": null, "c12": {"c121": "f"}} | b |"#, - r#"| {"c11": 5, "c12": {"c121": "g"}} | c |"#, - r#"+-------------------------------------+----+"#, + "+--------------------------+----+", + "| c1 | c2 |", + "+--------------------------+----+", + "| {c11: 1, c12: {c121: e}} | a |", + "| {c11: , c12: {c121: f}} | b |", + "| {c11: 5, c12: {c121: g}} | c |", + "+--------------------------+----+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_dense_union() -> Result<()> { + fn test_pretty_format_dense_union() { let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); @@ -656,22 +770,18 @@ mod tests { builder.append_null::("a").unwrap(); let union = builder.build().unwrap(); - let schema = Schema::new(vec![Field::new( + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Dense, - ), - false, + vec![0, 1], + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + UnionMode::Dense, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+------------+", @@ -685,11 +795,10 @@ mod tests { ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_pretty_format_sparse_union() -> Result<()> { + fn test_pretty_format_sparse_union() { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); @@ -697,22 +806,18 @@ mod tests { builder.append_null::("a").unwrap(); let union = builder.build().unwrap(); - let schema = Schema::new(vec![Field::new( + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Sparse, - ), - false, + vec![0, 1], + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+------------+", @@ -726,11 +831,10 @@ mod tests { ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_pretty_format_nested_union() -> Result<()> { + fn test_pretty_format_nested_union() { //Inner UnionArray let mut builder = UnionBuilder::new_dense(); builder.append::("b", 1).unwrap(); @@ -740,43 +844,40 @@ mod tests { builder.append_null::("c").unwrap(); let inner = builder.build().unwrap(); - let inner_field = Field::new( + let inner_field = Field::new_union( "European Union", - DataType::Union( - vec![ - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Dense, - ), - false, + vec![0, 1], + vec![ + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float64, false), + ], + UnionMode::Dense, ); // Can't use UnionBuilder with non-primitive types, so manually build outer UnionArray let a_array = Int32Array::from(vec![None, None, None, Some(1234), Some(23)]); - let type_ids = Buffer::from_slice_ref(&[1_i8, 1, 0, 0, 1]); + let type_ids = [1, 1, 0, 0, 1].into_iter().collect::>(); - let children: Vec<(Field, Arc)> = vec![ - (Field::new("a", DataType::Int32, true), Arc::new(a_array)), - (inner_field.clone(), Arc::new(inner)), - ]; + let children = vec![Arc::new(a_array) as Arc, Arc::new(inner)]; - let outer = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap(); + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(inner_field.clone())), + ] + .into_iter() + .collect(); - let schema = Schema::new(vec![Field::new( + let outer = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![Field::new("a", DataType::Int32, true), inner_field], - vec![0, 1], - UnionMode::Sparse, - ), - false, + vec![0, 1], + vec![Field::new("a", DataType::Int32, true), inner_field], + UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+-----------------------------+", @@ -790,11 +891,10 @@ mod tests { "+-----------------------------+", ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_writing_formatted_batches() -> Result<()> { + fn test_writing_formatted_batches() { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, true), @@ -818,12 +918,13 @@ mod tests { Some(100), ])), ], - )?; + ) + .unwrap(); let mut buf = String::new(); - write!(&mut buf, "{}", pretty_format_batches(&[batch])?).unwrap(); + write!(&mut buf, "{}", pretty_format_batches(&[batch]).unwrap()).unwrap(); - let s = vec![ + let s = [ "+---+-----+", "| a | b |", "+---+-----+", @@ -835,12 +936,10 @@ mod tests { ]; let expected = s.join("\n"); assert_eq!(expected, buf); - - Ok(()) } #[test] - fn test_float16_display() -> Result<()> { + fn test_float16_display() { let values = vec![ Some(f16::from_f32(f32::NAN)), Some(f16::from_f32(4.0)), @@ -854,18 +953,144 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![array])?; + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", - "+------+", + "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", "+------+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_pretty_format_interval_day_time() { + let arr = Arc::new(arrow_array::IntervalDayTimeArray::from(vec![ + Some(IntervalDayTime::new(-1, -600_000)), + Some(IntervalDayTime::new(0, -1001)), + Some(IntervalDayTime::new(0, -1)), + Some(IntervalDayTime::new(0, 1)), + Some(IntervalDayTime::new(0, 10)), + Some(IntervalDayTime::new(0, 100)), + ])); + + let schema = Arc::new(Schema::new(vec![Field::new( + "IntervalDayTime", + arr.data_type().clone(), + true, + )])); + + let batch = RecordBatch::try_new(schema, vec![arr]).unwrap(); + + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+------------------+", + "| IntervalDayTime |", + "+------------------+", + "| -1 days -10 mins |", + "| -1.001 secs |", + "| -0.001 secs |", + "| 0.001 secs |", + "| 0.010 secs |", + "| 0.100 secs |", + "+------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_pretty_format_interval_month_day_nano_array() { + let arr = Arc::new(arrow_array::IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano::new(-1, -1, -600_000_000_000)), + Some(IntervalMonthDayNano::new(0, 0, -1_000_000_001)), + Some(IntervalMonthDayNano::new(0, 0, -1)), + Some(IntervalMonthDayNano::new(0, 0, 1)), + Some(IntervalMonthDayNano::new(0, 0, 10)), + Some(IntervalMonthDayNano::new(0, 0, 100)), + Some(IntervalMonthDayNano::new(0, 0, 1_000)), + Some(IntervalMonthDayNano::new(0, 0, 10_000)), + Some(IntervalMonthDayNano::new(0, 0, 100_000)), + Some(IntervalMonthDayNano::new(0, 0, 1_000_000)), + Some(IntervalMonthDayNano::new(0, 0, 10_000_000)), + Some(IntervalMonthDayNano::new(0, 0, 100_000_000)), + Some(IntervalMonthDayNano::new(0, 0, 1_000_000_000)), + ])); + + let schema = Arc::new(Schema::new(vec![Field::new( + "IntervalMonthDayNano", + arr.data_type().clone(), + true, + )])); + + let batch = RecordBatch::try_new(schema, vec![arr]).unwrap(); + + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+--------------------------+", + "| IntervalMonthDayNano |", + "+--------------------------+", + "| -1 mons -1 days -10 mins |", + "| -1.000000001 secs |", + "| -0.000000001 secs |", + "| 0.000000001 secs |", + "| 0.000000010 secs |", + "| 0.000000100 secs |", + "| 0.000001000 secs |", + "| 0.000010000 secs |", + "| 0.000100000 secs |", + "| 0.001000000 secs |", + "| 0.010000000 secs |", + "| 0.100000000 secs |", + "| 1.000000000 secs |", + "+--------------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_format_options() { + let options = FormatOptions::default().with_null("null"); + let array = Int32Array::from(vec![Some(1), Some(2), None, Some(3), Some(4)]); + let batch = RecordBatch::try_from_iter([("my_column_name", Arc::new(array) as _)]).unwrap(); + + let column = pretty_format_columns_with_options( + "my_column_name", + &[batch.column(0).clone()], + &options, + ) + .unwrap() + .to_string(); + + let batch = pretty_format_batches_with_options(&[batch], &options) + .unwrap() + .to_string(); + + let expected = vec![ + "+----------------+", + "| my_column_name |", + "+----------------+", + "| 1 |", + "| 2 |", + "| null |", + "| 3 |", + "| 4 |", + "+----------------+", + ]; + + let actual: Vec<&str> = column.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n{column}"); - Ok(()) + let actual: Vec<&str> = batch.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n{batch}"); } } diff --git a/arrow-csv/Cargo.toml b/arrow-csv/Cargo.toml new file mode 100644 index 000000000000..be213c9363c2 --- /dev/null +++ b/arrow-csv/Cargo.toml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-csv" +version = { workspace = true } +description = "Support for parsing CSV format to and from the Arrow format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_csv" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +csv = { version = "1.1", default-features = false } +csv-core = { version = "0.1" } +lazy_static = { version = "1.4", default-features = false } +lexical-core = { version = "1.0", default-features = false } +regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] } + +[dev-dependencies] +tempfile = "3.3" +futures = "0.3" +tokio = { version = "1.27", default-features = false, features = ["io-util"] } +bytes = "1.4" diff --git a/arrow-csv/examples/README.md b/arrow-csv/examples/README.md new file mode 100644 index 000000000000..340413e76d94 --- /dev/null +++ b/arrow-csv/examples/README.md @@ -0,0 +1,21 @@ + + +# Examples +- [`csv_calculation.rs`](csv_calculation.rs): performs a simple calculation using the CSV reader \ No newline at end of file diff --git a/arrow-csv/examples/csv_calculation.rs b/arrow-csv/examples/csv_calculation.rs new file mode 100644 index 000000000000..6ce963e2b012 --- /dev/null +++ b/arrow-csv/examples/csv_calculation.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Int16Type; +use arrow_csv::ReaderBuilder; + +use arrow_schema::{DataType, Field, Schema}; +use std::fs::File; +use std::sync::Arc; + +fn main() { + // read csv from file + let file = File::open("arrow-csv/test/data/example.csv").unwrap(); + let csv_schema = Schema::new(vec![ + Field::new("c1", DataType::Int16, true), + Field::new("c2", DataType::Float32, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Boolean, true), + ]); + let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) + .with_header(true) + .build(file) + .unwrap(); + + match reader.next() { + Some(r) => match r { + Ok(r) => { + // get the column(0) max value + let col = r.column(0).as_primitive::(); + let max = col.iter().max().flatten(); + println!("max value column(0): {max:?}") + } + Err(e) => { + println!("{e:?}"); + } + }, + None => { + println!("csv is empty"); + } + } +} diff --git a/arrow-csv/src/lib.rs b/arrow-csv/src/lib.rs new file mode 100644 index 000000000000..28c0d6ebdbb8 --- /dev/null +++ b/arrow-csv/src/lib.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Transfer data between the Arrow memory format and CSV (comma-separated values). + +#![warn(missing_docs)] + +pub mod reader; +pub mod writer; + +pub use self::reader::infer_schema_from_files; +pub use self::reader::Reader; +pub use self::reader::ReaderBuilder; +pub use self::writer::Writer; +pub use self::writer::WriterBuilder; +use arrow_schema::ArrowError; + +fn map_csv_error(error: csv::Error) -> ArrowError { + match error.kind() { + csv::ErrorKind::Io(error) => ArrowError::CsvError(error.to_string()), + csv::ErrorKind::Utf8 { pos: _, err } => ArrowError::CsvError(format!( + "Encountered UTF-8 error while reading CSV file: {err}" + )), + csv::ErrorKind::UnequalLengths { + expected_len, len, .. + } => ArrowError::CsvError(format!( + "Encountered unequal lengths between records on CSV file. Expected {len} \ + records, found {expected_len} records" + )), + _ => ArrowError::CsvError("Error reading CSV file".to_string()), + } +} diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs new file mode 100644 index 000000000000..36f80ec90a95 --- /dev/null +++ b/arrow-csv/src/reader/mod.rs @@ -0,0 +1,2682 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CSV Reader +//! +//! # Basic Usage +//! +//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are +//! loaded in batches and are then converted from row-based data to columnar data. +//! +//! Example: +//! +//! ``` +//! # use arrow_schema::*; +//! # use arrow_csv::{Reader, ReaderBuilder}; +//! # use std::fs::File; +//! # use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("city", DataType::Utf8, false), +//! Field::new("lat", DataType::Float64, false), +//! Field::new("lng", DataType::Float64, false), +//! ]); +//! +//! let file = File::open("test/data/uk_cities.csv").unwrap(); +//! +//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); +//! let batch = csv.next().unwrap().unwrap(); +//! ``` +//! +//! # Async Usage +//! +//! The lower-level [`Decoder`] can be integrated with various forms of async data streams, +//! and is designed to be agnostic to the various different kinds of async IO primitives found +//! within the Rust ecosystem. +//! +//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes` +//! +//! ``` +//! # use std::task::{Poll, ready}; +//! # use bytes::{Buf, Bytes}; +//! # use arrow_schema::ArrowError; +//! # use futures::stream::{Stream, StreamExt}; +//! # use arrow_array::RecordBatch; +//! # use arrow_csv::reader::Decoder; +//! # +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! if let Some(b) = ready!(input.poll_next_unpin(cx)) { +//! buffered = b; +//! } +//! // Note: don't break on `None` as the decoder needs +//! // to be called with an empty array to delimit the +//! // final record +//! } +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(0) => break, +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! buffered.advance(decoded); +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! +//! ``` +//! +//! In a similar vein, it can also be used with tokio-based IO primitives +//! +//! ``` +//! # use std::pin::Pin; +//! # use std::task::{Poll, ready}; +//! # use futures::Stream; +//! # use tokio::io::AsyncBufRead; +//! # use arrow_array::RecordBatch; +//! # use arrow_csv::reader::Decoder; +//! # use arrow_schema::ArrowError; +//! fn decode_stream( +//! mut decoder: Decoder, +//! mut reader: R, +//! ) -> impl Stream> { +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) { +//! Ok(b) => b, +//! Err(e) => return Poll::Ready(Some(Err(e.into()))), +//! }; +//! let decoded = match decoder.decode(b) { +//! // Note: the decoder needs to be called with an empty +//! // array to delimit the final record +//! Ok(0) => break, +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! Pin::new(&mut reader).consume(decoded); +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! + +mod records; + +use arrow_array::builder::{NullBuilder, PrimitiveBuilder}; +use arrow_array::types::*; +use arrow_array::*; +use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser}; +use arrow_schema::*; +use chrono::{TimeZone, Utc}; +use csv::StringRecord; +use lazy_static::lazy_static; +use regex::{Regex, RegexSet}; +use std::fmt::{self, Debug}; +use std::fs::File; +use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; +use std::sync::Arc; + +use crate::map_csv_error; +use crate::reader::records::{RecordDecoder, StringRecords}; +use arrow_array::timezone::Tz; + +lazy_static! { + /// Order should match [`InferredDataType`] + static ref REGEX_SET: RegexSet = RegexSet::new([ + r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN + r"^-?(\d+)$", //INTEGER + r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", //DECIMAL + r"^\d{4}-\d\d-\d\d$", //DATE32 + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond) + ]).unwrap(); +} + +/// A wrapper over `Option` to check if the value is `NULL`. +#[derive(Debug, Clone, Default)] +struct NullRegex(Option); + +impl NullRegex { + /// Returns true if the value should be considered as `NULL` according to + /// the provided regular expression. + #[inline] + fn is_null(&self, s: &str) -> bool { + match &self.0 { + Some(r) => r.is_match(s), + None => s.is_empty(), + } + } +} + +#[derive(Default, Copy, Clone)] +struct InferredDataType { + /// Packed booleans indicating type + /// + /// 0 - Boolean + /// 1 - Integer + /// 2 - Float64 + /// 3 - Date32 + /// 4 - Timestamp(Second) + /// 5 - Timestamp(Millisecond) + /// 6 - Timestamp(Microsecond) + /// 7 - Timestamp(Nanosecond) + /// 8 - Utf8 + packed: u16, +} + +impl InferredDataType { + /// Returns the inferred data type + fn get(&self) -> DataType { + match self.packed { + 0 => DataType::Null, + 1 => DataType::Boolean, + 2 => DataType::Int64, + 4 | 6 => DataType::Float64, // Promote Int64 to Float64 + b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() { + // Promote to highest precision temporal type + 8 => DataType::Timestamp(TimeUnit::Nanosecond, None), + 9 => DataType::Timestamp(TimeUnit::Microsecond, None), + 10 => DataType::Timestamp(TimeUnit::Millisecond, None), + 11 => DataType::Timestamp(TimeUnit::Second, None), + 12 => DataType::Date32, + _ => unreachable!(), + }, + _ => DataType::Utf8, + } + } + + /// Updates the [`InferredDataType`] with the given string + fn update(&mut self, string: &str) { + self.packed |= if string.starts_with('"') { + 1 << 8 // Utf8 + } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() { + 1 << m + } else { + 1 << 8 // Utf8 + } + } +} + +/// The format specification for the CSV file +#[derive(Debug, Clone, Default)] +pub struct Format { + header: bool, + delimiter: Option, + escape: Option, + quote: Option, + terminator: Option, + comment: Option, + null_regex: NullRegex, + truncated_rows: bool, +} + +impl Format { + /// Specify whether the CSV file has a header, defaults to `true` + /// + /// When `true`, the first row of the CSV file is treated as a header row + pub fn with_header(mut self, has_header: bool) -> Self { + self.header = has_header; + self + } + + /// Specify a custom delimiter character, defaults to comma `','` + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = Some(delimiter); + self + } + + /// Specify an escape character, defaults to `None` + pub fn with_escape(mut self, escape: u8) -> Self { + self.escape = Some(escape); + self + } + + /// Specify a custom quote character, defaults to double quote `'"'` + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = Some(quote); + self + } + + /// Specify a custom terminator character, defaults to CRLF + pub fn with_terminator(mut self, terminator: u8) -> Self { + self.terminator = Some(terminator); + self + } + + /// Specify a comment character, defaults to `None` + /// + /// Lines starting with this character will be ignored + pub fn with_comment(mut self, comment: u8) -> Self { + self.comment = Some(comment); + self + } + + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.null_regex = NullRegex(Some(null_regex)); + self + } + + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.truncated_rows = allow; + self + } + + /// Infer schema of CSV records from the provided `reader` + /// + /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` + /// records are read to infer the schema + /// + /// Returns inferred schema and number of records read + pub fn infer_schema( + &self, + reader: R, + max_records: Option, + ) -> Result<(Schema, usize), ArrowError> { + let mut csv_reader = self.build_reader(reader); + + // get or create header names + // when has_header is false, creates default column names with column_ prefix + let headers: Vec = if self.header { + let headers = &csv_reader.headers().map_err(map_csv_error)?.clone(); + headers.iter().map(|s| s.to_string()).collect() + } else { + let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len(); + (0..*first_record_count) + .map(|i| format!("column_{}", i + 1)) + .collect() + }; + + let header_length = headers.len(); + // keep track of inferred field types + let mut column_types: Vec = vec![Default::default(); header_length]; + + let mut records_count = 0; + + let mut record = StringRecord::new(); + let max_records = max_records.unwrap_or(usize::MAX); + while records_count < max_records { + if !csv_reader.read_record(&mut record).map_err(map_csv_error)? { + break; + } + records_count += 1; + + // Note since we may be looking at a sample of the data, we make the safe assumption that + // they could be nullable + for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) { + if let Some(string) = record.get(i) { + if !self.null_regex.is_null(string) { + column_type.update(string) + } + } + } + } + + // build schema from inference results + let fields: Fields = column_types + .iter() + .zip(&headers) + .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true)) + .collect(); + + Ok((Schema::new(fields), records_count)) + } + + /// Build a [`csv::Reader`] for this [`Format`] + fn build_reader(&self, reader: R) -> csv::Reader { + let mut builder = csv::ReaderBuilder::new(); + builder.has_headers(self.header); + builder.flexible(self.truncated_rows); + + if let Some(c) = self.delimiter { + builder.delimiter(c); + } + builder.escape(self.escape); + if let Some(c) = self.quote { + builder.quote(c); + } + if let Some(t) = self.terminator { + builder.terminator(csv::Terminator::Any(t)); + } + if let Some(comment) = self.comment { + builder.comment(Some(comment)); + } + builder.from_reader(reader) + } + + /// Build a [`csv_core::Reader`] for this [`Format`] + fn build_parser(&self) -> csv_core::Reader { + let mut builder = csv_core::ReaderBuilder::new(); + builder.escape(self.escape); + builder.comment(self.comment); + + if let Some(c) = self.delimiter { + builder.delimiter(c); + } + if let Some(c) = self.quote { + builder.quote(c); + } + if let Some(t) = self.terminator { + builder.terminator(csv_core::Terminator::Any(t)); + } + builder.build() + } +} + +/// Infer the schema of a CSV file by reading through the first n records of the file, +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// If `max_read_records` is not set, the whole file is read to infer its schema. +/// +/// Return inferred schema and number of records used for inference. This function does not change +/// reader cursor offset. +/// +/// The inferred schema will always have each field set as nullable. +#[deprecated(note = "Use Format::infer_schema")] +#[allow(deprecated)] +pub fn infer_file_schema( + mut reader: R, + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result<(Schema, usize), ArrowError> { + let saved_offset = reader.stream_position()?; + let r = infer_reader_schema(&mut reader, delimiter, max_read_records, has_header)?; + // return the reader seek back to the start + reader.seek(SeekFrom::Start(saved_offset))?; + Ok(r) +} + +/// Infer schema of CSV records provided by struct that implements `Read` trait. +/// +/// `max_read_records` controlling the maximum number of records to read. If `max_read_records` is +/// not set, all records are read to infer the schema. +/// +/// Return inferred schema and number of records used for inference. +#[deprecated(note = "Use Format::infer_schema")] +pub fn infer_reader_schema( + reader: R, + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result<(Schema, usize), ArrowError> { + let format = Format { + delimiter: Some(delimiter), + header: has_header, + ..Default::default() + }; + format.infer_schema(reader, max_read_records) +} + +/// Infer schema from a list of CSV files by reading through first n records +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// Files will be read in the given order until n records have been reached. +/// +/// If `max_read_records` is not set, all files will be read fully to infer the schema. +pub fn infer_schema_from_files( + files: &[String], + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result { + let mut schemas = vec![]; + let mut records_to_read = max_read_records.unwrap_or(usize::MAX); + let format = Format { + delimiter: Some(delimiter), + header: has_header, + ..Default::default() + }; + + for fname in files.iter() { + let f = File::open(fname)?; + let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?; + if records_read == 0 { + continue; + } + schemas.push(schema.clone()); + records_to_read -= records_read; + if records_to_read == 0 { + break; + } + } + + Schema::try_merge(schemas) +} + +// optional bounds of the reader, of the form (min line, max line). +type Bounds = Option<(usize, usize)>; + +/// CSV file reader using [`std::io::BufReader`] +pub type Reader = BufReader>; + +/// CSV file reader +pub struct BufReader { + /// File reader + reader: R, + + /// The decoder + decoder: Decoder, +} + +impl fmt::Debug for BufReader +where + R: BufRead, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Reader") + .field("decoder", &self.decoder) + .finish() + } +} + +impl Reader { + /// Returns the schema of the reader, useful for getting the schema without reading + /// record batches + pub fn schema(&self) -> SchemaRef { + match &self.decoder.projection { + Some(projection) => { + let fields = self.decoder.schema.fields(); + let projected = projection.iter().map(|i| fields[*i].clone()); + Arc::new(Schema::new(projected.collect::())) + } + None => self.decoder.schema.clone(), + } + } +} + +impl BufReader { + fn read(&mut self) -> Result, ArrowError> { + loop { + let buf = self.reader.fill_buf()?; + let decoded = self.decoder.decode(buf)?; + self.reader.consume(decoded); + // Yield if decoded no bytes or the decoder is full + // + // The capacity check avoids looping around and potentially + // blocking reading data in fill_buf that isn't needed + // to flush the next batch + if decoded == 0 || self.decoder.capacity() == 0 { + break; + } + } + + self.decoder.flush() + } +} + +impl Iterator for BufReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.read().transpose() + } +} + +impl RecordBatchReader for BufReader { + fn schema(&self) -> SchemaRef { + self.decoder.schema.clone() + } +} + +/// A push-based interface for decoding CSV data from an arbitrary byte stream +/// +/// See [`Reader`] for a higher-level interface for interface with [`Read`] +/// +/// The push-based interface facilitates integration with sources that yield arbitrarily +/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from +/// object storage +/// +/// ``` +/// # use std::io::BufRead; +/// # use arrow_array::RecordBatch; +/// # use arrow_csv::ReaderBuilder; +/// # use arrow_schema::{ArrowError, SchemaRef}; +/// # +/// fn read_from_csv( +/// mut reader: R, +/// schema: SchemaRef, +/// batch_size: usize, +/// ) -> Result>, ArrowError> { +/// let mut decoder = ReaderBuilder::new(schema) +/// .with_batch_size(batch_size) +/// .build_decoder(); +/// +/// let mut next = move || { +/// loop { +/// let buf = reader.fill_buf()?; +/// let decoded = decoder.decode(buf)?; +/// if decoded == 0 { +/// break; +/// } +/// +/// // Consume the number of bytes read +/// reader.consume(decoded); +/// } +/// decoder.flush() +/// }; +/// Ok(std::iter::from_fn(move || next().transpose())) +/// } +/// ``` +#[derive(Debug)] +pub struct Decoder { + /// Explicit schema for the CSV file + schema: SchemaRef, + + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, + + /// Number of records per batch + batch_size: usize, + + /// Rows to skip + to_skip: usize, + + /// Current line number + line_number: usize, + + /// End line number + end: usize, + + /// A decoder for [`StringRecords`] + record_decoder: RecordDecoder, + + /// Check if the string matches this pattern for `NULL`. + null_regex: NullRegex, +} + +impl Decoder { + /// Decode records from `buf` returning the number of bytes read + /// + /// This method returns once `batch_size` objects have been parsed since the + /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes + /// should be included in the next call to [`Self::decode`] + /// + /// There is no requirement that `buf` contains a whole number of records, facilitating + /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or + /// network sources such as object storage + pub fn decode(&mut self, buf: &[u8]) -> Result { + if self.to_skip != 0 { + // Skip in units of `to_read` to avoid over-allocating buffers + let to_skip = self.to_skip.min(self.batch_size); + let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?; + self.to_skip -= skipped; + self.record_decoder.clear(); + return Ok(bytes); + } + + let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len(); + let (_, bytes) = self.record_decoder.decode(buf, to_read)?; + Ok(bytes) + } + + /// Flushes the currently buffered data to a [`RecordBatch`] + /// + /// This should only be called after [`Self::decode`] has returned `Ok(0)`, + /// otherwise may return an error if part way through decoding a record + /// + /// Returns `Ok(None)` if no buffered data + pub fn flush(&mut self) -> Result, ArrowError> { + if self.record_decoder.is_empty() { + return Ok(None); + } + + let rows = self.record_decoder.flush()?; + let batch = parse( + &rows, + self.schema.fields(), + Some(self.schema.metadata.clone()), + self.projection.as_ref(), + self.line_number, + &self.null_regex, + )?; + self.line_number += rows.len(); + Ok(Some(batch)) + } + + /// Returns the number of records that can be read before requiring a call to [`Self::flush`] + pub fn capacity(&self) -> usize { + self.batch_size - self.record_decoder.len() + } +} + +/// Parses a slice of [`StringRecords`] into a [RecordBatch] +fn parse( + rows: &StringRecords<'_>, + fields: &Fields, + metadata: Option>, + projection: Option<&Vec>, + line_number: usize, + null_regex: &NullRegex, +) -> Result { + let projection: Vec = match projection { + Some(v) => v.clone(), + None => fields.iter().enumerate().map(|(i, _)| i).collect(), + }; + + let arrays: Result, _> = projection + .iter() + .map(|i| { + let i = *i; + let field = &fields[i]; + match field.data_type() { + DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex), + DataType::Decimal128(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Decimal256(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int16 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt16 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Float32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Float64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Date32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Date64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time32(TimeUnit::Second) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time32(TimeUnit::Millisecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time64(TimeUnit::Microsecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time64(TimeUnit::Nanosecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Timestamp(TimeUnit::Second, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Null => Ok(Arc::new({ + let mut builder = NullBuilder::new(); + builder.append_nulls(rows.len()); + builder.finish() + }) as ArrayRef), + DataType::Utf8 => Ok(Arc::new( + rows.iter() + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) + .collect::(), + ) as ArrayRef), + DataType::Utf8View => Ok(Arc::new( + rows.iter() + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) + .collect::(), + ) as ArrayRef), + DataType::Dictionary(key_type, value_type) + if value_type.as_ref() == &DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + _ => Err(ArrowError::ParseError(format!( + "Unsupported dictionary key type {key_type:?}" + ))), + } + } + other => Err(ArrowError::ParseError(format!( + "Unsupported data type {other:?}" + ))), + } + }) + .collect(); + + let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect(); + + let projected_schema = Arc::new(match metadata { + None => Schema::new(projected_fields), + Some(metadata) => Schema::new_with_metadata(projected_fields, metadata), + }); + + arrays.and_then(|arr| { + RecordBatch::try_new_with_options( + projected_schema, + arr, + &RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(rows.len())), + ) + }) +} + +fn parse_bool(string: &str) -> Option { + if string.eq_ignore_ascii_case("false") { + Some(false) + } else if string.eq_ignore_ascii_case("true") { + Some(true) + } else { + None + } +} + +// parse the column string to an Arrow Array +fn build_decimal_array( + _line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + precision: u8, + scale: i8, + null_regex: &NullRegex, +) -> Result { + let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); + for row in rows.iter() { + let s = row.get(col_idx); + if null_regex.is_null(s) { + // append null + decimal_builder.append_null(); + } else { + let decimal_value: Result = parse_decimal::(s, precision, scale); + match decimal_value { + Ok(v) => { + decimal_builder.append_value(v); + } + Err(e) => { + return Err(e); + } + } + } + } + Ok(Arc::new( + decimal_builder + .finish() + .with_precision_and_scale(precision, scale)?, + )) +} + +// parses a specific column (col_idx) into an Arrow Array. +fn build_primitive_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + null_regex: &NullRegex, +) -> Result { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + + match T::parse(s) { + Some(e) => Ok(Some(e)), + None => Err(ArrowError::ParseError(format!( + // TODO: we should surface the underlying error here. + "Error while parsing value {} for column {} at line {}", + s, + col_idx, + line_number + row_index + ))), + } + }) + .collect::, ArrowError>>() + .map(|e| Arc::new(e) as ArrayRef) +} + +fn build_timestamp_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + timezone: Option<&str>, + null_regex: &NullRegex, +) -> Result { + Ok(Arc::new(match timezone { + Some(timezone) => { + let tz: Tz = timezone.parse()?; + build_timestamp_array_impl::(line_number, rows, col_idx, &tz, null_regex)? + .with_timezone(timezone) + } + None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc, null_regex)?, + })) +} + +fn build_timestamp_array_impl( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + timezone: &Tz, + null_regex: &NullRegex, +) -> Result, ArrowError> { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + + let date = string_to_datetime(timezone, s) + .and_then(|date| match T::UNIT { + TimeUnit::Second => Ok(date.timestamp()), + TimeUnit::Millisecond => Ok(date.timestamp_millis()), + TimeUnit::Microsecond => Ok(date.timestamp_micros()), + TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| { + ArrowError::ParseError(format!( + "{} would overflow 64-bit signed nanoseconds", + date.to_rfc3339(), + )) + }), + }) + .map_err(|e| { + ArrowError::ParseError(format!( + "Error parsing column {col_idx} at line {}: {}", + line_number + row_index, + e + )) + })?; + Ok(Some(date)) + }) + .collect() +} + +// parses a specific column (col_idx) into an Arrow Array. +fn build_boolean_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + null_regex: &NullRegex, +) -> Result { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + let parsed = parse_bool(s); + match parsed { + Some(e) => Ok(Some(e)), + None => Err(ArrowError::ParseError(format!( + // TODO: we should surface the underlying error here. + "Error while parsing value {} for column {} at line {}", + s, + col_idx, + line_number + row_index + ))), + } + }) + .collect::>() + .map(|e| Arc::new(e) as ArrayRef) +} + +/// CSV file reader builder +#[derive(Debug)] +pub struct ReaderBuilder { + /// Schema of the CSV file + schema: SchemaRef, + /// Format of the CSV file + format: Format, + /// Batch size (number of records to load each time) + /// + /// The default batch size when using the `ReaderBuilder` is 1024 records + batch_size: usize, + /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF. + bounds: Bounds, + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, +} + +impl ReaderBuilder { + /// Create a new builder for configuring CSV parsing options. + /// + /// To convert a builder into a reader, call `ReaderBuilder::build` + /// + /// # Example + /// + /// ``` + /// # use arrow_csv::{Reader, ReaderBuilder}; + /// # use std::fs::File; + /// # use std::io::Seek; + /// # use std::sync::Arc; + /// # use arrow_csv::reader::Format; + /// # + /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + /// // Infer the schema with the first 100 records + /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap(); + /// file.rewind().unwrap(); + /// + /// // create a builder + /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); + /// ``` + pub fn new(schema: SchemaRef) -> ReaderBuilder { + Self { + schema, + format: Format::default(), + batch_size: 1024, + bounds: None, + projection: None, + } + } + + /// Set whether the CSV file has headers + #[deprecated(note = "Use with_header")] + #[doc(hidden)] + pub fn has_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; + self + } + + /// Set whether the CSV file has a header + pub fn with_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; + self + } + + /// Overrides the [Format] of this [ReaderBuilder] + pub fn with_format(mut self, format: Format) -> Self { + self.format = format; + self + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.format.delimiter = Some(delimiter); + self + } + + /// Set the given character as the CSV file's escape character + pub fn with_escape(mut self, escape: u8) -> Self { + self.format.escape = Some(escape); + self + } + + /// Set the given character as the CSV file's quote character, by default it is double quote + pub fn with_quote(mut self, quote: u8) -> Self { + self.format.quote = Some(quote); + self + } + + /// Provide a custom terminator character, defaults to CRLF + pub fn with_terminator(mut self, terminator: u8) -> Self { + self.format.terminator = Some(terminator); + self + } + + /// Provide a comment character, lines starting with this character will be ignored + pub fn with_comment(mut self, comment: u8) -> Self { + self.format.comment = Some(comment); + self + } + + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.format.null_regex = NullRegex(Some(null_regex)); + self + } + + /// Set the batch size (number of records to load at one time) + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the bounds over which to scan the reader. + /// `start` and `end` are line numbers. + pub fn with_bounds(mut self, start: usize, end: usize) -> Self { + self.bounds = Some((start, end)); + self + } + + /// Set the reader's column projection + pub fn with_projection(mut self, projection: Vec) -> Self { + self.projection = Some(projection); + self + } + + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.format.truncated_rows = allow; + self + } + + /// Create a new `Reader` from a non-buffered reader + /// + /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional + /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`] + pub fn build(self, reader: R) -> Result, ArrowError> { + self.build_buffered(StdBufReader::new(reader)) + } + + /// Create a new `BufReader` from a buffered reader + pub fn build_buffered(self, reader: R) -> Result, ArrowError> { + Ok(BufReader { + reader, + decoder: self.build_decoder(), + }) + } + + /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream + pub fn build_decoder(self) -> Decoder { + let delimiter = self.format.build_parser(); + let record_decoder = RecordDecoder::new( + delimiter, + self.schema.fields().len(), + self.format.truncated_rows, + ); + + let header = self.format.header as usize; + + let (start, end) = match self.bounds { + Some((start, end)) => (start + header, end + header), + None => (header, usize::MAX), + }; + + Decoder { + schema: self.schema, + to_skip: start, + record_decoder, + line_number: start, + end, + projection: self.projection, + batch_size: self.batch_size, + null_regex: self.format.null_regex, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Cursor, Write}; + use tempfile::NamedTempFile; + + use arrow_array::cast::AsArray; + + #[test] + fn test_csv() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap(); + assert_eq!(schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch.column(1).as_primitive::(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch.column(0).as_string::(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_schema_metadata() { + let mut metadata = std::collections::HashMap::new(); + metadata.insert("foo".to_owned(), "bar".to_owned()); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ], + metadata.clone(), + )); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap(); + assert_eq!(schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + assert_eq!(&metadata, batch.schema().metadata()); + } + + #[test] + fn test_csv_reader_with_decimal() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Decimal128(38, 6), false), + Field::new("lng", DataType::Decimal256(76, 6), false), + ])); + + let file = File::open("test/data/decimal_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema).build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("57.653484", lat.value_as_string(0)); + assert_eq!("53.002666", lat.value_as_string(1)); + assert_eq!("52.412811", lat.value_as_string(2)); + assert_eq!("51.481583", lat.value_as_string(3)); + assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("50.760000", lat.value_as_string(5)); + assert_eq!("0.123000", lat.value_as_string(6)); + assert_eq!("123.000000", lat.value_as_string(7)); + assert_eq!("123.000000", lat.value_as_string(8)); + assert_eq!("-50.760000", lat.value_as_string(9)); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("-3.335724", lng.value_as_string(0)); + assert_eq!("-2.179404", lng.value_as_string(1)); + assert_eq!("-1.778197", lng.value_as_string(2)); + assert_eq!("-3.179090", lng.value_as_string(3)); + assert_eq!("-3.179090", lng.value_as_string(4)); + assert_eq!("0.290472", lng.value_as_string(5)); + assert_eq!("0.290472", lng.value_as_string(6)); + assert_eq!("0.290472", lng.value_as_string(7)); + assert_eq!("0.290472", lng.value_as_string(8)); + assert_eq!("0.290472", lng.value_as_string(9)); + } + + #[test] + fn test_csv_from_buf_reader() { + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); + + let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + let file_without_headers = File::open("test/data/uk_cities.csv").unwrap(); + let both_files = file_with_headers + .chain(Cursor::new("\n".to_string())) + .chain(file_without_headers); + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_header(true) + .build(both_files) + .unwrap(); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(74, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + } + + #[test] + fn test_csv_with_schema_inference() { + let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None) + .unwrap(); + + file.rewind().unwrap(); + let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true); + + let mut csv = builder.build(file).unwrap(); + let expected_schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, true), + Field::new("lat", DataType::Float64, true), + Field::new("lng", DataType::Float64, true), + ]); + assert_eq!(Arc::new(expected_schema), csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_with_schema_inference_no_headers() { + let mut file = File::open("test/data/uk_cities.csv").unwrap(); + + let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); + + // csv field names should be 'column_{number}' + let schema = csv.schema(); + assert_eq!("column_1", schema.field(0).name()); + assert_eq!("column_2", schema.field(1).name()); + assert_eq!("column_3", schema.field(2).name()); + let batch = csv.next().unwrap().unwrap(); + let batch_schema = batch.schema(); + + assert_eq!(schema, batch_schema); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_builder_with_bounds() { + let mut file = File::open("test/data/uk_cities.csv").unwrap(); + + // Set the bounds to the lines 0, 1 and 2. + let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_bounds(0, 2) + .build(file) + .unwrap(); + let batch = csv.next().unwrap().unwrap(); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // The value on line 0 is within the bounds + assert_eq!("Elgin, Scotland, the UK", city.value(0)); + + // The value on line 13 is outside of the bounds. Therefore + // the call to .value() will panic. + let result = std::panic::catch_unwind(|| city.value(13)); + assert!(result.is_err()); + } + + #[test] + fn test_csv_with_projection() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_projection(vec![0, 1]) + .build(file) + .unwrap(); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(projected_schema, batch.schema()); + assert_eq!(37, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + } + + #[test] + fn test_csv_with_dictionary() { + let schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_projection(vec![0, 1]) + .build(file) + .unwrap(); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(projected_schema, batch.schema()); + assert_eq!(37, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + + let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap(); + let strings = strings.as_string::(); + + assert_eq!(strings.value(0), "Elgin, Scotland, the UK"); + assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK"); + assert_eq!(strings.value(29), "Uckfield, East Sussex, UK"); + } + + #[test] + fn test_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, false), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, false), + ])); + + let file = File::open("test/data/null_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(!batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_init_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ])); + let file = File::open("test/data/init_null_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_init_nulls_with_inference() { + let format = Format::default().with_header(true).with_delimiter(b','); + + let mut file = File::open("test/data/init_null_test.csv").unwrap(); + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ]); + assert_eq!(schema, expected_schema); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_custom_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ])); + + let file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .with_null_regex(null_regex) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + // "nil"s should be NULL + assert!(batch.column(0).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(batch.column(3).is_null(4)); + assert!(batch.column(2).is_null(3)); + assert!(!batch.column(2).is_null(4)); + } + + #[test] + fn test_nulls_with_inference() { + let mut file = File::open("test/data/various_types.csv").unwrap(); + let format = Format::default().with_header(true).with_delimiter(b'|'); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3, 4, 5]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(7, batch.num_rows()); + assert_eq!(6, batch.num_columns()); + + let schema = batch.schema(); + + assert_eq!(&DataType::Int64, schema.field(0).data_type()); + assert_eq!(&DataType::Float64, schema.field(1).data_type()); + assert_eq!(&DataType::Float64, schema.field(2).data_type()); + assert_eq!(&DataType::Boolean, schema.field(3).data_type()); + assert_eq!(&DataType::Date32, schema.field(4).data_type()); + assert_eq!( + &DataType::Timestamp(TimeUnit::Second, None), + schema.field(5).data_type() + ); + + let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect(); + assert_eq!( + names, + vec![ + "c_int", + "c_float", + "c_string", + "c_bool", + "c_date", + "c_datetime" + ] + ); + + assert!(schema.field(0).is_nullable()); + assert!(schema.field(1).is_nullable()); + assert!(schema.field(2).is_nullable()); + assert!(schema.field(3).is_nullable()); + assert!(schema.field(4).is_nullable()); + assert!(schema.field(5).is_nullable()); + + assert!(!batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_custom_nulls_with_inference() { + let mut file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let format = Format::default() + .with_header(true) + .with_null_regex(null_regex); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ]); + + assert_eq!(schema, expected_schema); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(5, batch.num_rows()); + assert_eq!(4, batch.num_columns()); + + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_scientific_notation_with_inference() { + let mut file = File::open("test/data/scientific_notation_test.csv").unwrap(); + let format = Format::default().with_header(false).with_delimiter(b','); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + let schema = batch.schema(); + + assert_eq!(&DataType::Float64, schema.field(0).data_type()); + } + + #[test] + fn test_parse_invalid_csv() { + let file = File::open("test/data/various_types_invalid.csv").unwrap(); + + let schema = Schema::new(vec![ + Field::new("c_int", DataType::UInt64, false), + Field::new("c_float", DataType::Float32, false), + Field::new("c_string", DataType::Utf8, false), + Field::new("c_bool", DataType::Boolean, false), + ]); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(true) + .with_delimiter(b'|') + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + match csv.next() { + Some(e) => match e { + Err(e) => assert_eq!( + "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")", + format!("{e:?}") + ), + Ok(_) => panic!("should have failed"), + }, + None => panic!("should have failed"), + } + } + + /// Infer the data type of a record + fn infer_field_schema(string: &str) -> DataType { + let mut v = InferredDataType::default(); + v.update(string); + v.get() + } + + #[test] + fn test_infer_field_schema() { + assert_eq!(infer_field_schema("A"), DataType::Utf8); + assert_eq!(infer_field_schema("\"123\""), DataType::Utf8); + assert_eq!(infer_field_schema("10"), DataType::Int64); + assert_eq!(infer_field_schema("10.2"), DataType::Float64); + assert_eq!(infer_field_schema(".2"), DataType::Float64); + assert_eq!(infer_field_schema("2."), DataType::Float64); + assert_eq!(infer_field_schema("true"), DataType::Boolean); + assert_eq!(infer_field_schema("trUe"), DataType::Boolean); + assert_eq!(infer_field_schema("false"), DataType::Boolean); + assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32); + assert_eq!( + infer_field_schema("2020-11-08T14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!( + infer_field_schema("2020-11-08 14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!( + infer_field_schema("2020-11-08 14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!(infer_field_schema("-5.13"), DataType::Float64); + assert_eq!(infer_field_schema("0.1300"), DataType::Float64); + assert_eq!( + infer_field_schema("2021-12-19 13:12:30.921"), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + assert_eq!( + infer_field_schema("2021-12-19T13:12:30.123456789"), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + } + + #[test] + fn parse_date32() { + assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0); + assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336); + assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004); + } + + #[test] + fn parse_time() { + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 AM"), + Some(601_123_456_789) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 am"), + Some(601_123_456) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 PM"), + Some(51_001_120) + ); + assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); + } + + #[test] + fn parse_date64() { + assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0); + assert_eq!( + Date64Type::parse("2018-11-13T17:11:10").unwrap(), + 1542129070000 + ); + assert_eq!( + Date64Type::parse("2018-11-13T17:11:10.011").unwrap(), + 1542129070011 + ); + assert_eq!( + Date64Type::parse("1900-02-28T12:34:56").unwrap(), + -2203932304000 + ); + assert_eq!( + Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(), + -2203932304000 + ); + assert_eq!( + Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(), + -2203932304000 - (30 * 60 * 1000) + ); + } + + fn test_parse_timestamp_impl( + timezone: Option>, + expected: &[i64], + ) { + let csv = [ + "1970-01-01T00:00:00", + "1970-01-01T00:00:00Z", + "1970-01-01T00:00:00+02:00", + ] + .join("\n"); + let schema = Arc::new(Schema::new(vec![Field::new( + "field", + DataType::Timestamp(T::UNIT, timezone.clone()), + true, + )])); + + let mut decoder = ReaderBuilder::new(schema).build_decoder(); + + let decoded = decoder.decode(csv.as_bytes()).unwrap(); + assert_eq!(decoded, csv.len()); + decoder.decode(&[]).unwrap(); + + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 3); + let col = batch.column(0).as_primitive::(); + assert_eq!(col.values(), expected); + assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone)); + } + + #[test] + fn test_parse_timestamp() { + test_parse_timestamp_impl::(None, &[0, 0, -7_200_000_000_000]); + test_parse_timestamp_impl::( + Some("+00:00".into()), + &[0, 0, -7_200_000_000_000], + ); + test_parse_timestamp_impl::( + Some("-05:00".into()), + &[18_000_000_000_000, 0, -7_200_000_000_000], + ); + test_parse_timestamp_impl::( + Some("-03".into()), + &[10_800_000_000, 0, -7_200_000_000], + ); + test_parse_timestamp_impl::( + Some("-03".into()), + &[10_800_000, 0, -7_200_000], + ); + test_parse_timestamp_impl::(Some("-03".into()), &[10_800, 0, -7_200]); + } + + #[test] + fn test_infer_schema_from_multiple_files() { + let mut csv1 = NamedTempFile::new().unwrap(); + let mut csv2 = NamedTempFile::new().unwrap(); + let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped + let mut csv4 = NamedTempFile::new().unwrap(); + writeln!(csv1, "c1,c2,c3").unwrap(); + writeln!(csv1, "1,\"foo\",0.5").unwrap(); + writeln!(csv1, "3,\"bar\",1").unwrap(); + writeln!(csv1, "3,\"bar\",2e-06").unwrap(); + // reading csv2 will set c2 to optional + writeln!(csv2, "c1,c2,c3,c4").unwrap(); + writeln!(csv2, "10,,3.14,true").unwrap(); + // reading csv4 will set c3 to optional + writeln!(csv4, "c1,c2,c3").unwrap(); + writeln!(csv4, "10,\"foo\",").unwrap(); + + let schema = infer_schema_from_files( + &[ + csv3.path().to_str().unwrap().to_string(), + csv1.path().to_str().unwrap().to_string(), + csv2.path().to_str().unwrap().to_string(), + csv4.path().to_str().unwrap().to_string(), + ], + b',', + Some(4), // only csv1 and csv2 should be read + true, + ) + .unwrap(); + + assert_eq!(schema.fields().len(), 4); + assert!(schema.field(0).is_nullable()); + assert!(schema.field(1).is_nullable()); + assert!(schema.field(2).is_nullable()); + assert!(schema.field(3).is_nullable()); + + assert_eq!(&DataType::Int64, schema.field(0).data_type()); + assert_eq!(&DataType::Utf8, schema.field(1).data_type()); + assert_eq!(&DataType::Float64, schema.field(2).data_type()); + assert_eq!(&DataType::Boolean, schema.field(3).data_type()); + } + + #[test] + fn test_bounded() { + let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); + let data = [ + vec!["0"], + vec!["1"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["5"], + vec!["6"], + ]; + + let data = data + .iter() + .map(|x| x.join(",")) + .collect::>() + .join("\n"); + let data = data.as_bytes(); + + let reader = std::io::Cursor::new(data); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(2) + .with_projection(vec![0]) + .with_bounds(2, 6) + .build_buffered(reader) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + let a = batch.column(0); + let a = a.as_any().downcast_ref::().unwrap(); + assert_eq!(a, &UInt32Array::from(vec![2, 3])); + + let batch = csv.next().unwrap().unwrap(); + let a = batch.column(0); + let a = a.as_any().downcast_ref::().unwrap(); + assert_eq!(a, &UInt32Array::from(vec![4, 5])); + + assert!(csv.next().is_none()); + } + + #[test] + fn test_empty_projection() { + let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); + let data = [vec!["0"], vec!["1"]]; + + let data = data + .iter() + .map(|x| x.join(",")) + .collect::>() + .join("\n"); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(2) + .with_projection(vec![]) + .build_buffered(Cursor::new(data.as_bytes())) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + assert_eq!(batch.columns().len(), 0); + assert_eq!(batch.num_rows(), 2); + + assert!(csv.next().is_none()); + } + + #[test] + fn test_parsing_bool() { + // Encode the expected behavior of boolean parsing + assert_eq!(Some(true), parse_bool("true")); + assert_eq!(Some(true), parse_bool("tRUe")); + assert_eq!(Some(true), parse_bool("True")); + assert_eq!(Some(true), parse_bool("TRUE")); + assert_eq!(None, parse_bool("t")); + assert_eq!(None, parse_bool("T")); + assert_eq!(None, parse_bool("")); + + assert_eq!(Some(false), parse_bool("false")); + assert_eq!(Some(false), parse_bool("fALse")); + assert_eq!(Some(false), parse_bool("False")); + assert_eq!(Some(false), parse_bool("FALSE")); + assert_eq!(None, parse_bool("f")); + assert_eq!(None, parse_bool("F")); + assert_eq!(None, parse_bool("")); + } + + #[test] + fn test_parsing_float() { + assert_eq!(Some(12.34), Float64Type::parse("12.34")); + assert_eq!(Some(-12.34), Float64Type::parse("-12.34")); + assert_eq!(Some(12.0), Float64Type::parse("12")); + assert_eq!(Some(0.0), Float64Type::parse("0")); + assert_eq!(Some(2.0), Float64Type::parse("2.")); + assert_eq!(Some(0.2), Float64Type::parse(".2")); + assert!(Float64Type::parse("nan").unwrap().is_nan()); + assert!(Float64Type::parse("NaN").unwrap().is_nan()); + assert!(Float64Type::parse("inf").unwrap().is_infinite()); + assert!(Float64Type::parse("inf").unwrap().is_sign_positive()); + assert!(Float64Type::parse("-inf").unwrap().is_infinite()); + assert!(Float64Type::parse("-inf").unwrap().is_sign_negative()); + assert_eq!(None, Float64Type::parse("")); + assert_eq!(None, Float64Type::parse("dd")); + assert_eq!(None, Float64Type::parse("12.34.56")); + } + + #[test] + fn test_non_std_quote() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_quote(b'~'); // default is ", change to ~ + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + csv_writer + .write_fmt(format_args!("~{text1}~,~{text2}~\r\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value5"); + } + + #[test] + fn test_non_std_escape() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_escape(b'\\'); // default is None, change to \ + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value\\\"{index:}"); + csv_writer + .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value\"5"); + } + + #[test] + fn test_non_std_terminator() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_terminator(b'\n'); // default is CRLF, change to LF + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + csv_writer + .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value5"); + } + + #[test] + fn test_header_bounds() { + let csv = "a,b\na,b\na,b\na,b\na,b\n"; + let tests = [ + (None, false, 5), + (None, true, 4), + (Some((0, 4)), false, 4), + (Some((1, 4)), false, 3), + (Some((0, 4)), true, 4), + (Some((1, 4)), true, 3), + ]; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), + ])); + + for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() { + let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header); + if let Some((start, end)) = bounds { + reader = reader.with_bounds(start, end); + } + let b = reader + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + assert_eq!(b.num_rows(), expected, "{idx}"); + } + } + + #[test] + fn test_null_boolean() { + let csv = "true,false\nFalse,True\n,True\nFalse,"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("a", DataType::Boolean, true), + ])); + + let b = ReaderBuilder::new(schema) + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(b.num_rows(), 4); + assert_eq!(b.num_columns(), 2); + + let c = b.column(0).as_boolean(); + assert_eq!(c.null_count(), 1); + assert!(c.value(0)); + assert!(!c.value(1)); + assert!(c.is_null(2)); + assert!(!c.value(3)); + + let c = b.column(1).as_boolean(); + assert_eq!(c.null_count(), 1); + assert!(!c.value(0)); + assert!(c.value(1)); + assert!(c.value(2)); + assert!(c.is_null(3)); + } + + #[test] + fn test_truncated_rows() { + let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(batches.is_ok()); + let batch = batches.unwrap().into_iter().next().unwrap(); + // Empty rows are skipped by the underlying csv parser + assert_eq!(batch.num_rows(), 3); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(false) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"), + _ => false, + }); + } + + #[test] + fn test_truncated_rows_csv() { + let file = File::open("test/data/truncated_rows.csv").unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("Name", DataType::Utf8, true), + Field::new("Age", DataType::UInt32, true), + Field::new("Occupation", DataType::Utf8, true), + Field::new("DOB", DataType::Date32, true), + ])); + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_batch_size(24) + .with_truncated_rows(true); + let csv = reader.build(file).unwrap(); + let batches = csv.collect::, _>>().unwrap(); + + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 4); + let name = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let age = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let occupation = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let dob = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(name.value(0), "A1"); + assert_eq!(name.value(1), "B2"); + assert!(name.is_null(2)); + assert_eq!(name.value(3), "C3"); + assert_eq!(name.value(4), "D4"); + assert_eq!(name.value(5), "E5"); + + assert_eq!(age.value(0), 34); + assert_eq!(age.value(1), 29); + assert!(age.is_null(2)); + assert_eq!(age.value(3), 45); + assert!(age.is_null(4)); + assert_eq!(age.value(5), 31); + + assert_eq!(occupation.value(0), "Engineer"); + assert_eq!(occupation.value(1), "Doctor"); + assert!(occupation.is_null(2)); + assert_eq!(occupation.value(3), "Artist"); + assert!(occupation.is_null(4)); + assert!(occupation.is_null(5)); + + assert_eq!(dob.value(0), 5675); + assert!(dob.is_null(1)); + assert!(dob.is_null(2)); + assert_eq!(dob.value(3), -1858); + assert!(dob.is_null(4)); + assert!(dob.is_null(5)); + } + + #[test] + fn test_truncated_rows_not_nullable_error() { + let data = "a,b,c\n1,2,3\n4,5"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::InvalidArgumentError(e)) => + e.to_string().contains("contains null values"), + _ => false, + }); + } + + #[test] + fn test_buffered() { + let tests = [ + ("test/data/uk_cities.csv", false, 37), + ("test/data/various_types.csv", true, 7), + ("test/data/decimal_test.csv", false, 10), + ]; + + for (path, has_header, expected_rows) in tests { + let (schema, _) = Format::default() + .infer_schema(File::open(path).unwrap(), None) + .unwrap(); + let schema = Arc::new(schema); + + for batch_size in [1, 4] { + for capacity in [1, 3, 7, 100] { + let reader = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_header(has_header) + .build(File::open(path).unwrap()) + .unwrap(); + + let expected = reader.collect::, _>>().unwrap(); + + assert_eq!( + expected.iter().map(|x| x.num_rows()).sum::(), + expected_rows + ); + + let buffered = + std::io::BufReader::with_capacity(capacity, File::open(path).unwrap()); + + let reader = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_header(has_header) + .build_buffered(buffered) + .unwrap(); + + let actual = reader.collect::, _>>().unwrap(); + assert_eq!(expected, actual) + } + } + } + } + + fn err_test(csv: &[u8], expected: &str) { + fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc) { + let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv)); + let b = ReaderBuilder::new(schema) + .with_batch_size(2) + .build_buffered(buffer) + .unwrap(); + let err = b.collect::, _>>().unwrap_err().to_string(); + assert_eq!(err, expected) + } + + let schema_utf8 = Arc::new(Schema::new(vec![ + Field::new("text1", DataType::Utf8, true), + Field::new("text2", DataType::Utf8, true), + ])); + err_test_with_schema(csv, expected, schema_utf8); + + let schema_utf8view = Arc::new(Schema::new(vec![ + Field::new("text1", DataType::Utf8View, true), + Field::new("text2", DataType::Utf8View, true), + ])); + err_test_with_schema(csv, expected, schema_utf8view); + } + + #[test] + fn test_invalid_utf8() { + err_test( + b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,", + "Csv error: Encountered invalid UTF-8 data for line 2 and field 2", + ); + + err_test( + b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,", + "Csv error: Encountered invalid UTF-8 data for line 3 and field 1", + ); + + err_test( + b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF", + "Csv error: Encountered invalid UTF-8 data for line 5 and field 2", + ); + + err_test( + b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF", + "Csv error: Encountered invalid UTF-8 data for line 1 and field 1", + ); + } + + struct InstrumentedRead { + r: R, + fill_count: usize, + fill_sizes: Vec, + } + + impl InstrumentedRead { + fn new(r: R) -> Self { + Self { + r, + fill_count: 0, + fill_sizes: vec![], + } + } + } + + impl Seek for InstrumentedRead { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + self.r.seek(pos) + } + } + + impl Read for InstrumentedRead { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.r.read(buf) + } + } + + impl BufRead for InstrumentedRead { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + self.fill_count += 1; + let buf = self.r.fill_buf()?; + self.fill_sizes.push(buf.len()); + Ok(buf) + } + + fn consume(&mut self, amt: usize) { + self.r.consume(amt) + } + } + + #[test] + fn test_io() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let csv = "foo,bar\nbaz,foo\na,b\nc,d"; + let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes())); + let reader = ReaderBuilder::new(schema) + .with_batch_size(3) + .build_buffered(&mut read) + .unwrap(); + + let batches = reader.collect::, _>>().unwrap(); + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 3); + assert_eq!(batches[1].num_rows(), 1); + + // Expect 4 calls to fill_buf + // 1. Read first 3 rows + // 2. Read final row + // 3. Delimit and flush final row + // 4. Iterator finished + assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]); + assert_eq!(read.fill_count, 4); + } + + #[test] + fn test_inference() { + let cases: &[(&[&str], DataType)] = &[ + (&[], DataType::Null), + (&["false", "12"], DataType::Utf8), + (&["12", "cupcakes"], DataType::Utf8), + (&["12", "12.4"], DataType::Float64), + (&["14050", "24332"], DataType::Int64), + (&["14050.0", "true"], DataType::Utf8), + (&["14050", "2020-03-19 00:00:00"], DataType::Utf8), + (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8), + ( + &["2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + (&["2020-03-19", "2020-03-20"], DataType::Date32), + ( + &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000", + ], + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Microsecond, None), + ), + ( + &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"], + DataType::Timestamp(TimeUnit::Second, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00+02:00", + "2020-03-19 02:00:00Z", + "2020-03-19 02:00:00.12Z", + ], + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00.000000000", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ]; + + for (values, expected) in cases { + let mut t = InferredDataType::default(); + for v in *values { + t.update(v) + } + assert_eq!(&t.get(), expected, "{values:?}") + } + } + + #[test] + fn test_comment() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int8, false), + Field::new("b", DataType::Int8, false), + ]); + + let csv = "# comment1 \n1,2\n#comment2\n11,22"; + let mut read = Cursor::new(csv.as_bytes()); + let reader = ReaderBuilder::new(Arc::new(schema)) + .with_comment(b'#') + .build(&mut read) + .unwrap(); + + let batches = reader.collect::, _>>().unwrap(); + assert_eq!(batches.len(), 1); + let b = batches.first().unwrap(); + assert_eq!(b.num_columns(), 2); + assert_eq!( + b.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &vec![1, 11] + ); + assert_eq!( + b.column(1) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &vec![2, 22] + ); + } + + #[test] + fn test_parse_string_view_single_column() { + let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n"); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Utf8View, + true, + )])); + + let mut decoder = ReaderBuilder::new(schema).build_decoder(); + + let decoded = decoder.decode(csv.as_bytes()).unwrap(); + assert_eq!(decoded, csv.len()); + decoder.decode(&[]).unwrap(); + + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 3); + let col = batch.column(0).as_string_view(); + assert_eq!(col.data_type(), &DataType::Utf8View); + assert_eq!(col.value(0), "foo"); + assert_eq!(col.value(1), "something_cannot_be_inlined"); + assert_eq!(col.value(2), "foobar"); + } + + #[test] + fn test_parse_string_view_multi_column() { + let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n"); + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8View, true), + Field::new("c2", DataType::Utf8View, true), + ])); + + let mut decoder = ReaderBuilder::new(schema).build_decoder(); + + let decoded = decoder.decode(csv.as_bytes()).unwrap(); + assert_eq!(decoded, csv.len()); + decoder.decode(&[]).unwrap(); + + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 3); + let c1 = batch.column(0).as_string_view(); + let c2 = batch.column(1).as_string_view(); + assert_eq!(c1.data_type(), &DataType::Utf8View); + assert_eq!(c2.data_type(), &DataType::Utf8View); + + assert!(!c1.is_null(0)); + assert!(c1.is_null(1)); + assert!(!c1.is_null(2)); + assert_eq!(c1.value(0), "foo"); + assert_eq!(c1.value(2), "foobarfoobar"); + + assert!(c2.is_null(0)); + assert!(!c2.is_null(1)); + assert!(!c2.is_null(2)); + assert_eq!(c2.value(1), "something_cannot_be_inlined"); + assert_eq!(c2.value(2), "bar"); + } +} diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs new file mode 100644 index 000000000000..a07fc9c94ffa --- /dev/null +++ b/arrow-csv/src/reader/records.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use csv_core::{ReadRecordResult, Reader}; + +/// The estimated length of a field in bytes +const AVERAGE_FIELD_SIZE: usize = 8; + +/// The minimum amount of data in a single read +const MIN_CAPACITY: usize = 1024; + +/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`] +#[derive(Debug)] +pub struct RecordDecoder { + delimiter: Reader, + + /// The expected number of fields per row + num_columns: usize, + + /// The current line number + line_number: usize, + + /// Offsets delimiting field start positions + offsets: Vec, + + /// The current offset into `self.offsets` + /// + /// We track this independently of Vec to avoid re-zeroing memory + offsets_len: usize, + + /// The number of fields read for the current record + current_field: usize, + + /// The number of rows buffered + num_rows: usize, + + /// Decoded field data + data: Vec, + + /// Offsets into data + /// + /// We track this independently of Vec to avoid re-zeroing memory + data_len: usize, + + /// Whether rows with less than expected columns are considered valid + /// + /// Default value is false + /// When enabled fills in missing columns with null + truncated_rows: bool, +} + +impl RecordDecoder { + pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self { + Self { + delimiter, + num_columns, + line_number: 1, + offsets: vec![], + offsets_len: 1, // The first offset is always 0 + current_field: 0, + data_len: 0, + data: vec![], + num_rows: 0, + truncated_rows, + } + } + + /// Decodes records from `input` returning the number of records and bytes read + /// + /// Note: this expects to be called with an empty `input` to signal EOF + pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> { + if to_read == 0 { + return Ok((0, 0)); + } + + // Reserve sufficient capacity in offsets + self.offsets + .resize(self.offsets_len + to_read * self.num_columns, 0); + + // The current offset into `input` + let mut input_offset = 0; + + // The number of rows decoded in this pass + let mut read = 0; + + loop { + // Reserve necessary space in output data based on best estimate + let remaining_rows = to_read - read; + let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE; + let estimated_data = capacity.max(MIN_CAPACITY); + self.data.resize(self.data_len + estimated_data, 0); + + // Try to read a record + loop { + let (result, bytes_read, bytes_written, end_positions) = + self.delimiter.read_record( + &input[input_offset..], + &mut self.data[self.data_len..], + &mut self.offsets[self.offsets_len..], + ); + + self.current_field += end_positions; + self.offsets_len += end_positions; + input_offset += bytes_read; + self.data_len += bytes_written; + + match result { + ReadRecordResult::End | ReadRecordResult::InputEmpty => { + // Reached end of input + return Ok((read, input_offset)); + } + // Need to allocate more capacity + ReadRecordResult::OutputFull => break, + ReadRecordResult::OutputEndsFull => { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got more than {}", + self.line_number, self.num_columns, self.current_field + ))); + } + ReadRecordResult::Record => { + if self.current_field != self.num_columns { + if self.truncated_rows && self.current_field < self.num_columns { + // If the number of fields is less than expected, pad with nulls + let fill_count = self.num_columns - self.current_field; + let fill_value = self.offsets[self.offsets_len - 1]; + self.offsets[self.offsets_len..self.offsets_len + fill_count] + .fill(fill_value); + self.offsets_len += fill_count; + } else { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); + } + } + read += 1; + self.current_field = 0; + self.line_number += 1; + self.num_rows += 1; + + if read == to_read { + // Read sufficient rows + return Ok((read, input_offset)); + } + + if input.len() == input_offset { + // Input exhausted, need to read more + // Without this read_record will interpret the empty input + // byte array as indicating the end of the file + return Ok((read, input_offset)); + } + } + } + } + } + } + + /// Returns the current number of buffered records + pub fn len(&self) -> usize { + self.num_rows + } + + /// Returns true if the decoder is empty + pub fn is_empty(&self) -> bool { + self.num_rows == 0 + } + + /// Clears the current contents of the decoder + pub fn clear(&mut self) { + // This does not reset current_field to allow clearing part way through a record + self.offsets_len = 1; + self.data_len = 0; + self.num_rows = 0; + } + + /// Flushes the current contents of the reader + pub fn flush(&mut self) -> Result, ArrowError> { + if self.current_field != 0 { + return Err(ArrowError::CsvError( + "Cannot flush part way through record".to_string(), + )); + } + + // csv_core::Reader writes end offsets relative to the start of the row + // Therefore scan through and offset these based on the cumulative row offsets + let mut row_offset = 0; + self.offsets[1..self.offsets_len] + .chunks_exact_mut(self.num_columns) + .for_each(|row| { + let offset = row_offset; + row.iter_mut().for_each(|x| { + *x += offset; + row_offset = *x; + }); + }); + + // Need to truncate data t1o the actual amount of data read + let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| { + let valid_up_to = e.valid_up_to(); + + // We can't use binary search because of empty fields + let idx = self.offsets[..self.offsets_len] + .iter() + .rposition(|x| *x <= valid_up_to) + .unwrap(); + + let field = idx % self.num_columns + 1; + let line_offset = self.line_number - self.num_rows; + let line = line_offset + idx / self.num_columns; + + ArrowError::CsvError(format!( + "Encountered invalid UTF-8 data for line {line} and field {field}" + )) + })?; + + let offsets = &self.offsets[..self.offsets_len]; + let num_rows = self.num_rows; + + // Reset state + self.offsets_len = 1; + self.data_len = 0; + self.num_rows = 0; + + Ok(StringRecords { + num_rows, + num_columns: self.num_columns, + offsets, + data, + }) + } +} + +/// A collection of parsed, UTF-8 CSV records +#[derive(Debug)] +pub struct StringRecords<'a> { + num_columns: usize, + num_rows: usize, + offsets: &'a [usize], + data: &'a str, +} + +impl<'a> StringRecords<'a> { + fn get(&self, index: usize) -> StringRecord<'a> { + let field_idx = index * self.num_columns; + StringRecord { + data: self.data, + offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1], + } + } + + pub fn len(&self) -> usize { + self.num_rows + } + + pub fn iter(&self) -> impl Iterator> + '_ { + (0..self.num_rows).map(|x| self.get(x)) + } +} + +/// A single parsed, UTF-8 CSV record +#[derive(Debug, Clone, Copy)] +pub struct StringRecord<'a> { + data: &'a str, + offsets: &'a [usize], +} + +impl<'a> StringRecord<'a> { + pub fn get(&self, index: usize) -> &'a str { + let end = self.offsets[index + 1]; + let start = self.offsets[index]; + + // SAFETY: + // Parsing produces offsets at valid byte boundaries + unsafe { self.data.get_unchecked(start..end) } + } +} + +#[cfg(test)] +mod tests { + use crate::reader::records::RecordDecoder; + use csv_core::Reader; + use std::io::{BufRead, BufReader, Cursor}; + + #[test] + fn test_basic() { + let csv = [ + "foo,bar,baz", + "a,b,c", + "12,3,5", + "\"asda\"\"asas\",\"sdffsnsd\", as", + ] + .join("\n"); + + let mut expected = vec![ + vec!["foo", "bar", "baz"], + vec!["a", "b", "c"], + vec!["12", "3", "5"], + vec!["asda\"asas", "sdffsnsd", " as"], + ] + .into_iter(); + + let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes())); + let mut decoder = RecordDecoder::new(Reader::new(), 3, false); + + loop { + let to_read = 3; + let mut read = 0; + loop { + let buf = reader.fill_buf().unwrap(); + let (records, bytes) = decoder.decode(buf, to_read - read).unwrap(); + + reader.consume(bytes); + read += records; + + if read == to_read || bytes == 0 { + break; + } + } + if read == 0 { + break; + } + + let b = decoder.flush().unwrap(); + b.iter().zip(&mut expected).for_each(|(record, expected)| { + let actual = (0..3) + .map(|field_idx| record.get(field_idx)) + .collect::>(); + assert_eq!(actual, expected) + }); + } + assert!(expected.next().is_none()); + } + + #[test] + fn test_invalid_fields() { + let csv = "a,b\nb,c\na\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); + let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string(); + + let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1"; + + assert_eq!(err, expected); + + // Test with initial skip + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); + let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap(); + assert_eq!(skipped, 1); + decoder.clear(); + + let remaining = &csv.as_bytes()[bytes..]; + let err = decoder.decode(remaining, 3).unwrap_err().to_string(); + assert_eq!(err, expected); + } + + #[test] + fn test_skip_insufficient_rows() { + let csv = "a\nv\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 1, false); + let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap(); + assert_eq!(read, 2); + assert_eq!(bytes, csv.len()); + } + + #[test] + fn test_truncated_rows() { + let csv = "a,b\nv\n,1\n,2\n,3\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2, true); + let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap(); + assert_eq!(read, 5); + assert_eq!(bytes, csv.len()); + } +} diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs new file mode 100644 index 000000000000..eae2133a4623 --- /dev/null +++ b/arrow-csv/src/writer.rs @@ -0,0 +1,866 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CSV Writer +//! +//! This CSV writer allows Arrow data (in record batches) to be written as CSV files. +//! The writer does not support writing `ListArray` and `StructArray`. +//! +//! Example: +//! +//! ``` +//! # use arrow_array::*; +//! # use arrow_array::types::*; +//! # use arrow_csv::Writer; +//! # use arrow_schema::*; +//! # use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("c1", DataType::Utf8, false), +//! Field::new("c2", DataType::Float64, true), +//! Field::new("c3", DataType::UInt32, false), +//! Field::new("c4", DataType::Boolean, true), +//! ]); +//! let c1 = StringArray::from(vec![ +//! "Lorem ipsum dolor sit amet", +//! "consectetur adipiscing elit", +//! "sed do eiusmod tempor", +//! ]); +//! let c2 = PrimitiveArray::::from(vec![ +//! Some(123.564532), +//! None, +//! Some(-556132.25), +//! ]); +//! let c3 = PrimitiveArray::::from(vec![3, 2, 1]); +//! let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], +//! ) +//! .unwrap(); +//! +//! let mut output = Vec::with_capacity(1024); +//! +//! let mut writer = Writer::new(&mut output); +//! let batches = vec![&batch, &batch]; +//! for batch in batches { +//! writer.write(batch).unwrap(); +//! } +//! ``` + +use arrow_array::*; +use arrow_cast::display::*; +use arrow_schema::*; +use csv::ByteRecord; +use std::io::Write; + +use crate::map_csv_error; +const DEFAULT_NULL_VALUE: &str = ""; + +/// A CSV writer +#[derive(Debug)] +pub struct Writer { + /// The object to write to + writer: csv::Writer, + /// Whether file should be written with headers, defaults to `true` + has_headers: bool, + /// The date format for date arrays, defaults to RFC3339 + date_format: Option, + /// The datetime format for datetime arrays, defaults to RFC3339 + datetime_format: Option, + /// The timestamp format for timestamp arrays, defaults to RFC3339 + timestamp_format: Option, + /// The timestamp format for timestamp (with timezone) arrays, defaults to RFC3339 + timestamp_tz_format: Option, + /// The time format for time arrays, defaults to RFC3339 + time_format: Option, + /// Is the beginning-of-writer + beginning: bool, + /// The value to represent null entries, defaults to [`DEFAULT_NULL_VALUE`] + null_value: Option, +} + +impl Writer { + /// Create a new CsvWriter from a writable object, with default options + pub fn new(writer: W) -> Self { + let delimiter = b','; + WriterBuilder::new().with_delimiter(delimiter).build(writer) + } + + /// Write a vector of record batches to a writable object + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + let num_columns = batch.num_columns(); + if self.beginning { + if self.has_headers { + let mut headers: Vec = Vec::with_capacity(num_columns); + batch + .schema() + .fields() + .iter() + .for_each(|field| headers.push(field.name().to_string())); + self.writer + .write_record(&headers[..]) + .map_err(map_csv_error)?; + } + self.beginning = false; + } + + let options = FormatOptions::default() + .with_null(self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE)) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()); + + let converters = batch + .columns() + .iter() + .map(|a| { + if a.data_type().is_nested() { + Err(ArrowError::CsvError(format!( + "Nested type {} is not supported in CSV", + a.data_type() + ))) + } else { + ArrayFormatter::try_new(a.as_ref(), &options) + } + }) + .collect::, ArrowError>>()?; + + let mut buffer = String::with_capacity(1024); + let mut byte_record = ByteRecord::with_capacity(1024, converters.len()); + + for row_idx in 0..batch.num_rows() { + byte_record.clear(); + for (col_idx, converter) in converters.iter().enumerate() { + buffer.clear(); + converter.value(row_idx).write(&mut buffer).map_err(|e| { + ArrowError::CsvError(format!( + "Error processing row {}, col {}: {e}", + row_idx + 1, + col_idx + 1 + )) + })?; + byte_record.push_field(buffer.as_bytes()); + } + + self.writer + .write_byte_record(&byte_record) + .map_err(map_csv_error)?; + } + self.writer.flush()?; + + Ok(()) + } + + /// Unwraps this `Writer`, returning the underlying writer. + pub fn into_inner(self) -> W { + // Safe to call `unwrap` since `write` always flushes the writer. + self.writer.into_inner().unwrap() + } +} + +impl RecordBatchWriter for Writer { + fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + self.write(batch) + } + + fn close(self) -> Result<(), ArrowError> { + Ok(()) + } +} + +/// A CSV writer builder +#[derive(Clone, Debug)] +pub struct WriterBuilder { + /// Optional column delimiter. Defaults to `b','` + delimiter: u8, + /// Whether to write column names as file headers. Defaults to `true` + has_header: bool, + /// Optional quote character. Defaults to `b'"'` + quote: u8, + /// Optional escape character. Defaults to `b'\\'` + escape: u8, + /// Enable double quote escapes. Defaults to `true` + double_quote: bool, + /// Optional date format for date arrays + date_format: Option, + /// Optional datetime format for datetime arrays + datetime_format: Option, + /// Optional timestamp format for timestamp arrays + timestamp_format: Option, + /// Optional timestamp format for timestamp with timezone arrays + timestamp_tz_format: Option, + /// Optional time format for time arrays + time_format: Option, + /// Optional value to represent null + null_value: Option, +} + +impl Default for WriterBuilder { + fn default() -> Self { + WriterBuilder { + delimiter: b',', + has_header: true, + quote: b'"', + escape: b'\\', + double_quote: true, + date_format: None, + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + null_value: None, + } + } +} + +impl WriterBuilder { + /// Create a new builder for configuring CSV writing options. + /// + /// To convert a builder into a writer, call `WriterBuilder::build` + /// + /// # Example + /// + /// ``` + /// # use arrow_csv::{Writer, WriterBuilder}; + /// # use std::fs::File; + /// + /// fn example() -> Writer { + /// let file = File::create("target/out.csv").unwrap(); + /// + /// // create a builder that doesn't write headers + /// let builder = WriterBuilder::new().with_header(false); + /// let writer = builder.build(file); + /// + /// writer + /// } + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// Set whether to write headers + #[deprecated(note = "Use Self::with_header")] + #[doc(hidden)] + pub fn has_headers(mut self, has_headers: bool) -> Self { + self.has_header = has_headers; + self + } + + /// Set whether to write the CSV file with a header + pub fn with_header(mut self, header: bool) -> Self { + self.has_header = header; + self + } + + /// Returns `true` if this writer is configured to write a header + pub fn header(&self) -> bool { + self.has_header + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = delimiter; + self + } + + /// Get the CSV file's column delimiter as a byte character + pub fn delimiter(&self) -> u8 { + self.delimiter + } + + /// Set the CSV file's quote character as a byte character + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// Get the CSV file's quote character as a byte character + pub fn quote(&self) -> u8 { + self.quote + } + + /// Set the CSV file's escape character as a byte character + /// + /// In some variants of CSV, quotes are escaped using a special escape + /// character like `\` (instead of escaping quotes by doubling them). + /// + /// By default, writing these idiosyncratic escapes is disabled, and is + /// only used when `double_quote` is disabled. + pub fn with_escape(mut self, escape: u8) -> Self { + self.escape = escape; + self + } + + /// Get the CSV file's escape character as a byte character + pub fn escape(&self) -> u8 { + self.escape + } + + /// Set whether to enable double quote escapes + /// + /// When enabled (which is the default), quotes are escaped by doubling + /// them. e.g., `"` escapes to `""`. + /// + /// When disabled, quotes are escaped with the escape character (which + /// is `\\` by default). + pub fn with_double_quote(mut self, double_quote: bool) -> Self { + self.double_quote = double_quote; + self + } + + /// Get whether double quote escapes are enabled + pub fn double_quote(&self) -> bool { + self.double_quote + } + + /// Set the CSV file's date format + pub fn with_date_format(mut self, format: String) -> Self { + self.date_format = Some(format); + self + } + + /// Get the CSV file's date format if set, defaults to RFC3339 + pub fn date_format(&self) -> Option<&str> { + self.date_format.as_deref() + } + + /// Set the CSV file's datetime format + pub fn with_datetime_format(mut self, format: String) -> Self { + self.datetime_format = Some(format); + self + } + + /// Get the CSV file's datetime format if set, defaults to RFC3339 + pub fn datetime_format(&self) -> Option<&str> { + self.datetime_format.as_deref() + } + + /// Set the CSV file's time format + pub fn with_time_format(mut self, format: String) -> Self { + self.time_format = Some(format); + self + } + + /// Get the CSV file's datetime time if set, defaults to RFC3339 + pub fn time_format(&self) -> Option<&str> { + self.time_format.as_deref() + } + + /// Set the CSV file's timestamp format + pub fn with_timestamp_format(mut self, format: String) -> Self { + self.timestamp_format = Some(format); + self + } + + /// Get the CSV file's timestamp format if set, defaults to RFC3339 + pub fn timestamp_format(&self) -> Option<&str> { + self.timestamp_format.as_deref() + } + + /// Set the CSV file's timestamp tz format + pub fn with_timestamp_tz_format(mut self, tz_format: String) -> Self { + self.timestamp_tz_format = Some(tz_format); + self + } + + /// Get the CSV file's timestamp tz format if set, defaults to RFC3339 + pub fn timestamp_tz_format(&self) -> Option<&str> { + self.timestamp_tz_format.as_deref() + } + + /// Set the value to represent null in output + pub fn with_null(mut self, null_value: String) -> Self { + self.null_value = Some(null_value); + self + } + + /// Get the value to represent null in output + pub fn null(&self) -> &str { + self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) + } + + /// Use RFC3339 format for date/time/timestamps (default) + #[deprecated(note = "Use WriterBuilder::default()")] + pub fn with_rfc3339(mut self) -> Self { + self.date_format = None; + self.datetime_format = None; + self.time_format = None; + self.timestamp_format = None; + self.timestamp_tz_format = None; + self + } + + /// Create a new `Writer` + pub fn build(self, writer: W) -> Writer { + let mut builder = csv::WriterBuilder::new(); + let writer = builder + .delimiter(self.delimiter) + .quote(self.quote) + .double_quote(self.double_quote) + .escape(self.escape) + .from_writer(writer); + Writer { + writer, + beginning: true, + has_headers: self.has_header, + date_format: self.date_format, + datetime_format: self.datetime_format, + time_format: self.time_format, + timestamp_format: self.timestamp_format, + timestamp_tz_format: self.timestamp_tz_format, + null_value: self.null_value, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::ReaderBuilder; + use arrow_array::builder::{ + BinaryBuilder, Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, + LargeBinaryBuilder, + }; + use arrow_array::types::*; + use arrow_buffer::i256; + use core::str; + use std::io::{Cursor, Read, Seek}; + use std::sync::Arc; + + #[test] + fn test_write_csv() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + Field::new_dictionary("c7", DataType::Int32, DataType::Utf8, false), + ]); + + let c1 = StringArray::from(vec![ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); + let c3 = PrimitiveArray::::from(vec![3, 2, 1]); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c5 = + TimestampMillisecondArray::from(vec![None, Some(1555584887378), Some(1555555555555)]); + let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); + let c7: DictionaryArray = + vec!["cupcakes", "cupcakes", "foo"].into_iter().collect(); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c5), + Arc::new(c6), + Arc::new(c7), + ], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + let expected = r#"c1,c2,c3,c4,c5,c6,c7 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo +"#; + assert_eq!(expected, str::from_utf8(&buffer).unwrap()); + } + + #[test] + fn test_write_csv_decimal() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Decimal128(38, 6), true), + Field::new("c2", DataType::Decimal256(76, 6), true), + ]); + + let mut c1_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); + let c1 = c1_builder.finish(); + + let mut c2_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); + c2_builder.extend(vec![ + Some(i256::from_i128(-3335724)), + Some(i256::from_i128(2179404)), + None, + Some(i256::from_i128(290472)), + ]); + let c2 = c2_builder.finish(); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + let expected = r#"c1,c2 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +"#; + assert_eq!(expected, str::from_utf8(&buffer).unwrap()); + } + + #[test] + fn test_write_csv_custom_options() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = StringArray::from(vec![ + "Lorem ipsum \ndolor sit amet", + "consectetur \"adipiscing\" elit", + "sed do eiusmod tempor", + ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); + let c3 = PrimitiveArray::::from(vec![3, 2, 1]); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c6), + ], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new() + .with_header(false) + .with_delimiter(b'|') + .with_quote(b'\'') + .with_null("NULL".to_string()) + .with_time_format("%r".to_string()); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "'Lorem ipsum \ndolor sit amet'|123.564532|3|true|12:20:34 AM\nconsectetur \"adipiscing\" elit|NULL|2|false|06:51:20 AM\nsed do eiusmod tempor|-556132.25|1|NULL|11:46:03 PM\n" + .to_string(), + String::from_utf8(buffer).unwrap() + ); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new() + .with_header(true) + .with_double_quote(false) + .with_escape(b'$'); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "c1,c2,c3,c4,c6\n\"Lorem ipsum \ndolor sit amet\",123.564532,3,true,00:20:34\n\"consectetur $\"adipiscing$\" elit\",,2,false,06:51:20\nsed do eiusmod tempor,-556132.25,1,,23:46:03\n" + .to_string(), + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_conversion_consistency() { + // test if we can serialize and deserialize whilst retaining the same type information/ precision + + let schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, false), + Field::new("c2", DataType::Date64, false), + Field::new("c3", DataType::Timestamp(TimeUnit::Nanosecond, None), false), + ]); + + let nanoseconds = vec![ + 1599566300000000000, + 1599566200000000000, + 1599566100000000000, + ]; + let c1 = Date32Array::from(vec![3, 2, 1]); + let c2 = Date64Array::from(vec![3, 2, 1]); + let c3 = TimestampNanosecondArray::from(nanoseconds.clone()); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], + ) + .unwrap(); + + let builder = WriterBuilder::new().with_header(false); + + let mut buf: Cursor> = Default::default(); + // drop the writer early to release the borrow. + { + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + } + buf.set_position(0); + + let mut reader = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(3) + .build_buffered(buf) + .unwrap(); + + let rb = reader.next().unwrap().unwrap(); + let c1 = rb.column(0).as_any().downcast_ref::().unwrap(); + let c2 = rb.column(1).as_any().downcast_ref::().unwrap(); + let c3 = rb + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + let actual = c1.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c2.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c3.into_iter().collect::>(); + let expected = nanoseconds.into_iter().map(Some).collect::>(); + assert_eq!(actual, expected); + } + + #[test] + fn test_write_csv_invalid_cast() { + let schema = Schema::new(vec![ + Field::new("c0", DataType::UInt32, false), + Field::new("c1", DataType::Date64, false), + ]); + + let c0 = UInt32Array::from(vec![Some(123), Some(234)]); + let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)]).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + + for batch in batches { + let err = writer.write(batch).unwrap_err().to_string(); + assert_eq!(err, "Csv error: Error processing row 2, col 2: Cast error: Failed to convert 1926632005177685347 to temporal for Date64") + } + drop(writer); + } + + #[test] + fn test_write_csv_using_rfc3339() { + let schema = Schema::new(vec![ + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c3", DataType::Date32, false), + Field::new("c4", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]) + .with_timezone("+00:00".to_string()); + let c2 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]); + let c3 = Date32Array::from(vec![3, 2]); + let c4 = Time32SecondArray::from(vec![1234, 24680]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new(); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "c1,c2,c3,c4 +2019-04-18T10:54:47.378Z,2019-04-18T10:54:47.378,1970-01-04,00:20:34 +2021-10-30T06:59:07Z,2021-10-30T06:59:07,1970-01-03,06:51:20\n", + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_write_csv_tz_format() { + let schema = Schema::new(vec![ + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("+02:00".into())), + true, + ), + Field::new( + "c2", + DataType::Timestamp(TimeUnit::Second, Some("+04:00".into())), + true, + ), + ]); + let c1 = TimestampMillisecondArray::from(vec![Some(1_000), Some(2_000)]) + .with_timezone("+02:00".to_string()); + let c2 = TimestampSecondArray::from(vec![Some(1_000_000), None]) + .with_timezone("+04:00".to_string()); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + let mut writer = WriterBuilder::new() + .with_timestamp_tz_format("%M:%H".to_string()) + .build(&mut file); + writer.write(&batch).unwrap(); + + drop(writer); + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "c1,c2\n00:02,46:17\n00:02,\n", + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_write_csv_binary() { + let fixed_size = 8; + let schema = SchemaRef::new(Schema::new(vec![ + Field::new("c1", DataType::Binary, true), + Field::new("c2", DataType::FixedSizeBinary(fixed_size), true), + Field::new("c3", DataType::LargeBinary, true), + ])); + let mut c1_builder = BinaryBuilder::new(); + c1_builder.append_value(b"Homer"); + c1_builder.append_value(b"Bart"); + c1_builder.append_null(); + c1_builder.append_value(b"Ned"); + let mut c2_builder = FixedSizeBinaryBuilder::new(fixed_size); + c2_builder.append_value(b"Simpson ").unwrap(); + c2_builder.append_value(b"Simpson ").unwrap(); + c2_builder.append_null(); + c2_builder.append_value(b"Flanders").unwrap(); + let mut c3_builder = LargeBinaryBuilder::new(); + c3_builder.append_null(); + c3_builder.append_null(); + c3_builder.append_value(b"Comic Book Guy"); + c3_builder.append_null(); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(c1_builder.finish()) as ArrayRef, + Arc::new(c2_builder.finish()) as ArrayRef, + Arc::new(c3_builder.finish()) as ArrayRef, + ], + ) + .unwrap(); + + let mut buf = Vec::new(); + let builder = WriterBuilder::new(); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "\ + c1,c2,c3\n\ + 486f6d6572,53696d70736f6e20,\n\ + 42617274,53696d70736f6e20,\n\ + ,,436f6d696320426f6f6b20477579\n\ + 4e6564,466c616e64657273,\n\ + ", + String::from_utf8(buf).unwrap() + ); + } +} diff --git a/arrow-csv/test/data/custom_null_test.csv b/arrow-csv/test/data/custom_null_test.csv new file mode 100644 index 000000000000..39f9fc4b3eff --- /dev/null +++ b/arrow-csv/test/data/custom_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool +1,1.1,"1.11",True +nil,2.2,"2.22",TRUE +3,nil,"3.33",true +4,4.4,nil,False +5,6.6,"",nil diff --git a/arrow/test/data/decimal_test.csv b/arrow-csv/test/data/decimal_test.csv similarity index 100% rename from arrow/test/data/decimal_test.csv rename to arrow-csv/test/data/decimal_test.csv diff --git a/arrow-csv/test/data/example.csv b/arrow-csv/test/data/example.csv new file mode 100644 index 000000000000..0c03cee84528 --- /dev/null +++ b/arrow-csv/test/data/example.csv @@ -0,0 +1,4 @@ +c1,c2,c3,c4 +1,1.1,"hong kong",true +3,323.12,"XiAn",false +10,131323.12,"cheng du",false \ No newline at end of file diff --git a/arrow-csv/test/data/init_null_test.csv b/arrow-csv/test/data/init_null_test.csv new file mode 100644 index 000000000000..f7d8a299645d --- /dev/null +++ b/arrow-csv/test/data/init_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool,c_null +,,,, +2,2.2,"a",TRUE, +3,,"b",true, +4,4.4,,False, +5,6.6,"",FALSE, \ No newline at end of file diff --git a/arrow/test/data/null_test.csv b/arrow-csv/test/data/null_test.csv similarity index 100% rename from arrow/test/data/null_test.csv rename to arrow-csv/test/data/null_test.csv diff --git a/arrow-csv/test/data/scientific_notation_test.csv b/arrow-csv/test/data/scientific_notation_test.csv new file mode 100644 index 000000000000..632c3ef8bc51 --- /dev/null +++ b/arrow-csv/test/data/scientific_notation_test.csv @@ -0,0 +1,19 @@ +1.439e+04, positive_exponent +1.31e+04, positive_exponent +1.2711e+0, positive_exponent +1.44e+04, positive_exponent +2.22e+04, positive_exponent +1.149e+04, positive_exponent +2.139e+04, positive_exponent +7.322e+04, positive_exponent +1.531e+04, positive_exponent +2.206e-04, negative_exponent +1.517e-04, negative_exponent +2.332e-04, negative_exponent +2.19e-04, negative_exponent +2.087e-04, negative_exponent +12683.18, no_exponent +7134.6, no_exponent +8540.17, no_exponent +21462.36, no_exponent +1120.76, no_exponent \ No newline at end of file diff --git a/arrow-csv/test/data/truncated_rows.csv b/arrow-csv/test/data/truncated_rows.csv new file mode 100644 index 000000000000..0b2af5740095 --- /dev/null +++ b/arrow-csv/test/data/truncated_rows.csv @@ -0,0 +1,8 @@ +Name,Age,Occupation,DOB +A1,34,Engineer,1985-07-16 +B2,29,Doctor +, +C3,45,Artist,1964-11-30 + +D4 +E5,31,, diff --git a/arrow/test/data/uk_cities.csv b/arrow-csv/test/data/uk_cities.csv similarity index 100% rename from arrow/test/data/uk_cities.csv rename to arrow-csv/test/data/uk_cities.csv diff --git a/arrow/test/data/uk_cities_with_headers.csv b/arrow-csv/test/data/uk_cities_with_headers.csv similarity index 100% rename from arrow/test/data/uk_cities_with_headers.csv rename to arrow-csv/test/data/uk_cities_with_headers.csv diff --git a/arrow/test/data/various_types.csv b/arrow-csv/test/data/various_types.csv similarity index 100% rename from arrow/test/data/various_types.csv rename to arrow-csv/test/data/various_types.csv diff --git a/arrow/test/data/various_types_invalid.csv b/arrow-csv/test/data/various_types_invalid.csv similarity index 100% rename from arrow/test/data/various_types_invalid.csv rename to arrow-csv/test/data/various_types_invalid.csv diff --git a/arrow-data/Cargo.toml b/arrow-data/Cargo.toml new file mode 100644 index 000000000000..c83f867523d5 --- /dev/null +++ b/arrow-data/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-data" +version = { workspace = true } +description = "Array data abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_data" +path = "src/lib.rs" +bench = false + +[features] +# force_validate runs full data validation for all arrays that are created +# this is not enabled by default as it is too computationally expensive +# but is run as part of our CI checks +force_validate = [] +# Enable ffi support +ffi = ["arrow-schema/ffi"] + +[package.metadata.docs.rs] +features = ["ffi"] + +[dependencies] + +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } + +num = { version = "0.4", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false } + +[dev-dependencies] + +[build-dependencies] diff --git a/arrow-data/src/byte_view.rs b/arrow-data/src/byte_view.rs new file mode 100644 index 000000000000..6f6d6d175689 --- /dev/null +++ b/arrow-data/src/byte_view.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; + +/// Helper to access views of [`GenericByteViewArray`] (`StringViewArray` and +/// `BinaryViewArray`) where the length is greater than 12 bytes. +/// +/// See the documentation on [`GenericByteViewArray`] for more information on +/// the layout of the views. +/// +/// [`GenericByteViewArray`]: https://docs.rs/arrow/latest/arrow/array/struct.GenericByteViewArray.html +#[derive(Debug, Copy, Clone, Default)] +#[repr(C)] +pub struct ByteView { + /// The length of the string/bytes. + pub length: u32, + /// First 4 bytes of string/bytes data. + pub prefix: u32, + /// The buffer index. + pub buffer_index: u32, + /// The offset into the buffer. + pub offset: u32, +} + +impl ByteView { + #[inline(always)] + /// Convert `ByteView` to `u128` by concatenating the fields + pub fn as_u128(self) -> u128 { + (self.length as u128) + | ((self.prefix as u128) << 32) + | ((self.buffer_index as u128) << 64) + | ((self.offset as u128) << 96) + } +} + +impl From for ByteView { + #[inline] + fn from(value: u128) -> Self { + Self { + length: value as u32, + prefix: (value >> 32) as u32, + buffer_index: (value >> 64) as u32, + offset: (value >> 96) as u32, + } + } +} + +impl From for u128 { + #[inline] + fn from(value: ByteView) -> Self { + value.as_u128() + } +} + +/// Validates the combination of `views` and `buffers` is a valid BinaryView +pub fn validate_binary_view(views: &[u128], buffers: &[Buffer]) -> Result<(), ArrowError> { + validate_view_impl(views, buffers, |_, _| Ok(())) +} + +/// Validates the combination of `views` and `buffers` is a valid StringView +pub fn validate_string_view(views: &[u128], buffers: &[Buffer]) -> Result<(), ArrowError> { + validate_view_impl(views, buffers, |idx, b| { + std::str::from_utf8(b).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Encountered non-UTF-8 data at index {idx}: {e}" + )) + })?; + Ok(()) + }) +} + +fn validate_view_impl(views: &[u128], buffers: &[Buffer], f: F) -> Result<(), ArrowError> +where + F: Fn(usize, &[u8]) -> Result<(), ArrowError>, +{ + for (idx, v) in views.iter().enumerate() { + let len = *v as u32; + if len <= 12 { + if len < 12 && (v >> (32 + len * 8)) != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "View at index {idx} contained non-zero padding for string of length {len}", + ))); + } + f(idx, &v.to_le_bytes()[4..4 + len as usize])?; + } else { + let view = ByteView::from(*v); + let data = buffers.get(view.buffer_index as usize).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Invalid buffer index at {idx}: got index {} but only has {} buffers", + view.buffer_index, + buffers.len() + )) + })?; + + let start = view.offset as usize; + let end = start + len as usize; + let b = data.get(start..end).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Invalid buffer slice at {idx}: got {start}..{end} but buffer {} has length {}", + view.buffer_index, + data.len() + )) + })?; + + if !b.starts_with(&view.prefix.to_le_bytes()) { + return Err(ArrowError::InvalidArgumentError( + "Mismatch between embedded prefix and data".to_string(), + )); + } + + f(idx, b)?; + } + } + Ok(()) +} diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs new file mode 100644 index 000000000000..8c9e002e219b --- /dev/null +++ b/arrow-data/src/data.rs @@ -0,0 +1,2254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains [`ArrayData`], a generic representation of Arrow array data which encapsulates +//! common attributes and operations for Arrow array. + +use crate::bit_iterator::BitSliceIterator; +use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; +use arrow_buffer::{ + bit_util, i256, ArrowNativeType, Buffer, IntervalDayTime, IntervalMonthDayNano, MutableBuffer, +}; +use arrow_schema::{ArrowError, DataType, UnionMode}; +use std::mem; +use std::ops::Range; +use std::sync::Arc; + +use crate::{equal, validate_binary_view, validate_string_view}; + +/// A collection of [`Buffer`] +#[doc(hidden)] +#[deprecated(note = "Use [Buffer]")] +pub type Buffers<'a> = &'a [Buffer]; + +#[inline] +pub(crate) fn contains_nulls( + null_bit_buffer: Option<&NullBuffer>, + offset: usize, + len: usize, +) -> bool { + match null_bit_buffer { + Some(buffer) => { + match BitSliceIterator::new(buffer.validity(), buffer.offset() + offset, len).next() { + Some((start, end)) => start != 0 || end != len, + None => len != 0, // No non-null values + } + } + None => false, // No null buffer + } +} + +#[inline] +pub(crate) fn count_nulls( + null_bit_buffer: Option<&NullBuffer>, + offset: usize, + len: usize, +) -> usize { + if let Some(buf) = null_bit_buffer { + let buffer = buf.buffer(); + len - buffer.count_set_bits_offset(offset + buf.offset(), len) + } else { + 0 + } +} + +/// creates 2 [`MutableBuffer`]s with a given `capacity` (in slots). +#[inline] +pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuffer; 2] { + let empty_buffer = MutableBuffer::new(0); + match data_type { + DataType::Null => [empty_buffer, MutableBuffer::new(0)], + DataType::Boolean => { + let bytes = bit_util::ceil(capacity, 8); + let buffer = MutableBuffer::new(bytes); + [buffer, empty_buffer] + } + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) => [ + MutableBuffer::new(capacity * data_type.primitive_width().unwrap()), + empty_buffer, + ], + DataType::Utf8 | DataType::Binary => { + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(0i32); + [buffer, MutableBuffer::new(capacity * mem::size_of::())] + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(0i64); + [buffer, MutableBuffer::new(capacity * mem::size_of::())] + } + DataType::BinaryView | DataType::Utf8View => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], + DataType::List(_) | DataType::Map(_, _) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + buffer.push(0i32); + [buffer, empty_buffer] + } + DataType::ListView(_) => [ + MutableBuffer::new(capacity * mem::size_of::()), + MutableBuffer::new(capacity * mem::size_of::()), + ], + DataType::LargeList(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + buffer.push(0i64); + [buffer, empty_buffer] + } + DataType::LargeListView(_) => [ + MutableBuffer::new(capacity * mem::size_of::()), + MutableBuffer::new(capacity * mem::size_of::()), + ], + DataType::FixedSizeBinary(size) => { + [MutableBuffer::new(capacity * *size as usize), empty_buffer] + } + DataType::Dictionary(k, _) => [ + MutableBuffer::new(capacity * k.primitive_width().unwrap()), + empty_buffer, + ], + DataType::FixedSizeList(_, _) | DataType::Struct(_) | DataType::RunEndEncoded(_, _) => { + [empty_buffer, MutableBuffer::new(0)] + } + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], + DataType::Union(_, mode) => { + let type_ids = MutableBuffer::new(capacity * mem::size_of::()); + match mode { + UnionMode::Sparse => [type_ids, empty_buffer], + UnionMode::Dense => { + let offsets = MutableBuffer::new(capacity * mem::size_of::()); + [type_ids, offsets] + } + } + } + } +} + +/// A generic representation of Arrow array data which encapsulates common attributes +/// and operations for Arrow array. +/// +/// Specific operations for different arrays types (e.g., primitive, list, struct) +/// are implemented in `Array`. +/// +/// # Memory Layout +/// +/// `ArrayData` has references to one or more underlying data buffers +/// and optional child ArrayData, depending on type as illustrated +/// below. Bitmaps are not shown for simplicity but they are stored +/// similarly to the buffers. +/// +/// ```text +/// offset +/// points to +/// ┌───────────────────┐ start of ┌───────┐ Different +/// │ │ data │ │ ArrayData may +/// │ArrayData { │ │.... │ also refers to +/// │ data_type: ... │ ─ ─ ─ ─▶│1234 │ ┌ ─ the same +/// │ offset: ... ─ ─ ─│─ ┘ │4372 │ underlying +/// │ len: ... ─ ─ ─│─ ┐ │4888 │ │ buffer with different offset/len +/// │ buffers: [ │ │5882 │◀─ +/// │ ... │ │ │4323 │ +/// │ ] │ ─ ─ ─ ─▶│4859 │ +/// │ child_data: [ │ │.... │ +/// │ ... │ │ │ +/// │ ] │ └───────┘ +/// │} │ +/// │ │ Shared Buffer uses +/// │ │ │ bytes::Bytes to hold +/// └───────────────────┘ actual data values +/// ┌ ─ ─ ┘ +/// +/// ▼ +/// ┌───────────────────┐ +/// │ArrayData { │ +/// │ ... │ +/// │} │ +/// │ │ +/// └───────────────────┘ +/// +/// Child ArrayData may also have its own buffers and children +/// ``` + +#[derive(Debug, Clone)] +pub struct ArrayData { + /// The data type for this array data + data_type: DataType, + + /// The number of elements in this array data + len: usize, + + /// The offset into this array data, in number of items + offset: usize, + + /// The buffers for this array data. Note that depending on the array types, this + /// could hold different kinds of buffers (e.g., value buffer, value offset buffer) + /// at different positions. + buffers: Vec, + + /// The child(ren) of this array. Only non-empty for nested types, currently + /// `ListArray` and `StructArray`. + child_data: Vec, + + /// The null bitmap. A `None` value for this indicates all values are non-null in + /// this array. + nulls: Option, +} + +/// A thread-safe, shared reference to the Arrow array data. +pub type ArrayDataRef = Arc; + +impl ArrayData { + /// Create a new ArrayData instance; + /// + /// If `null_count` is not specified, the number of nulls in + /// null_bit_buffer is calculated. + /// + /// If the number of nulls is 0 then the null_bit_buffer + /// is set to `None`. + /// + /// # Safety + /// + /// The input values *must* form a valid Arrow array for + /// `data_type`, or undefined behavior can result. + /// + /// Note: This is a low level API and most users of the arrow + /// crate should create arrays using the methods in the `array` + /// module. + pub unsafe fn new_unchecked( + data_type: DataType, + len: usize, + null_count: Option, + null_bit_buffer: Option, + offset: usize, + buffers: Vec, + child_data: Vec, + ) -> Self { + ArrayDataBuilder { + data_type, + len, + null_count, + null_bit_buffer, + nulls: None, + offset, + buffers, + child_data, + } + .build_unchecked() + } + + /// Create a new ArrayData, validating that the provided buffers form a valid + /// Arrow array of the specified data type. + /// + /// If the number of nulls in `null_bit_buffer` is 0 then the null_bit_buffer + /// is set to `None`. + /// + /// Internally this calls through to [`Self::validate_data`] + /// + /// Note: This is a low level API and most users of the arrow crate should create + /// arrays using the builders found in [arrow_array](https://docs.rs/arrow-array) + pub fn try_new( + data_type: DataType, + len: usize, + null_bit_buffer: Option, + offset: usize, + buffers: Vec, + child_data: Vec, + ) -> Result { + // we must check the length of `null_bit_buffer` first + // because we use this buffer to calculate `null_count` + // in `Self::new_unchecked`. + if let Some(null_bit_buffer) = null_bit_buffer.as_ref() { + let needed_len = bit_util::ceil(len + offset, 8); + if null_bit_buffer.len() < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {} needed {}", + null_bit_buffer.len(), + needed_len + ))); + } + } + // Safety justification: `validate_full` is called below + let new_self = unsafe { + Self::new_unchecked( + data_type, + len, + None, + null_bit_buffer, + offset, + buffers, + child_data, + ) + }; + + // As the data is not trusted, do a full validation of its contents + // We don't need to validate children as we can assume that the + // [`ArrayData`] in `child_data` have already been validated through + // a call to `ArrayData::try_new` or created using unsafe + new_self.validate_data()?; + Ok(new_self) + } + + /// Returns a builder to construct a [`ArrayData`] instance of the same [`DataType`] + #[inline] + pub const fn builder(data_type: DataType) -> ArrayDataBuilder { + ArrayDataBuilder::new(data_type) + } + + /// Returns a reference to the [`DataType`] of this [`ArrayData`] + #[inline] + pub const fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the [`Buffer`] storing data for this [`ArrayData`] + pub fn buffers(&self) -> &[Buffer] { + &self.buffers + } + + /// Returns a slice of children [`ArrayData`]. This will be non + /// empty for type such as lists and structs. + pub fn child_data(&self) -> &[ArrayData] { + &self.child_data[..] + } + + /// Returns whether the element at index `i` is null + #[inline] + pub fn is_null(&self, i: usize) -> bool { + match &self.nulls { + Some(v) => v.is_null(i), + None => false, + } + } + + /// Returns a reference to the null buffer of this [`ArrayData`] if any + /// + /// Note: [`ArrayData::offset`] does NOT apply to the returned [`NullBuffer`] + #[inline] + pub fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + /// Returns whether the element at index `i` is not null + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + !self.is_null(i) + } + + /// Returns the length (i.e., number of elements) of this [`ArrayData`]. + #[inline] + pub const fn len(&self) -> usize { + self.len + } + + /// Returns whether this [`ArrayData`] is empty + #[inline] + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the offset of this [`ArrayData`] + #[inline] + pub const fn offset(&self) -> usize { + self.offset + } + + /// Returns the total number of nulls in this array + #[inline] + pub fn null_count(&self) -> usize { + self.nulls + .as_ref() + .map(|x| x.null_count()) + .unwrap_or_default() + } + + /// Returns the total number of bytes of memory occupied by the + /// buffers owned by this [`ArrayData`] and all of its + /// children. (See also diagram on [`ArrayData`]). + /// + /// Note that this [`ArrayData`] may only refer to a subset of the + /// data in the underlying [`Buffer`]s (due to `offset` and + /// `length`), but the size returned includes the entire size of + /// the buffers. + /// + /// If multiple [`ArrayData`]s refer to the same underlying + /// [`Buffer`]s they will both report the same size. + pub fn get_buffer_memory_size(&self) -> usize { + let mut size = 0; + for buffer in &self.buffers { + size += buffer.capacity(); + } + if let Some(bitmap) = &self.nulls { + size += bitmap.buffer().capacity() + } + for child in &self.child_data { + size += child.get_buffer_memory_size(); + } + size + } + + /// Returns the total number of the bytes of memory occupied by + /// the buffers by this slice of [`ArrayData`] (See also diagram on [`ArrayData`]). + /// + /// This is approximately the number of bytes if a new + /// [`ArrayData`] was formed by creating new [`Buffer`]s with + /// exactly the data needed. + /// + /// For example, a [`DataType::Int64`] with `100` elements, + /// [`Self::get_slice_memory_size`] would return `100 * 8 = 800`. If + /// the [`ArrayData`] was then [`Self::slice`]ed to refer to its + /// first `20` elements, then [`Self::get_slice_memory_size`] on the + /// sliced [`ArrayData`] would return `20 * 8 = 160`. + pub fn get_slice_memory_size(&self) -> Result { + let mut result: usize = 0; + let layout = layout(&self.data_type); + + for spec in layout.buffers.iter() { + match spec { + BufferSpec::FixedWidth { byte_width, .. } => { + let buffer_size = self.len.checked_mul(*byte_width).ok_or_else(|| { + ArrowError::ComputeError( + "Integer overflow computing buffer size".to_string(), + ) + })?; + result += buffer_size; + } + BufferSpec::VariableWidth => { + let buffer_len: usize; + match self.data_type { + DataType::Utf8 | DataType::Binary => { + let offsets = self.typed_offsets::()?; + buffer_len = (offsets[self.len] - offsets[0] ) as usize; + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let offsets = self.typed_offsets::()?; + buffer_len = (offsets[self.len] - offsets[0]) as usize; + } + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {}", + self.data_type + ))) + } + }; + result += buffer_len; + } + BufferSpec::BitMap => { + let buffer_size = bit_util::ceil(self.len, 8); + result += buffer_size; + } + BufferSpec::AlwaysNull => { + // Nothing to do + } + } + } + + if self.nulls().is_some() { + result += bit_util::ceil(self.len, 8); + } + + for child in &self.child_data { + result += child.get_slice_memory_size()?; + } + Ok(result) + } + + /// Returns the total number of bytes of memory occupied + /// physically by this [`ArrayData`] and all its [`Buffer`]s and + /// children. (See also diagram on [`ArrayData`]). + /// + /// Equivalent to: + /// `size_of_val(self)` + + /// [`Self::get_buffer_memory_size`] + + /// `size_of_val(child)` for all children + pub fn get_array_memory_size(&self) -> usize { + let mut size = mem::size_of_val(self); + + // Calculate rest of the fields top down which contain actual data + for buffer in &self.buffers { + size += mem::size_of::(); + size += buffer.capacity(); + } + if let Some(nulls) = &self.nulls { + size += nulls.buffer().capacity(); + } + for child in &self.child_data { + size += child.get_array_memory_size(); + } + + size + } + + /// Creates a zero-copy slice of itself. This creates a new + /// [`ArrayData`] pointing at the same underlying [`Buffer`]s with a + /// different offset and len + /// + /// # Panics + /// + /// Panics if `offset + length > self.len()`. + pub fn slice(&self, offset: usize, length: usize) -> ArrayData { + assert!((offset + length) <= self.len()); + + if let DataType::Struct(_) = self.data_type() { + // Slice into children + let new_offset = self.offset + offset; + let new_data = ArrayData { + data_type: self.data_type().clone(), + len: length, + offset: new_offset, + buffers: self.buffers.clone(), + // Slice child data, to propagate offsets down to them + child_data: self + .child_data() + .iter() + .map(|data| data.slice(offset, length)) + .collect(), + nulls: self.nulls.as_ref().map(|x| x.slice(offset, length)), + }; + + new_data + } else { + let mut new_data = self.clone(); + + new_data.len = length; + new_data.offset = offset + self.offset; + new_data.nulls = self.nulls.as_ref().map(|x| x.slice(offset, length)); + + new_data + } + } + + /// Returns the `buffer` as a slice of type `T` starting at self.offset + /// # Panics + /// This function panics if: + /// * the buffer is not byte-aligned with type T, or + /// * the datatype is `Boolean` (it corresponds to a bit-packed buffer where the offset is not applicable) + pub fn buffer(&self, buffer: usize) -> &[T] { + &self.buffers()[buffer].typed_data()[self.offset..] + } + + /// Returns a new [`ArrayData`] valid for `data_type` containing `len` null values + pub fn new_null(data_type: &DataType, len: usize) -> Self { + let bit_len = bit_util::ceil(len, 8); + let zeroed = |len: usize| Buffer::from(MutableBuffer::from_len_zeroed(len)); + + let (buffers, child_data, has_nulls) = match data_type.primitive_width() { + Some(width) => (vec![zeroed(width * len)], vec![], true), + None => match data_type { + DataType::Null => (vec![], vec![], false), + DataType::Boolean => (vec![zeroed(bit_len)], vec![], true), + DataType::Binary | DataType::Utf8 => { + (vec![zeroed((len + 1) * 4), zeroed(0)], vec![], true) + } + DataType::BinaryView | DataType::Utf8View => (vec![zeroed(len * 16)], vec![], true), + DataType::LargeBinary | DataType::LargeUtf8 => { + (vec![zeroed((len + 1) * 8), zeroed(0)], vec![], true) + } + DataType::FixedSizeBinary(i) => (vec![zeroed(*i as usize * len)], vec![], true), + DataType::List(f) | DataType::Map(f, _) => ( + vec![zeroed((len + 1) * 4)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), + DataType::LargeList(f) => ( + vec![zeroed((len + 1) * 8)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), + DataType::FixedSizeList(f, list_len) => ( + vec![], + vec![ArrayData::new_null(f.data_type(), *list_len as usize * len)], + true, + ), + DataType::Struct(fields) => ( + vec![], + fields + .iter() + .map(|f| Self::new_null(f.data_type(), len)) + .collect(), + true, + ), + DataType::Dictionary(k, v) => ( + vec![zeroed(k.primitive_width().unwrap() * len)], + vec![ArrayData::new_empty(v.as_ref())], + true, + ), + DataType::Union(f, mode) => { + let (id, _) = f.iter().next().unwrap(); + let ids = Buffer::from_iter(std::iter::repeat(id).take(len)); + let buffers = match mode { + UnionMode::Sparse => vec![ids], + UnionMode::Dense => { + let end_offset = i32::from_usize(len).unwrap(); + vec![ids, Buffer::from_iter(0_i32..end_offset)] + } + }; + + let children = f + .iter() + .enumerate() + .map(|(idx, (_, f))| { + if idx == 0 || *mode == UnionMode::Sparse { + Self::new_null(f.data_type(), len) + } else { + Self::new_empty(f.data_type()) + } + }) + .collect(); + + (buffers, children, false) + } + DataType::RunEndEncoded(r, v) => { + let runs = match r.data_type() { + DataType::Int16 => { + let i = i16::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + DataType::Int32 => { + let i = i32::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + DataType::Int64 => { + let i = i64::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + dt => unreachable!("Invalid run ends data type {dt}"), + }; + + let builder = ArrayData::builder(r.data_type().clone()) + .len(1) + .buffers(vec![runs]); + + // SAFETY: + // Valid by construction + let runs = unsafe { builder.build_unchecked() }; + ( + vec![], + vec![runs, ArrayData::new_null(v.data_type(), 1)], + false, + ) + } + d => unreachable!("{d}"), + }, + }; + + let mut builder = ArrayDataBuilder::new(data_type.clone()) + .len(len) + .buffers(buffers) + .child_data(child_data); + + if has_nulls { + builder = builder.nulls(Some(NullBuffer::new_null(len))) + } + + // SAFETY: + // Data valid by construction + unsafe { builder.build_unchecked() } + } + + /// Returns a new empty [ArrayData] valid for `data_type`. + pub fn new_empty(data_type: &DataType) -> Self { + Self::new_null(data_type, 0) + } + + /// Verifies that the buffers meet the minimum alignment requirements for the data type + /// + /// Buffers that are not adequately aligned will be copied to a new aligned allocation + /// + /// This can be useful for when interacting with data sent over IPC or FFI, that may + /// not meet the minimum alignment requirements + pub fn align_buffers(&mut self) { + let layout = layout(&self.data_type); + for (buffer, spec) in self.buffers.iter_mut().zip(&layout.buffers) { + if let BufferSpec::FixedWidth { alignment, .. } = spec { + if buffer.as_ptr().align_offset(*alignment) != 0 { + *buffer = Buffer::from_slice_ref(buffer.as_ref()) + } + } + } + } + + /// "cheap" validation of an `ArrayData`. Ensures buffers are + /// sufficiently sized to store `len` + `offset` total elements of + /// `data_type` and performs other inexpensive consistency checks. + /// + /// This check is "cheap" in the sense that it does not validate the + /// contents of the buffers (e.g. that all offsets for UTF8 arrays + /// are within the bounds of the values buffer). + /// + /// See [ArrayData::validate_data] to validate fully the offset content + /// and the validity of utf8 data + pub fn validate(&self) -> Result<(), ArrowError> { + // Need at least this mich space in each buffer + let len_plus_offset = self.len + self.offset; + + // Check that the data layout conforms to the spec + let layout = layout(&self.data_type); + + if !layout.can_contain_null_mask && self.nulls.is_some() { + return Err(ArrowError::InvalidArgumentError(format!( + "Arrays of type {:?} cannot contain a null bitmask", + self.data_type, + ))); + } + + // Check data buffers length for view types and other types + if self.buffers.len() < layout.buffers.len() + || (!layout.variadic && self.buffers.len() != layout.buffers.len()) + { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected {} buffers in array of type {:?}, got {}", + layout.buffers.len(), + self.data_type, + self.buffers.len(), + ))); + } + + for (i, (buffer, spec)) in self.buffers.iter().zip(layout.buffers.iter()).enumerate() { + match spec { + BufferSpec::FixedWidth { + byte_width, + alignment, + } => { + let min_buffer_size = len_plus_offset.saturating_mul(*byte_width); + + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + + let align_offset = buffer.as_ptr().align_offset(*alignment); + if align_offset != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Misaligned buffers[{i}] in array of type {:?}, offset from expected alignment of {alignment} by {}", + self.data_type, align_offset.min(alignment - align_offset) + ))); + } + } + BufferSpec::VariableWidth => { + // not cheap to validate (need to look at the + // data). Partially checked in validate_offsets + // called below. Can check with `validate_full` + } + BufferSpec::BitMap => { + let min_buffer_size = bit_util::ceil(len_plus_offset, 8); + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes for bitmap in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + } + BufferSpec::AlwaysNull => { + // Nothing to validate + } + } + } + + // check null bit buffer size + if let Some(nulls) = self.nulls() { + if nulls.null_count() > self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count {} for an array exceeds length of {} elements", + nulls.null_count(), + self.len + ))); + } + + let actual_len = nulls.validity().len(); + let needed_len = bit_util::ceil(len_plus_offset, 8); + if actual_len < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {actual_len} needed {needed_len}", + ))); + } + + if nulls.len() != self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "null buffer incorrect size. got {} expected {}", + nulls.len(), + self.len + ))); + } + } + + self.validate_child_data()?; + + // Additional Type specific checks + match &self.data_type { + DataType::Utf8 | DataType::Binary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::LargeUtf8 | DataType::LargeBinary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::Dictionary(key_type, _value_type) => { + // At the moment, constructing a DictionaryArray will also check this + if !DataType::is_dictionary_key_type(key_type) { + return Err(ArrowError::InvalidArgumentError(format!( + "Dictionary key type must be integer, but was {key_type}" + ))); + } + } + DataType::RunEndEncoded(run_ends_type, _) => { + if run_ends_type.is_nullable() { + return Err(ArrowError::InvalidArgumentError( + "The nullable should be set to false for the field defining run_ends array.".to_string() + )); + } + if !DataType::is_run_ends_type(run_ends_type.data_type()) { + return Err(ArrowError::InvalidArgumentError(format!( + "RunArray run_ends types must be Int16, Int32 or Int64, but was {}", + run_ends_type.data_type() + ))); + } + } + _ => {} + }; + + Ok(()) + } + + /// Returns a reference to the data in `buffer` as a typed slice + /// (typically `&[i32]` or `&[i64]`) after validating. The + /// returned slice is guaranteed to have at least `self.len + 1` + /// entries. + /// + /// For an empty array, the `buffer` can also be empty. + fn typed_offsets(&self) -> Result<&[T], ArrowError> { + // An empty list-like array can have 0 offsets + if self.len == 0 && self.buffers[0].is_empty() { + return Ok(&[]); + } + + self.typed_buffer(0, self.len + 1) + } + + /// Returns a reference to the data in `buffers[idx]` as a typed slice after validating + fn typed_buffer( + &self, + idx: usize, + len: usize, + ) -> Result<&[T], ArrowError> { + let buffer = &self.buffers[idx]; + + let required_len = (len + self.offset) * mem::size_of::(); + + if buffer.len() < required_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Buffer {} of {} isn't large enough. Expected {} bytes got {}", + idx, + self.data_type, + required_len, + buffer.len() + ))); + } + + Ok(&buffer.typed_data::()[self.offset..self.offset + len]) + } + + /// Does a cheap sanity check that the `self.len` values in `buffer` are valid + /// offsets (of type T) into some other buffer of `values_length` bytes long + fn validate_offsets( + &self, + values_length: usize, + ) -> Result<(), ArrowError> { + // Justification: buffer size was validated above + let offsets = self.typed_offsets::()?; + if offsets.is_empty() { + return Ok(()); + } + + let first_offset = offsets[0].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[0] ({}) to usize for {}", + offsets[0], self.data_type + )) + })?; + + let last_offset = offsets[self.len].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] ({}) to usize for {}", + self.len, offsets[self.len], self.data_type + )) + })?; + + if first_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} of {} is larger than values length {}", + first_offset, self.data_type, values_length, + ))); + } + + if last_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "Last offset {} of {} is larger than values length {}", + last_offset, self.data_type, values_length, + ))); + } + + if first_offset > last_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} in {} is smaller than last offset {}", + first_offset, self.data_type, last_offset, + ))); + } + + Ok(()) + } + + /// Does a cheap sanity check that the `self.len` values in `buffer` are valid + /// offsets and sizes (of type T) into some other buffer of `values_length` bytes long + fn validate_offsets_and_sizes( + &self, + values_length: usize, + ) -> Result<(), ArrowError> { + let offsets: &[T] = self.typed_buffer(0, self.len)?; + let sizes: &[T] = self.typed_buffer(1, self.len)?; + for i in 0..values_length { + let size = sizes[i].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting size[{}] ({}) to usize for {}", + i, sizes[i], self.data_type + )) + })?; + let offset = offsets[i].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] ({}) to usize for {}", + i, offsets[i], self.data_type + )) + })?; + if size + .checked_add(offset) + .expect("Offset and size have exceeded the usize boundary") + > values_length + { + return Err(ArrowError::InvalidArgumentError(format!( + "Size {} at index {} is larger than the remaining values for {}", + size, i, self.data_type + ))); + } + } + Ok(()) + } + + /// Validates the layout of `child_data` ArrayData structures + fn validate_child_data(&self) -> Result<(), ArrowError> { + match &self.data_type { + DataType::List(field) | DataType::Map(field, _) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::LargeList(field) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::ListView(field) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets_and_sizes::(values_data.len)?; + Ok(()) + } + DataType::LargeListView(field) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets_and_sizes::(values_data.len)?; + Ok(()) + } + DataType::FixedSizeList(field, list_size) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + + let list_size: usize = (*list_size).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "{} has a negative list_size {}", + self.data_type, list_size + )) + })?; + + let expected_values_len = self.len + .checked_mul(list_size) + .expect("integer overflow computing expected number of expected values in FixedListSize"); + + if values_data.len < expected_values_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Values length {} is less than the length ({}) multiplied by the value size ({}) for {}", + values_data.len, list_size, list_size, self.data_type + ))); + } + + Ok(()) + } + DataType::Struct(fields) => { + self.validate_num_child_data(fields.len())?; + for (i, field) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + // Ensure child field has sufficient size + if field_data.len < self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "{} child array #{} for field {} has length smaller than expected for struct array ({} < {})", + self.data_type, i, field.name(), field_data.len, self.len + ))); + } + } + Ok(()) + } + DataType::RunEndEncoded(run_ends_field, values_field) => { + self.validate_num_child_data(2)?; + let run_ends_data = self.get_valid_child_data(0, run_ends_field.data_type())?; + let values_data = self.get_valid_child_data(1, values_field.data_type())?; + if run_ends_data.len != values_data.len { + return Err(ArrowError::InvalidArgumentError(format!( + "The run_ends array length should be the same as values array length. Run_ends array length is {}, values array length is {}", + run_ends_data.len, values_data.len + ))); + } + if run_ends_data.nulls.is_some() { + return Err(ArrowError::InvalidArgumentError( + "Found null values in run_ends array. The run_ends array should not have null values.".to_string(), + )); + } + Ok(()) + } + DataType::Union(fields, mode) => { + self.validate_num_child_data(fields.len())?; + + for (i, (_, field)) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + if mode == &UnionMode::Sparse && field_data.len < (self.len + self.offset) { + return Err(ArrowError::InvalidArgumentError(format!( + "Sparse union child array #{} has length smaller than expected for union array ({} < {})", + i, field_data.len, self.len + self.offset + ))); + } + } + Ok(()) + } + DataType::Dictionary(_key_type, value_type) => { + self.get_single_valid_child_data(value_type)?; + Ok(()) + } + _ => { + // other types do not have child data + if !self.child_data.is_empty() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected no child arrays for type {} but got {}", + self.data_type, + self.child_data.len() + ))); + } + Ok(()) + } + } + } + + /// Ensures that this array data has a single child_data with the + /// expected type, and calls `validate()` on it. Returns a + /// reference to that child_data + fn get_single_valid_child_data( + &self, + expected_type: &DataType, + ) -> Result<&ArrayData, ArrowError> { + self.validate_num_child_data(1)?; + self.get_valid_child_data(0, expected_type) + } + + /// Returns `Err` if self.child_data does not have exactly `expected_len` elements + fn validate_num_child_data(&self, expected_len: usize) -> Result<(), ArrowError> { + if self.child_data.len() != expected_len { + Err(ArrowError::InvalidArgumentError(format!( + "Value data for {} should contain {} child data array(s), had {}", + self.data_type, + expected_len, + self.child_data.len() + ))) + } else { + Ok(()) + } + } + + /// Ensures that `child_data[i]` has the expected type, calls + /// `validate()` on it, and returns a reference to that child_data + fn get_valid_child_data( + &self, + i: usize, + expected_type: &DataType, + ) -> Result<&ArrayData, ArrowError> { + let values_data = self.child_data.get(i).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "{} did not have enough child arrays. Expected at least {} but had only {}", + self.data_type, + i + 1, + self.child_data.len() + )) + })?; + + if expected_type != &values_data.data_type { + return Err(ArrowError::InvalidArgumentError(format!( + "Child type mismatch for {}. Expected {} but child data had {}", + self.data_type, expected_type, values_data.data_type + ))); + } + + values_data.validate()?; + Ok(values_data) + } + + /// Validate that the data contained within this [`ArrayData`] is valid + /// + /// 1. Null count is correct + /// 2. All offsets are valid + /// 3. All String data is valid UTF-8 + /// 4. All dictionary offsets are valid + /// + /// Internally this calls: + /// + /// * [`Self::validate`] + /// * [`Self::validate_nulls`] + /// * [`Self::validate_values`] + /// + /// Note: this does not recurse into children, for a recursive variant + /// see [`Self::validate_full`] + pub fn validate_data(&self) -> Result<(), ArrowError> { + self.validate()?; + + self.validate_nulls()?; + self.validate_values()?; + Ok(()) + } + + /// Performs a full recursive validation of this [`ArrayData`] and all its children + /// + /// This is equivalent to calling [`Self::validate_data`] on this [`ArrayData`] + /// and all its children recursively + pub fn validate_full(&self) -> Result<(), ArrowError> { + self.validate_data()?; + // validate all children recursively + self.child_data + .iter() + .enumerate() + .try_for_each(|(i, child_data)| { + child_data.validate_full().map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "{} child #{} invalid: {}", + self.data_type, i, e + )) + }) + })?; + Ok(()) + } + + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + /// 2. the the null count is correct and that any + /// 3. nullability requirements of its children are correct + /// + /// [#85]: https://github.com/apache/arrow-rs/issues/85 + pub fn validate_nulls(&self) -> Result<(), ArrowError> { + if let Some(nulls) = &self.nulls { + let actual = nulls.len() - nulls.inner().count_set_bits(); + if actual != nulls.null_count() { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count value ({}) doesn't match actual number of nulls in array ({})", + nulls.null_count(), + actual + ))); + } + } + + // In general non-nullable children should not contain nulls, however, for certain + // types, such as StructArray and FixedSizeList, nulls in the parent take up + // space in the child. As such we permit nulls in the children in the corresponding + // positions for such types + match &self.data_type { + DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => { + if !f.is_nullable() { + self.validate_non_nullable(None, &self.child_data[0])? + } + } + DataType::FixedSizeList(field, len) => { + let child = &self.child_data[0]; + if !field.is_nullable() { + match &self.nulls { + Some(nulls) => { + let element_len = *len as usize; + let expanded = nulls.expand(element_len); + self.validate_non_nullable(Some(&expanded), child)?; + } + None => self.validate_non_nullable(None, child)?, + } + } + } + DataType::Struct(fields) => { + for (field, child) in fields.iter().zip(&self.child_data) { + if !field.is_nullable() { + self.validate_non_nullable(self.nulls(), child)? + } + } + } + _ => {} + } + + Ok(()) + } + + /// Verifies that `child` contains no nulls not present in `mask` + fn validate_non_nullable( + &self, + mask: Option<&NullBuffer>, + child: &ArrayData, + ) -> Result<(), ArrowError> { + let mask = match mask { + Some(mask) => mask, + None => { + return match child.null_count() { + 0 => Ok(()), + _ => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent {}", + child.data_type, self.data_type + ))), + } + } + }; + + match child.nulls() { + Some(nulls) if !mask.contains(nulls) => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent", + child.data_type + ))), + _ => Ok(()), + } + } + + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + pub fn validate_values(&self) -> Result<(), ArrowError> { + match &self.data_type { + DataType::Utf8 => self.validate_utf8::(), + DataType::LargeUtf8 => self.validate_utf8::(), + DataType::Binary => self.validate_offsets_full::(self.buffers[1].len()), + DataType::LargeBinary => self.validate_offsets_full::(self.buffers[1].len()), + DataType::BinaryView => { + let views = self.typed_buffer::(0, self.len)?; + validate_binary_view(views, &self.buffers[1..]) + } + DataType::Utf8View => { + let views = self.typed_buffer::(0, self.len)?; + validate_string_view(views, &self.buffers[1..]) + } + DataType::List(_) | DataType::Map(_, _) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::LargeList(_) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::Union(_, _) => { + // Validate Union Array as part of implementing new Union semantics + // See comments in `ArrayData::validate()` + // https://github.com/apache/arrow-rs/issues/85 + // + // TODO file follow on ticket for full union validation + Ok(()) + } + DataType::Dictionary(key_type, _value_type) => { + let dictionary_length: i64 = self.child_data[0].len.try_into().unwrap(); + let max_value = dictionary_length - 1; + match key_type.as_ref() { + DataType::UInt8 => self.check_bounds::(max_value), + DataType::UInt16 => self.check_bounds::(max_value), + DataType::UInt32 => self.check_bounds::(max_value), + DataType::UInt64 => self.check_bounds::(max_value), + DataType::Int8 => self.check_bounds::(max_value), + DataType::Int16 => self.check_bounds::(max_value), + DataType::Int32 => self.check_bounds::(max_value), + DataType::Int64 => self.check_bounds::(max_value), + _ => unreachable!(), + } + } + DataType::RunEndEncoded(run_ends, _values) => { + let run_ends_data = self.child_data()[0].clone(); + match run_ends.data_type() { + DataType::Int16 => run_ends_data.check_run_ends::(), + DataType::Int32 => run_ends_data.check_run_ends::(), + DataType::Int64 => run_ends_data.check_run_ends::(), + _ => unreachable!(), + } + } + _ => { + // No extra validation check required for other types + Ok(()) + } + } + } + + /// Calls the `validate(item_index, range)` function for each of + /// the ranges specified in the arrow offsets buffer of type + /// `T`. Also validates that each offset is smaller than + /// `offset_limit` + /// + /// For an empty array, the offsets buffer can either be empty + /// or contain a single `0`. + /// + /// For example, the offsets buffer contained `[1, 2, 4]`, this + /// function would call `validate([1,2])`, and `validate([2,4])` + fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + V: Fn(usize, Range) -> Result<(), ArrowError>, + { + self.typed_offsets::()? + .iter() + .enumerate() + .map(|(i, x)| { + // check if the offset can be converted to usize + let r = x.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: Could not convert offset {x} to usize at position {i}"))} + ); + // check if the offset exceeds the limit + match r { + Ok(n) if n <= offset_limit => Ok((i, n)), + Ok(_) => Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: offset at position {i} out of bounds: {x} > {offset_limit}")) + ), + Err(e) => Err(e), + } + }) + .scan(0_usize, |start, end| { + // check offsets are monotonically increasing + match end { + Ok((i, end)) if *start <= end => { + let range = Some(Ok((i, *start..end))); + *start = end; + range + } + Ok((i, end)) => Some(Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: non-monotonic offset at slot {}: {} > {}", + i - 1, start, end)) + )), + Err(err) => Some(Err(err)), + } + }) + .skip(1) // the first element is meaningless + .try_for_each(|res: Result<(usize, Range), ArrowError>| { + let (item_index, range) = res?; + validate(item_index-1, range) + }) + } + + /// Ensures that all strings formed by the offsets in `buffers[0]` + /// into `buffers[1]` are valid utf8 sequences + fn validate_utf8(&self) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let values_buffer = &self.buffers[1].as_slice(); + if let Ok(values_str) = std::str::from_utf8(values_buffer) { + // Validate Offsets are correct + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + if !values_str.is_char_boundary(range.start) + || !values_str.is_char_boundary(range.end) + { + return Err(ArrowError::InvalidArgumentError(format!( + "incomplete utf-8 byte sequence from index {string_index}" + ))); + } + Ok(()) + }) + } else { + // find specific offset that failed utf8 validation + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {string_index} ({range:?}): {e}" + )) + })?; + Ok(()) + }) + } + } + + /// Ensures that all offsets in `buffers[0]` into `buffers[1]` are + /// between `0` and `offset_limit` + fn validate_offsets_full(&self, offset_limit: usize) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + self.validate_each_offset::(offset_limit, |_string_index, _range| { + // No validation applied to each value, but the iteration + // itself applies bounds checking to each range + Ok(()) + }) + } + + /// Validates that each value in self.buffers (typed as T) + /// is within the range [0, max_value], inclusive + fn check_bounds(&self, max_value: i64) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let required_len = self.len + self.offset; + let buffer = &self.buffers[0]; + + // This should have been checked as part of `validate()` prior + // to calling `validate_full()` but double check to be sure + assert!(buffer.len() / mem::size_of::() >= required_len); + + // Justification: buffer size was validated above + let indexes: &[T] = &buffer.typed_data::()[self.offset..self.offset + self.len]; + + indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { + // Do not check the value is null (value can be arbitrary) + if self.is_null(i) { + return Ok(()); + } + let dict_index: i64 = dict_index.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Value at position {i} out of bounds: {dict_index} (can not convert to i64)" + )) + })?; + + if dict_index < 0 || dict_index > max_value { + return Err(ArrowError::InvalidArgumentError(format!( + "Value at position {i} out of bounds: {dict_index} (should be in [0, {max_value}])" + ))); + } + Ok(()) + }) + } + + /// Validates that each value in run_ends array is positive and strictly increasing. + fn check_run_ends(&self) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let values = self.typed_buffer::(0, self.len)?; + let mut prev_value: i64 = 0_i64; + values.iter().enumerate().try_for_each(|(ix, &inp_value)| { + let value: i64 = inp_value.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Value at position {ix} out of bounds: {inp_value} (can not convert to i64)" + )) + })?; + if value <= 0_i64 { + return Err(ArrowError::InvalidArgumentError(format!( + "The values in run_ends array should be strictly positive. Found value {value} at index {ix} that does not match the criteria." + ))); + } + if ix > 0 && value <= prev_value { + return Err(ArrowError::InvalidArgumentError(format!( + "The values in run_ends array should be strictly increasing. Found value {value} at index {ix} with previous value {prev_value} that does not match the criteria." + ))); + } + + prev_value = value; + Ok(()) + })?; + + if prev_value.as_usize() < (self.offset + self.len) { + return Err(ArrowError::InvalidArgumentError(format!( + "The offset + length of array should be less or equal to last value in the run_ends array. The last value of run_ends array is {prev_value} and offset + length of array is {}.", + self.offset + self.len + ))); + } + Ok(()) + } + + /// Returns true if this `ArrayData` is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + pub fn ptr_eq(&self, other: &Self) -> bool { + if self.offset != other.offset + || self.len != other.len + || self.data_type != other.data_type + || self.buffers.len() != other.buffers.len() + || self.child_data.len() != other.child_data.len() + { + return false; + } + + match (&self.nulls, &other.nulls) { + (Some(a), Some(b)) if !a.inner().ptr_eq(b.inner()) => return false, + (Some(_), None) | (None, Some(_)) => return false, + _ => {} + }; + + if !self + .buffers + .iter() + .zip(other.buffers.iter()) + .all(|(a, b)| a.as_ptr() == b.as_ptr()) + { + return false; + } + + self.child_data + .iter() + .zip(other.child_data.iter()) + .all(|(a, b)| a.ptr_eq(b)) + } + + /// Converts this [`ArrayData`] into an [`ArrayDataBuilder`] + pub fn into_builder(self) -> ArrayDataBuilder { + self.into() + } +} + +/// Return the expected [`DataTypeLayout`] Arrays of this data +/// type are expected to have +pub fn layout(data_type: &DataType) -> DataTypeLayout { + // based on C/C++ implementation in + // https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc) + use arrow_schema::IntervalUnit::*; + + match data_type { + DataType::Null => DataTypeLayout { + buffers: vec![], + can_contain_null_mask: false, + variadic: false, + }, + DataType::Boolean => DataTypeLayout { + buffers: vec![BufferSpec::BitMap], + can_contain_null_mask: true, + variadic: false, + }, + DataType::Int8 => DataTypeLayout::new_fixed_width::(), + DataType::Int16 => DataTypeLayout::new_fixed_width::(), + DataType::Int32 => DataTypeLayout::new_fixed_width::(), + DataType::Int64 => DataTypeLayout::new_fixed_width::(), + DataType::UInt8 => DataTypeLayout::new_fixed_width::(), + DataType::UInt16 => DataTypeLayout::new_fixed_width::(), + DataType::UInt32 => DataTypeLayout::new_fixed_width::(), + DataType::UInt64 => DataTypeLayout::new_fixed_width::(), + DataType::Float16 => DataTypeLayout::new_fixed_width::(), + DataType::Float32 => DataTypeLayout::new_fixed_width::(), + DataType::Float64 => DataTypeLayout::new_fixed_width::(), + DataType::Timestamp(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Date32 => DataTypeLayout::new_fixed_width::(), + DataType::Date64 => DataTypeLayout::new_fixed_width::(), + DataType::Time32(_) => DataTypeLayout::new_fixed_width::(), + DataType::Time64(_) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(YearMonth) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(DayTime) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(MonthDayNano) => { + DataTypeLayout::new_fixed_width::() + } + DataType::Duration(_) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal128(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal256(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::FixedSizeBinary(size) => { + let spec = BufferSpec::FixedWidth { + byte_width: (*size).try_into().unwrap(), + alignment: mem::align_of::(), + }; + DataTypeLayout { + buffers: vec![spec], + can_contain_null_mask: true, + variadic: false, + } + } + DataType::Binary => DataTypeLayout::new_binary::(), + DataType::LargeBinary => DataTypeLayout::new_binary::(), + DataType::Utf8 => DataTypeLayout::new_binary::(), + DataType::LargeUtf8 => DataTypeLayout::new_binary::(), + DataType::BinaryView | DataType::Utf8View => DataTypeLayout::new_view(), + DataType::FixedSizeList(_, _) => DataTypeLayout::new_nullable_empty(), // all in child data + DataType::List(_) => DataTypeLayout::new_fixed_width::(), + DataType::ListView(_) => DataTypeLayout::new_list_view::(), + DataType::LargeListView(_) => DataTypeLayout::new_list_view::(), + DataType::LargeList(_) => DataTypeLayout::new_fixed_width::(), + DataType::Map(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Struct(_) => DataTypeLayout::new_nullable_empty(), // all in child data, + DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all in child data, + DataType::Union(_, mode) => { + let type_ids = BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }; + + DataTypeLayout { + buffers: match mode { + UnionMode::Sparse => { + vec![type_ids] + } + UnionMode::Dense => { + vec![ + type_ids, + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + ] + } + }, + can_contain_null_mask: false, + variadic: false, + } + } + DataType::Dictionary(key_type, _value_type) => layout(key_type), + } +} + +/// Layout specification for a data type +#[derive(Debug, PartialEq, Eq)] +// Note: Follows structure from C++: https://github.com/apache/arrow/blob/master/cpp/src/arrow/type.h#L91 +pub struct DataTypeLayout { + /// A vector of buffer layout specifications, one for each expected buffer + pub buffers: Vec, + + /// Can contain a null bitmask + pub can_contain_null_mask: bool, + + /// This field only applies to the view type [`DataType::BinaryView`] and [`DataType::Utf8View`] + /// If `variadic` is true, the number of buffers expected is only lower-bounded by + /// buffers.len(). Buffers that exceed the lower bound are legal. + pub variadic: bool, +} + +impl DataTypeLayout { + /// Describes a basic numeric array where each element has type `T` + pub fn new_fixed_width() -> Self { + Self { + buffers: vec![BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }], + can_contain_null_mask: true, + variadic: false, + } + } + + /// Describes arrays which have no data of their own + /// but may still have a Null Bitmap (e.g. FixedSizeList) + pub fn new_nullable_empty() -> Self { + Self { + buffers: vec![], + can_contain_null_mask: true, + variadic: false, + } + } + + /// Describes arrays which have no data of their own + /// (e.g. RunEndEncoded). + pub fn new_empty() -> Self { + Self { + buffers: vec![], + can_contain_null_mask: false, + variadic: false, + } + } + + /// Describes a basic numeric array where each element has a fixed + /// with offset buffer of type `T`, followed by a + /// variable width data buffer + pub fn new_binary() -> Self { + Self { + buffers: vec![ + // offsets + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + // values + BufferSpec::VariableWidth, + ], + can_contain_null_mask: true, + variadic: false, + } + } + + /// Describes a view type + pub fn new_view() -> Self { + Self { + buffers: vec![BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }], + can_contain_null_mask: true, + variadic: true, + } + } + + /// Describes a list view type + pub fn new_list_view() -> Self { + Self { + buffers: vec![ + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + ], + can_contain_null_mask: true, + variadic: true, + } + } +} + +/// Layout specification for a single data type buffer +#[derive(Debug, PartialEq, Eq)] +pub enum BufferSpec { + /// Each element is a fixed width primitive, with the given `byte_width` and `alignment` + /// + /// `alignment` is the alignment required by Rust for an array of the corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// Arrow-rs requires that all buffers have at least this alignment, to allow for + /// [slice](std::slice) based APIs. Alignment in excess of this is not required to allow + /// for array slicing and interoperability with `Vec`, which cannot be over-aligned. + /// + /// Note that these alignment requirements will vary between architectures + FixedWidth { + /// The width of each element in bytes + byte_width: usize, + /// The alignment required by Rust for an array of the corresponding primitive + alignment: usize, + }, + /// Variable width, such as string data for utf8 data + VariableWidth, + /// Buffer holds a bitmap. + /// + /// Note: Unlike the C++ implementation, the null/validity buffer + /// is handled specially rather than as another of the buffers in + /// the spec, so this variant is only used for the Boolean type. + BitMap, + /// Buffer is always null. Unused currently in Rust implementation, + /// (used in C++ for Union type) + #[allow(dead_code)] + AlwaysNull, +} + +impl PartialEq for ArrayData { + fn eq(&self, other: &Self) -> bool { + equal::equal(self, other) + } +} + +/// Builder for `ArrayData` type +#[derive(Debug)] +pub struct ArrayDataBuilder { + data_type: DataType, + len: usize, + null_count: Option, + null_bit_buffer: Option, + nulls: Option, + offset: usize, + buffers: Vec, + child_data: Vec, +} + +impl ArrayDataBuilder { + #[inline] + /// Creates a new array data builder + pub const fn new(data_type: DataType) -> Self { + Self { + data_type, + len: 0, + null_count: None, + null_bit_buffer: None, + nulls: None, + offset: 0, + buffers: vec![], + child_data: vec![], + } + } + + /// Creates a new array data builder from an existing one, changing the data type + pub fn data_type(self, data_type: DataType) -> Self { + Self { data_type, ..self } + } + + #[inline] + #[allow(clippy::len_without_is_empty)] + /// Sets the length of the [ArrayData] + pub const fn len(mut self, n: usize) -> Self { + self.len = n; + self + } + + /// Sets the null buffer of the [ArrayData] + pub fn nulls(mut self, nulls: Option) -> Self { + self.nulls = nulls; + self.null_count = None; + self.null_bit_buffer = None; + self + } + + /// Sets the null count of the [ArrayData] + pub fn null_count(mut self, null_count: usize) -> Self { + self.null_count = Some(null_count); + self + } + + /// Sets the `null_bit_buffer` of the [ArrayData] + pub fn null_bit_buffer(mut self, buf: Option) -> Self { + self.nulls = None; + self.null_bit_buffer = buf; + self + } + + /// Sets the offset of the [ArrayData] + #[inline] + pub const fn offset(mut self, n: usize) -> Self { + self.offset = n; + self + } + + /// Sets the buffers of the [ArrayData] + pub fn buffers(mut self, v: Vec) -> Self { + self.buffers = v; + self + } + + /// Adds a single buffer to the [ArrayData]'s buffers + pub fn add_buffer(mut self, b: Buffer) -> Self { + self.buffers.push(b); + self + } + + /// Adds multiple buffers to the [ArrayData]'s buffers + pub fn add_buffers>(mut self, bs: I) -> Self { + self.buffers.extend(bs); + self + } + + /// Sets the child data of the [ArrayData] + pub fn child_data(mut self, v: Vec) -> Self { + self.child_data = v; + self + } + + /// Adds a single child data to the [ArrayData]'s child data + pub fn add_child_data(mut self, r: ArrayData) -> Self { + self.child_data.push(r); + self + } + + /// Creates an array data, without any validation + /// + /// # Safety + /// + /// The same caveats as [`ArrayData::new_unchecked`] + /// apply. + #[allow(clippy::let_and_return)] + pub unsafe fn build_unchecked(self) -> ArrayData { + let data = self.build_impl(); + // Provide a force_validate mode + #[cfg(feature = "force_validate")] + data.validate_data().unwrap(); + data + } + + /// Same as [`Self::build_unchecked`] but ignoring `force_validate` feature flag + unsafe fn build_impl(self) -> ArrayData { + let nulls = self + .nulls + .or_else(|| { + let buffer = self.null_bit_buffer?; + let buffer = BooleanBuffer::new(buffer, self.offset, self.len); + Some(match self.null_count { + Some(n) => NullBuffer::new_unchecked(buffer, n), + None => NullBuffer::new(buffer), + }) + }) + .filter(|b| b.null_count() != 0); + + ArrayData { + data_type: self.data_type, + len: self.len, + offset: self.offset, + buffers: self.buffers, + child_data: self.child_data, + nulls, + } + } + + /// Creates an array data, validating all inputs + pub fn build(self) -> Result { + let data = unsafe { self.build_impl() }; + data.validate_data()?; + Ok(data) + } + + /// Creates an array data, validating all inputs, and aligning any buffers + /// + /// Rust requires that arrays are aligned to their corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// [`ArrayData`] therefore requires that all buffers have at least this alignment, + /// to allow for [slice](std::slice) based APIs. See [`BufferSpec::FixedWidth`]. + /// + /// As this alignment is architecture specific, and not guaranteed by all arrow implementations, + /// this method is provided to automatically copy buffers to a new correctly aligned allocation + /// when necessary, making it useful when interacting with buffers produced by other systems, + /// e.g. IPC or FFI. + /// + /// This is unlike `[Self::build`] which will instead return an error on encountering + /// insufficiently aligned buffers. + pub fn build_aligned(self) -> Result { + let mut data = unsafe { self.build_impl() }; + data.align_buffers(); + data.validate_data()?; + Ok(data) + } +} + +impl From for ArrayDataBuilder { + fn from(d: ArrayData) -> Self { + Self { + data_type: d.data_type, + len: d.len, + offset: d.offset, + buffers: d.buffers, + child_data: d.child_data, + nulls: d.nulls, + null_bit_buffer: None, + null_count: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::Field; + + // See arrow/tests/array_data_validation.rs for test of array validation + + /// returns a buffer initialized with some constant value for tests + fn make_i32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(vec![42i32; n]) + } + + /// returns a buffer initialized with some constant value for tests + fn make_f32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(vec![42f32; n]) + } + + #[test] + fn test_builder() { + // Buffer needs to be at least 25 long + let v = (0..25).collect::>(); + let b1 = Buffer::from_slice_ref(&v); + let arr_data = ArrayData::builder(DataType::Int32) + .len(20) + .offset(5) + .add_buffer(b1) + .null_bit_buffer(Some(Buffer::from([ + 0b01011111, 0b10110101, 0b01100011, 0b00011110, + ]))) + .build() + .unwrap(); + + assert_eq!(20, arr_data.len()); + assert_eq!(10, arr_data.null_count()); + assert_eq!(5, arr_data.offset()); + assert_eq!(1, arr_data.buffers().len()); + assert_eq!( + Buffer::from_slice_ref(&v).as_slice(), + arr_data.buffers()[0].as_slice() + ); + } + + #[test] + fn test_builder_with_child_data() { + let child_arr_data = ArrayData::try_new( + DataType::Int32, + 5, + None, + 0, + vec![Buffer::from_slice_ref([1i32, 2, 3, 4, 5])], + vec![], + ) + .unwrap(); + + let field = Arc::new(Field::new("x", DataType::Int32, true)); + let data_type = DataType::Struct(vec![field].into()); + + let arr_data = ArrayData::builder(data_type) + .len(5) + .offset(0) + .add_child_data(child_arr_data.clone()) + .build() + .unwrap(); + + assert_eq!(5, arr_data.len()); + assert_eq!(1, arr_data.child_data().len()); + assert_eq!(child_arr_data, arr_data.child_data()[0]); + } + + #[test] + fn test_null_count() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert_eq!(13, arr_data.null_count()); + + // Test with offset + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(12) + .offset(2) + .add_buffer(make_i32_buffer(14)) // requires at least 14 bytes of space, + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert_eq!(10, arr_data.null_count()); + } + + #[test] + fn test_null_buffer_ref() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert!(arr_data.nulls().is_some()); + assert_eq!(&bit_v, arr_data.nulls().unwrap().validity()); + } + + #[test] + fn test_slice() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + let new_data = data.slice(1, 15); + assert_eq!(data.len() - 1, new_data.len()); + assert_eq!(1, new_data.offset()); + assert_eq!(data.null_count(), new_data.null_count()); + + // slice of a slice (removes one null) + let new_data = new_data.slice(1, 14); + assert_eq!(data.len() - 2, new_data.len()); + assert_eq!(2, new_data.offset()); + assert_eq!(data.null_count() - 1, new_data.null_count()); + } + + #[test] + fn test_equality() { + let int_data = ArrayData::builder(DataType::Int32) + .len(1) + .add_buffer(make_i32_buffer(1)) + .build() + .unwrap(); + + let float_data = ArrayData::builder(DataType::Float32) + .len(1) + .add_buffer(make_f32_buffer(1)) + .build() + .unwrap(); + assert_ne!(int_data, float_data); + assert!(!int_data.ptr_eq(&float_data)); + assert!(int_data.ptr_eq(&int_data)); + + #[allow(clippy::redundant_clone)] + let int_data_clone = int_data.clone(); + assert_eq!(int_data, int_data_clone); + assert!(int_data.ptr_eq(&int_data_clone)); + assert!(int_data_clone.ptr_eq(&int_data)); + + let int_data_slice = int_data_clone.slice(1, 0); + assert!(int_data_slice.ptr_eq(&int_data_slice)); + assert!(!int_data.ptr_eq(&int_data_slice)); + assert!(!int_data_slice.ptr_eq(&int_data)); + + let data_buffer = Buffer::from_slice_ref("abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref([0_i32, 2_i32, 2_i32, 5_i32]); + let string_data = ArrayData::try_new( + DataType::Utf8, + 3, + Some(Buffer::from_iter(vec![true, false, true])), + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + + assert_ne!(float_data, string_data); + assert!(!float_data.ptr_eq(&string_data)); + + assert!(string_data.ptr_eq(&string_data)); + + #[allow(clippy::redundant_clone)] + let string_data_cloned = string_data.clone(); + assert!(string_data_cloned.ptr_eq(&string_data)); + assert!(string_data.ptr_eq(&string_data_cloned)); + + let string_data_slice = string_data.slice(1, 2); + assert!(string_data_slice.ptr_eq(&string_data_slice)); + assert!(!string_data_slice.ptr_eq(&string_data)) + } + + #[test] + fn test_slice_memory_size() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + let new_data = data.slice(1, 14); + assert_eq!( + data.get_slice_memory_size().unwrap() - 8, + new_data.get_slice_memory_size().unwrap() + ); + let data_buffer = Buffer::from_slice_ref("abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref([0_i32, 2_i32, 2_i32, 5_i32]); + let string_data = ArrayData::try_new( + DataType::Utf8, + 3, + Some(Buffer::from_iter(vec![true, false, true])), + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + let string_data_slice = string_data.slice(1, 2); + //4 bytes of offset and 2 bytes of data reduced by slicing. + assert_eq!( + string_data.get_slice_memory_size().unwrap() - 6, + string_data_slice.get_slice_memory_size().unwrap() + ); + } + + #[test] + fn test_count_nulls() { + let buffer = Buffer::from([0b00010110, 0b10011111]); + let buffer = NullBuffer::new(BooleanBuffer::new(buffer, 0, 16)); + let count = count_nulls(Some(&buffer), 0, 16); + assert_eq!(count, 7); + + let count = count_nulls(Some(&buffer), 4, 8); + assert_eq!(count, 3); + } + + #[test] + fn test_contains_nulls() { + let buffer: Buffer = + MutableBuffer::from_iter([false, false, false, true, true, false]).into(); + let buffer = NullBuffer::new(BooleanBuffer::new(buffer, 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 3)); + assert!(!contains_nulls(Some(&buffer), 3, 2)); + assert!(!contains_nulls(Some(&buffer), 0, 0)); + } + + #[test] + fn test_alignment() { + let buffer = Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let sliced = buffer.slice(1); + + let mut data = ArrayData { + data_type: DataType::Int32, + len: 0, + offset: 0, + buffers: vec![buffer], + child_data: vec![], + nulls: None, + }; + data.validate_full().unwrap(); + + data.buffers[0] = sliced; + let err = data.validate().unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Int32, offset from expected alignment of 4 by 1" + ); + + data.align_buffers(); + data.validate_full().unwrap(); + } + + #[test] + fn test_null_view_types() { + let array_len = 32; + let array = ArrayData::new_null(&DataType::BinaryView, array_len); + assert_eq!(array.len(), array_len); + for i in 0..array.len() { + assert!(array.is_null(i)); + } + + let array = ArrayData::new_null(&DataType::Utf8View, array_len); + assert_eq!(array.len(), array_len); + for i in 0..array.len() { + assert!(array.is_null(i)); + } + } +} diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs new file mode 100644 index 000000000000..fe19db641236 --- /dev/null +++ b/arrow-data/src/decimal.rs @@ -0,0 +1,900 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines maximum and minimum values for `decimal256` and `decimal128` types for varying precisions. +//! +//! Also provides functions to validate if a given decimal value is within the valid range of the decimal type. + +use arrow_buffer::i256; +use arrow_schema::ArrowError; + +pub use arrow_schema::{ + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL_DEFAULT_SCALE, +}; + +/// MAX decimal256 value of little-endian format for each precision. +/// Each element is the max value of signed 256-bit integer for the specified precision which +/// is encoded to the 32-byte width format of little-endian. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [ + i256::from_i128(0_i128), // unused first element + i256::from_le_bytes([ + 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 231, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 15, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 159, 134, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]), + i256::from_le_bytes([ + 63, 66, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]), + i256::from_le_bytes([ + 127, 150, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 224, 245, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 201, 154, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 227, 11, 84, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 231, 118, 72, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 15, 165, 212, 232, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 159, 114, 78, 24, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 63, 122, 16, 243, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 127, 198, 164, 126, 141, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 192, 111, 242, 134, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 137, 93, 120, 69, 99, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 99, 167, 179, 182, 224, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 231, 137, 4, 35, 199, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 15, 99, 45, 94, 199, 107, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 159, 222, 197, 173, 201, 53, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 63, 178, 186, 201, 224, 25, 30, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 127, 246, 74, 225, 199, 2, 45, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 160, 237, 204, 206, 27, 194, 211, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 73, 72, 1, 20, 22, 149, 69, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 227, 210, 12, 200, 220, 210, 183, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 231, 60, 128, 208, 159, 60, 46, 59, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 15, 97, 2, 37, 62, 94, 206, 79, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 159, 202, 23, 114, 109, 174, 15, 30, 67, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 63, 234, 237, 116, 70, 208, 156, 44, 159, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 127, 38, 75, 145, 192, 34, 32, 190, 55, 126, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 128, 239, 172, 133, 91, 65, 109, 45, 238, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 9, 91, 193, 56, 147, 141, 68, 198, 77, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 99, 142, 141, 55, 192, 135, 173, 190, 9, 237, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 231, 143, 135, 43, 130, 77, 199, 114, 97, 66, 19, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 15, 159, 75, 179, 21, 7, 201, 123, 206, 151, 192, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 159, 54, 244, 0, 217, 70, 218, 213, 16, 238, 133, 7, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 127, 86, 101, 95, 196, 172, 67, 137, 147, 254, 80, 240, 2, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 96, 245, 185, 171, 191, 164, 92, 195, 241, 41, 99, 29, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 201, 149, 67, 181, 124, 111, 158, 161, 113, 163, 223, 37, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 227, 217, 163, 20, 223, 90, 48, 80, 112, 98, 188, 122, 11, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 231, 130, 102, 206, 182, 140, 227, 33, 99, 216, 91, 203, 114, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 15, 29, 1, 16, 36, 127, 227, 82, 223, 115, 150, 241, 123, 4, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 159, 34, 11, 160, 104, 247, 226, 60, 185, 134, 224, 111, 215, 44, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 63, 90, 111, 64, 22, 170, 221, 96, 60, 67, 197, 94, 106, 192, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 127, 134, 89, 132, 222, 164, 168, 200, 91, 160, 180, 179, 39, 132, + 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 64, 127, 43, 177, 112, 150, 214, 149, 67, 14, 5, 141, 41, + 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 137, 248, 178, 235, 102, 224, 97, 218, 163, 142, 50, 130, + 159, 215, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 99, 181, 253, 52, 5, 196, 210, 135, 102, 146, 249, 21, 59, + 108, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 231, 21, 233, 17, 52, 168, 59, 78, 1, 184, 191, 219, 78, 58, + 172, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 15, 219, 26, 179, 8, 146, 84, 14, 13, 48, 125, 149, 20, 71, + 186, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 159, 142, 12, 255, 86, 180, 77, 143, 130, 224, 227, 214, 205, + 198, 70, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 63, 146, 125, 246, 101, 11, 9, 153, 25, 197, 230, 100, 10, + 196, 195, 112, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 127, 182, 231, 160, 251, 113, 90, 250, 255, 178, 3, 241, 103, + 168, 165, 103, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 32, 13, 73, 212, 115, 136, 199, 255, 253, 36, 106, 15, + 148, 120, 12, 20, 4, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 73, 131, 218, 74, 134, 84, 203, 253, 235, 113, 37, 154, + 200, 181, 124, 200, 40, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 227, 32, 137, 236, 62, 77, 241, 233, 55, 115, 118, 5, + 214, 25, 223, 212, 151, 1, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 231, 72, 91, 61, 117, 4, 109, 35, 47, 128, 160, 54, 92, + 2, 183, 80, 238, 15, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 15, 217, 144, 101, 148, 44, 66, 98, 215, 1, 69, 34, 154, + 23, 38, 39, 79, 159, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 159, 122, 168, 247, 203, 189, 149, 214, 105, 18, 178, + 86, 5, 236, 124, 135, 23, 57, 6, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 63, 202, 148, 172, 247, 105, 217, 97, 34, 184, 244, 98, + 53, 56, 225, 74, 235, 58, 62, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 127, 230, 207, 189, 172, 35, 126, 210, 87, 49, 143, 221, + 21, 50, 204, 236, 48, 77, 110, 2, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 0, 31, 106, 191, 100, 237, 56, 110, 237, 151, 167, + 218, 244, 249, 63, 233, 3, 79, 24, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 9, 54, 37, 122, 239, 69, 57, 78, 70, 239, 139, 138, + 144, 195, 127, 28, 39, 22, 243, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 99, 28, 116, 197, 90, 187, 60, 14, 191, 88, 119, + 105, 165, 163, 253, 28, 135, 221, 126, 9, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 231, 27, 137, 182, 139, 81, 95, 142, 118, 119, 169, + 30, 118, 100, 232, 33, 71, 167, 244, 94, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 15, 23, 91, 33, 117, 47, 185, 143, 161, 170, 158, + 50, 157, 236, 19, 83, 199, 136, 142, 181, 3, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 159, 230, 142, 77, 147, 218, 59, 157, 79, 170, 50, + 250, 35, 62, 199, 62, 201, 87, 145, 23, 37, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 63, 2, 149, 7, 193, 137, 86, 36, 28, 167, 250, 197, + 103, 109, 200, 115, 220, 109, 173, 235, 114, 1, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 127, 22, 210, 75, 138, 97, 97, 107, 25, 135, 202, + 187, 13, 70, 212, 133, 156, 74, 198, 52, 125, 14, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 224, 52, 246, 102, 207, 205, 49, 254, 70, 233, + 85, 137, 188, 74, 58, 29, 234, 190, 15, 228, 144, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 201, 16, 158, 5, 26, 10, 242, 237, 197, 28, + 91, 93, 93, 235, 70, 36, 37, 117, 157, 232, 168, 5, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 227, 167, 44, 56, 4, 101, 116, 75, 187, 31, + 143, 165, 165, 49, 197, 106, 115, 147, 38, 22, 153, 56, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 231, 142, 190, 49, 42, 242, 139, 242, 80, 61, + 151, 119, 120, 240, 179, 43, 130, 194, 129, 221, 250, 53, 2, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 15, 149, 113, 241, 165, 117, 119, 121, 41, + 101, 232, 171, 180, 100, 7, 181, 21, 153, 17, 167, 204, 27, 22, + ]), +]; + +/// MIN decimal256 value of little-endian format for each precision. +/// Each element is the min value of signed 256-bit integer for the specified precision which +/// is encoded to the 76-byte width format of little-endian. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [ + i256::from_i128(0_i128), // unused first element + i256::from_le_bytes([ + 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 157, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 25, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 241, 216, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 97, 121, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 193, 189, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 129, 105, 103, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 31, 10, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 54, 101, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 28, 244, 171, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 24, 137, 183, 232, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 240, 90, 43, 23, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 96, 141, 177, 231, 246, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 192, 133, 239, 12, 165, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 128, 57, 91, 129, 114, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 63, 144, 13, 121, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 118, 162, 135, 186, 156, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 156, 88, 76, 73, 31, 242, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 24, 118, 251, 220, 56, 117, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 240, 156, 210, 161, 56, 148, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 96, 33, 58, 82, 54, 202, 201, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 192, 77, 69, 54, 31, 230, 225, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 128, 9, 181, 30, 56, 253, 210, 234, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 95, 18, 51, 49, 228, 61, 44, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 182, 183, 254, 235, 233, 106, 186, 247, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 28, 45, 243, 55, 35, 45, 72, 173, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 24, 195, 127, 47, 96, 195, 209, 196, 252, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 240, 158, 253, 218, 193, 161, 49, 176, 223, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 96, 53, 232, 141, 146, 81, 240, 225, 188, 254, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 192, 21, 18, 139, 185, 47, 99, 211, 96, 243, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 128, 217, 180, 110, 63, 221, 223, 65, 200, 129, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 127, 16, 83, 122, 164, 190, 146, 210, 17, 251, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 246, 164, 62, 199, 108, 114, 187, 57, 178, 206, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 156, 113, 114, 200, 63, 120, 82, 65, 246, 18, 254, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 24, 112, 120, 212, 125, 178, 56, 141, 158, 189, 236, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 240, 96, 180, 76, 234, 248, 54, 132, 49, 104, 63, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 96, 201, 11, 255, 38, 185, 37, 42, 239, 17, 122, 248, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 192, 221, 117, 246, 133, 59, 121, 165, 87, 179, 196, 180, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 128, 169, 154, 160, 59, 83, 188, 118, 108, 1, 175, 15, 253, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 159, 10, 70, 84, 64, 91, 163, 60, 14, 214, 156, 226, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 54, 106, 188, 74, 131, 144, 97, 94, 142, 92, 32, 218, 254, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 28, 38, 92, 235, 32, 165, 207, 175, 143, 157, 67, 133, 244, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 24, 125, 153, 49, 73, 115, 28, 222, 156, 39, 164, 52, 141, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 240, 226, 254, 239, 219, 128, 28, 173, 32, 140, 105, 14, 132, 251, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 96, 221, 244, 95, 151, 8, 29, 195, 70, 121, 31, 144, 40, 211, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 192, 165, 144, 191, 233, 85, 34, 159, 195, 188, 58, 161, 149, 63, 254, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 128, 121, 166, 123, 33, 91, 87, 55, 164, 95, 75, 76, 216, 123, 238, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 191, 128, 212, 78, 143, 105, 41, 106, 188, 241, 250, 114, 214, 80, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 118, 7, 77, 20, 153, 31, 158, 37, 92, 113, 205, 125, 96, 40, 249, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 156, 74, 2, 203, 250, 59, 45, 120, 153, 109, 6, 234, 196, 147, 187, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 24, 234, 22, 238, 203, 87, 196, 177, 254, 71, 64, 36, 177, 197, 83, 253, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 240, 36, 229, 76, 247, 109, 171, 241, 242, 207, 130, 106, 235, 184, 69, + 229, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 96, 113, 243, 0, 169, 75, 178, 112, 125, 31, 28, 41, 50, 57, 185, 244, + 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 192, 109, 130, 9, 154, 244, 246, 102, 230, 58, 25, 155, 245, 59, 60, 143, + 245, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 128, 73, 24, 95, 4, 142, 165, 5, 0, 77, 252, 14, 152, 87, 90, 152, 151, + 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 223, 242, 182, 43, 140, 119, 56, 0, 2, 219, 149, 240, 107, 135, 243, + 235, 251, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 182, 124, 37, 181, 121, 171, 52, 2, 20, 142, 218, 101, 55, 74, 131, + 55, 215, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 28, 223, 118, 19, 193, 178, 14, 22, 200, 140, 137, 250, 41, 230, 32, + 43, 104, 254, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 24, 183, 164, 194, 138, 251, 146, 220, 208, 127, 95, 201, 163, 253, + 72, 175, 17, 240, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 240, 38, 111, 154, 107, 211, 189, 157, 40, 254, 186, 221, 101, 232, + 217, 216, 176, 96, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 96, 133, 87, 8, 52, 66, 106, 41, 150, 237, 77, 169, 250, 19, 131, 120, + 232, 198, 249, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 192, 53, 107, 83, 8, 150, 38, 158, 221, 71, 11, 157, 202, 199, 30, + 181, 20, 197, 193, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 128, 25, 48, 66, 83, 220, 129, 45, 168, 206, 112, 34, 234, 205, 51, + 19, 207, 178, 145, 253, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 255, 224, 149, 64, 155, 18, 199, 145, 18, 104, 88, 37, 11, 6, 192, + 22, 252, 176, 231, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 246, 201, 218, 133, 16, 186, 198, 177, 185, 16, 116, 117, 111, 60, + 128, 227, 216, 233, 12, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 156, 227, 139, 58, 165, 68, 195, 241, 64, 167, 136, 150, 90, 92, 2, + 227, 120, 34, 129, 246, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 24, 228, 118, 73, 116, 174, 160, 113, 137, 136, 86, 225, 137, 155, + 23, 222, 184, 88, 11, 161, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 240, 232, 164, 222, 138, 208, 70, 112, 94, 85, 97, 205, 98, 19, + 236, 172, 56, 119, 113, 74, 252, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 96, 25, 113, 178, 108, 37, 196, 98, 176, 85, 205, 5, 220, 193, 56, + 193, 54, 168, 110, 232, 218, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 192, 253, 106, 248, 62, 118, 169, 219, 227, 88, 5, 58, 152, 146, + 55, 140, 35, 146, 82, 20, 141, 254, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 128, 233, 45, 180, 117, 158, 158, 148, 230, 120, 53, 68, 242, 185, + 43, 122, 99, 181, 57, 203, 130, 241, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 31, 203, 9, 153, 48, 50, 206, 1, 185, 22, 170, 118, 67, 181, + 197, 226, 21, 65, 240, 27, 111, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 54, 239, 97, 250, 229, 245, 13, 18, 58, 227, 164, 162, 162, 20, + 185, 219, 218, 138, 98, 23, 87, 250, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 28, 88, 211, 199, 251, 154, 139, 180, 68, 224, 112, 90, 90, 206, + 58, 149, 140, 108, 217, 233, 102, 199, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 24, 113, 65, 206, 213, 13, 116, 13, 175, 194, 104, 136, 135, 15, + 76, 212, 125, 61, 126, 34, 5, 202, 253, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 240, 106, 142, 14, 90, 138, 136, 134, 214, 154, 23, 84, 75, 155, + 248, 74, 234, 102, 238, 88, 51, 228, 233, + ]), +]; + +/// `MAX_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the maximum `i128` value that can +/// be stored in [arrow_schema::DataType::Decimal128] value of precision `p` +#[allow(dead_code)] // no longer used but is part of our public API +pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, + 9999999999999999999, + 99999999999999999999, + 999999999999999999999, + 9999999999999999999999, + 99999999999999999999999, + 999999999999999999999999, + 9999999999999999999999999, + 99999999999999999999999999, + 999999999999999999999999999, + 9999999999999999999999999999, + 99999999999999999999999999999, + 999999999999999999999999999999, + 9999999999999999999999999999999, + 99999999999999999999999999999999, + 999999999999999999999999999999999, + 9999999999999999999999999999999999, + 99999999999999999999999999999999999, + 999999999999999999999999999999999999, + 9999999999999999999999999999999999999, + 99999999999999999999999999999999999999, +]; + +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the minimum `i128` value that can +/// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p` +#[allow(dead_code)] // no longer used but is part of our public API +pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, + -9999999999999999999, + -99999999999999999999, + -999999999999999999999, + -9999999999999999999999, + -99999999999999999999999, + -999999999999999999999999, + -9999999999999999999999999, + -99999999999999999999999999, + -999999999999999999999999999, + -9999999999999999999999999999, + -99999999999999999999999999999, + -999999999999999999999999999999, + -9999999999999999999999999999999, + -99999999999999999999999999999999, + -999999999999999999999999999999999, + -9999999999999999999999999999999999, + -99999999999999999999999999999999999, + -999999999999999999999999999999999999, + -9999999999999999999999999999999999999, + -99999999999999999999999999999999999999, +]; + +/// `MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[p]` holds the maximum `i128` value that can +/// be stored in [arrow_schema::DataType::Decimal128] value of precision `p`. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [ + 0, // unused first element + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, + 9999999999999999999, + 99999999999999999999, + 999999999999999999999, + 9999999999999999999999, + 99999999999999999999999, + 999999999999999999999999, + 9999999999999999999999999, + 99999999999999999999999999, + 999999999999999999999999999, + 9999999999999999999999999999, + 99999999999999999999999999999, + 999999999999999999999999999999, + 9999999999999999999999999999999, + 99999999999999999999999999999999, + 999999999999999999999999999999999, + 9999999999999999999999999999999999, + 99999999999999999999999999999999999, + 999999999999999999999999999999999999, + 9999999999999999999999999999999999999, + 99999999999999999999999999999999999999, +]; + +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can +/// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p`. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [ + 0, // unused first element + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, + -9999999999999999999, + -99999999999999999999, + -999999999999999999999, + -9999999999999999999999, + -99999999999999999999999, + -999999999999999999999999, + -9999999999999999999999999, + -99999999999999999999999999, + -999999999999999999999999999, + -9999999999999999999999999999, + -99999999999999999999999999999, + -999999999999999999999999999999, + -9999999999999999999999999999999, + -99999999999999999999999999999999, + -999999999999999999999999999999999, + -9999999999999999999999999999999999, + -99999999999999999999999999999999999, + -999999999999999999999999999999999999, + -9999999999999999999999999999999999999, + -99999999999999999999999999999999999999, +]; + +/// Validates that the specified `i128` value can be properly +/// interpreted as a Decimal number with precision `precision` +#[inline] +pub fn validate_decimal_precision(value: i128, precision: u8) -> Result<(), ArrowError> { + if precision > DECIMAL128_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal128 is {DECIMAL128_MAX_PRECISION}, but got {precision}", + ))); + } + if value > MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] { + Err(ArrowError::InvalidArgumentError(format!( + "{value} is too large to store in a Decimal128 of precision {precision}. Max is {}", + MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] + ))) + } else if value < MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] { + Err(ArrowError::InvalidArgumentError(format!( + "{value} is too small to store in a Decimal128 of precision {precision}. Min is {}", + MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] + ))) + } else { + Ok(()) + } +} + +/// Determines whether the specified `i128` value can be properly +/// interpreted as a Decimal number with precision `precision` +#[inline] +pub fn is_validate_decimal_precision(value: i128, precision: u8) -> bool { + precision <= DECIMAL128_MAX_PRECISION + && value >= MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] + && value <= MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] +} + +/// Validates that the specified `i256` of value can be properly +/// interpreted as a Decimal256 number with precision `precision` +#[inline] +pub fn validate_decimal256_precision(value: i256, precision: u8) -> Result<(), ArrowError> { + if precision > DECIMAL256_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but got {precision}", + ))); + } + if value > MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] { + Err(ArrowError::InvalidArgumentError(format!( + "{value:?} is too large to store in a Decimal256 of precision {precision}. Max is {:?}", + MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] + ))) + } else if value < MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] { + Err(ArrowError::InvalidArgumentError(format!( + "{value:?} is too small to store in a Decimal256 of precision {precision}. Min is {:?}", + MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] + ))) + } else { + Ok(()) + } +} + +/// Determines whether the specified `i256` value can be properly +/// interpreted as a Decimal256 number with precision `precision` +#[inline] +pub fn is_validate_decimal256_precision(value: i256, precision: u8) -> bool { + precision <= DECIMAL256_MAX_PRECISION + && value >= MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] + && value <= MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] +} diff --git a/arrow/src/array/equal/boolean.rs b/arrow-data/src/equal/boolean.rs similarity index 65% rename from arrow/src/array/equal/boolean.rs rename to arrow-data/src/equal/boolean.rs index fddf21b963ad..addae936f118 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow-data/src/equal/boolean.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::contains_nulls, ArrayData}; -use crate::util::bit_iterator::BitIndexIterator; -use crate::util::bit_util::get_bit; +use crate::bit_iterator::BitIndexIterator; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; @@ -33,7 +33,7 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let contains_nulls = contains_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let contains_nulls = contains_nulls(lhs.nulls(), lhs_start, len); if !contains_nulls { // Optimize performance for starting offset at u8 boundary. @@ -76,42 +76,12 @@ pub(super) fn boolean_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); - let lhs_start = lhs.offset() + lhs_start; - let rhs_start = rhs.offset() + rhs_start; - - BitIndexIterator::new(lhs_null_bytes, lhs_start, len).all(|i| { - let lhs_pos = lhs_start + i; - let rhs_pos = rhs_start + i; + BitIndexIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len).all(|i| { + let lhs_pos = lhs_start + lhs.offset() + i; + let rhs_pos = rhs_start + rhs.offset() + i; get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) }) } } - -#[cfg(test)] -mod tests { - use crate::array::{Array, BooleanArray}; - - #[test] - fn test_boolean_slice() { - let array = BooleanArray::from(vec![true; 32]); - let slice = array.slice(4, 12); - assert_eq!(slice.data(), slice.data()); - - let slice = array.slice(8, 12); - assert_eq!(slice.data(), slice.data()); - - let slice = array.slice(8, 24); - assert_eq!(slice.data(), slice.data()); - } - - #[test] - fn test_sliced_nullable_boolean_array() { - let a = BooleanArray::from(vec![None; 32]); - let b = BooleanArray::from(vec![true; 32]); - let slice_a = a.slice(1, 12); - let slice_b = b.slice(1, 12); - assert_ne!(slice_a.data(), slice_b.data()); - } -} diff --git a/arrow-data/src/equal/byte_view.rs b/arrow-data/src/equal/byte_view.rs new file mode 100644 index 000000000000..def395125366 --- /dev/null +++ b/arrow-data/src/equal/byte_view.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ArrayData, ByteView}; + +pub(super) fn byte_view_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_views = &lhs.buffer::(0)[lhs_start..lhs_start + len]; + let lhs_buffers = &lhs.buffers()[1..]; + let rhs_views = &rhs.buffer::(0)[rhs_start..rhs_start + len]; + let rhs_buffers = &rhs.buffers()[1..]; + + for (idx, (l, r)) in lhs_views.iter().zip(rhs_views).enumerate() { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if lhs.is_null(idx) { + continue; + } + + let l_len_prefix = *l as u64; + let r_len_prefix = *r as u64; + // short-circuit, check length and prefix + if l_len_prefix != r_len_prefix { + return false; + } + + let len = l_len_prefix as u32; + // for inline storage, only need check view + if len <= 12 { + if l != r { + return false; + } + continue; + } + + // check buffers + let l_view = ByteView::from(*l); + let r_view = ByteView::from(*r); + + let l_buffer = &lhs_buffers[l_view.buffer_index as usize]; + let r_buffer = &rhs_buffers[r_view.buffer_index as usize]; + + // prefixes are already known to be equal; skip checking them + let len = len as usize - 4; + let l_offset = l_view.offset as usize + 4; + let r_offset = r_view.offset as usize + 4; + if l_buffer[l_offset..l_offset + len] != r_buffer[r_offset..r_offset + len] { + return false; + } + } + true +} + +#[cfg(test)] +mod tests {} diff --git a/arrow/src/array/equal/dictionary.rs b/arrow-data/src/equal/dictionary.rs similarity index 75% rename from arrow/src/array/equal/dictionary.rs rename to arrow-data/src/equal/dictionary.rs index 4c9bcf798760..1d9c4b8d964f 100644 --- a/arrow/src/array/equal/dictionary.rs +++ b/arrow-data/src/equal/dictionary.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; -use crate::datatypes::ArrowNativeType; -use crate::util::bit_util::get_bit; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; use super::equal_range; @@ -34,10 +33,9 @@ pub(super) fn dictionary_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -52,14 +50,14 @@ pub(super) fn dictionary_equal( }) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow-data/src/equal/fixed_binary.rs b/arrow-data/src/equal/fixed_binary.rs new file mode 100644 index 000000000000..0778d77e2fdd --- /dev/null +++ b/arrow-data/src/equal/fixed_binary.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::BitSliceIterator; +use crate::contains_nulls; +use crate::data::ArrayData; +use crate::equal::primitive::NULL_SLICES_SELECTIVITY_THRESHOLD; +use arrow_schema::DataType; + +use super::utils::equal_len; + +pub(super) fn fixed_binary_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let size = match lhs.data_type() { + DataType::FixedSizeBinary(i) => *i as usize, + _ => unreachable!(), + }; + + let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; + let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; + + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { + equal_len( + lhs_values, + rhs_values, + size * lhs_start, + size * rhs_start, + size * len, + ) + } else { + let selectivity_frac = lhs.null_count() as f64 / lhs.len() as f64; + + if selectivity_frac >= NULL_SLICES_SELECTIVITY_THRESHOLD { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * size, + rhs_pos * size, + size, // 1 * size since we are comparing a single entry + ) + }) + } else { + let lhs_nulls = lhs.nulls().unwrap(); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); + let rhs_nulls = rhs.nulls().unwrap(); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); + + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { + l_start == r_start + && l_end == r_end + && equal_len( + lhs_values, + rhs_values, + (lhs_start + l_start) * size, + (rhs_start + r_start) * size, + (l_end - l_start) * size, + ) + }) + } + } +} diff --git a/arrow/src/array/equal/fixed_list.rs b/arrow-data/src/equal/fixed_list.rs similarity index 75% rename from arrow/src/array/equal/fixed_list.rs rename to arrow-data/src/equal/fixed_list.rs index 82a347c86574..4b79e5c33fab 100644 --- a/arrow/src/array/equal/fixed_list.rs +++ b/arrow-data/src/equal/fixed_list.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; -use crate::datatypes::DataType; -use crate::util::bit_util::get_bit; +use crate::data::{contains_nulls, ArrayData}; +use arrow_schema::DataType; use super::equal_range; @@ -36,10 +35,9 @@ pub(super) fn fixed_list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { equal_range( lhs_values, rhs_values, @@ -49,15 +47,15 @@ pub(super) fn fixed_list_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow/src/array/equal/list.rs b/arrow-data/src/equal/list.rs similarity index 69% rename from arrow/src/array/equal/list.rs rename to arrow-data/src/equal/list.rs index b3bca9a69228..cc4ba3cacf9f 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow-data/src/equal/list.rs @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::ArrayData, - array::{data::count_nulls, OffsetSizeTrait}, - util::bit_util::get_bit, -}; +use crate::data::{count_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; +use num::Integer; use super::equal_range; -fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { +fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { // invariant from `base_equal` debug_assert_eq!(lhs.len(), rhs.len()); @@ -45,7 +43,7 @@ fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { }) } -pub(super) fn list_equal( +pub(super) fn list_equal( lhs: &ArrayData, rhs: &ArrayData, lhs_start: usize, @@ -91,8 +89,8 @@ pub(super) fn list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); + let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); if lhs_null_count != rhs_null_count { return false; @@ -113,8 +111,8 @@ pub(super) fn list_equal( ) } else { // get a ref of the parent null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null // TODO: Could potentially compare runs of not NULL values @@ -122,8 +120,8 @@ pub(super) fn list_equal( let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); if lhs_is_null != rhs_is_null { return false; @@ -149,52 +147,3 @@ pub(super) fn list_equal( }) } } - -#[cfg(test)] -mod tests { - use crate::{ - array::{Array, Int64Builder, ListArray, ListBuilder}, - datatypes::Int32Type, - }; - - #[test] - fn list_array_non_zero_nulls() { - // Tests handling of list arrays with non-empty null ranges - let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); - builder.values().append_value(1); - builder.values().append_value(2); - builder.values().append_value(3); - builder.append(true); - builder.append(false); - let array1 = builder.finish(); - - let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); - builder.values().append_value(1); - builder.values().append_value(2); - builder.values().append_value(3); - builder.append(true); - builder.values().append_null(); - builder.values().append_null(); - builder.append(false); - let array2 = builder.finish(); - - assert_eq!(array1, array2); - } - - #[test] - fn test_list_different_offsets() { - let a = ListArray::from_iter_primitive::([ - Some([Some(0), Some(0)]), - Some([Some(1), Some(2)]), - Some([None, None]), - ]); - let b = ListArray::from_iter_primitive::([ - Some([Some(1), Some(2)]), - Some([None, None]), - Some([None, None]), - ]); - let a_slice = a.slice(1, 2); - let b_slice = b.slice(0, 2); - assert_eq!(&a_slice, &b_slice); - } -} diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs new file mode 100644 index 000000000000..f24179b61700 --- /dev/null +++ b/arrow-data/src/equal/mod.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing functionality to compute array equality. +//! This module uses [ArrayData] and does not +//! depend on dynamic casting of `Array`. + +use crate::data::ArrayData; +use arrow_buffer::i256; +use arrow_schema::{DataType, IntervalUnit}; +use half::f16; + +mod boolean; +mod byte_view; +mod dictionary; +mod fixed_binary; +mod fixed_list; +mod list; +mod null; +mod primitive; +mod run; +mod structure; +mod union; +mod utils; +mod variable_size; + +// these methods assume the same type, len and null count. +// For this reason, they are not exposed and are instead used +// to build the generic functions below (`equal_range` and `equal`). +use boolean::boolean_equal; +use byte_view::byte_view_equal; +use dictionary::dictionary_equal; +use fixed_binary::fixed_binary_equal; +use fixed_list::fixed_list_equal; +use list::list_equal; +use null::null_equal; +use primitive::primitive_equal; +use structure::struct_equal; +use union::union_equal; +use variable_size::variable_sized_equal; + +use self::run::run_equal; + +/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively +/// for `len` slots. +#[inline] +fn equal_values( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + match lhs.data_type() { + DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal128(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal256(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Date64 + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Interval(IntervalUnit::MonthDayNano) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Utf8 | DataType::Binary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::FixedSizeBinary(_) => fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::BinaryView | DataType::Utf8View => { + byte_view_equal(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::ListView(_) | DataType::LargeListView(_) => { + unimplemented!("ListView/LargeListView not yet implemented") + } + DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Dictionary(data_type, _) => match data_type.as_ref() { + DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + _ => unreachable!(), + }, + DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), + } +} + +fn equal_range( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_start, rhs_start, len) +} + +/// Logically compares two [ArrayData]. +/// +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * their lengths are equal +/// * their null counts are equal +/// * their null bitmaps are equal +/// * each of their items are equal +/// +/// Two items are equal when their in-memory representation is physically equal +/// (i.e. has the same bit content). +/// +/// The physical comparison depend on the data type. +/// +/// # Panics +/// +/// This function may panic whenever any of the [ArrayData] does not follow the +/// Arrow specification. (e.g. wrong number of buffers, buffer `len` does not +/// correspond to the declared `len`) +pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { + utils::base_equal(lhs, rhs) + && lhs.null_count() == rhs.null_count() + && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) + && equal_values(lhs, rhs, 0, 0, lhs.len()) +} + +// See arrow/tests/array_equal.rs for tests diff --git a/arrow/src/array/equal/null.rs b/arrow-data/src/equal/null.rs similarity index 97% rename from arrow/src/array/equal/null.rs rename to arrow-data/src/equal/null.rs index f287a382507a..1478e448cec2 100644 --- a/arrow/src/array/equal/null.rs +++ b/arrow-data/src/equal/null.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; +use crate::data::ArrayData; #[inline] pub(super) fn null_equal( diff --git a/arrow-data/src/equal/primitive.rs b/arrow-data/src/equal/primitive.rs new file mode 100644 index 000000000000..e92fdd2ba23b --- /dev/null +++ b/arrow-data/src/equal/primitive.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::BitSliceIterator; +use crate::contains_nulls; +use std::mem::size_of; + +use crate::data::ArrayData; + +use super::utils::equal_len; + +pub(crate) const NULL_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.4; + +pub(super) fn primitive_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let byte_width = size_of::(); + let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * byte_width..]; + let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * byte_width..]; + + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { + // without nulls, we just need to compare slices + equal_len( + lhs_values, + rhs_values, + lhs_start * byte_width, + rhs_start * byte_width, + len * byte_width, + ) + } else { + let selectivity_frac = lhs.null_count() as f64 / lhs.len() as f64; + + if selectivity_frac >= NULL_SLICES_SELECTIVITY_THRESHOLD { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * byte_width, + rhs_pos * byte_width, + byte_width, // 1 * byte_width since we are comparing a single entry + ) + }) + } else { + let lhs_nulls = lhs.nulls().unwrap(); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); + let rhs_nulls = rhs.nulls().unwrap(); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); + + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { + l_start == r_start + && l_end == r_end + && equal_len( + lhs_values, + rhs_values, + (lhs_start + l_start) * byte_width, + (rhs_start + r_start) * byte_width, + (l_end - l_start) * byte_width, + ) + }) + } + } +} diff --git a/arrow-data/src/equal/run.rs b/arrow-data/src/equal/run.rs new file mode 100644 index 000000000000..6c9393ecd8d3 --- /dev/null +++ b/arrow-data/src/equal/run.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::data::ArrayData; + +use super::equal_range; + +/// The current implementation of comparison of run array support physical comparison. +/// Comparing run encoded array based on logical indices (`lhs_start`, `rhs_start`) will +/// be time consuming as converting from logical index to physical index cannot be done +/// in constant time. The current comparison compares the underlying physical arrays. +pub(super) fn run_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + if lhs_start != 0 + || rhs_start != 0 + || (lhs.len() != len && rhs.len() != len) + || lhs.offset() > 0 + || rhs.offset() > 0 + { + unimplemented!("Logical comparison for run array not supported.") + } + + if lhs.len() != rhs.len() { + return false; + } + + let lhs_child_data = lhs.child_data(); + let lhs_run_ends_array = &lhs_child_data[0]; + let lhs_values_array = &lhs_child_data[1]; + + let rhs_child_data = rhs.child_data(); + let rhs_run_ends_array = &rhs_child_data[0]; + let rhs_values_array = &rhs_child_data[1]; + + if lhs_run_ends_array.len() != rhs_run_ends_array.len() { + return false; + } + + if lhs_values_array.len() != rhs_values_array.len() { + return false; + } + + // check run ends array are equal. The length of the physical array + // is used to validate the child arrays. + let run_ends_equal = equal_range( + lhs_run_ends_array, + rhs_run_ends_array, + lhs_start, + rhs_start, + lhs_run_ends_array.len(), + ); + + // if run ends array are not the same return early without validating + // values array. + if !run_ends_equal { + return false; + } + + // check values array are equal + equal_range( + lhs_values_array, + rhs_values_array, + lhs_start, + rhs_start, + rhs_values_array.len(), + ) +} diff --git a/arrow/src/array/equal/structure.rs b/arrow-data/src/equal/structure.rs similarity index 74% rename from arrow/src/array/equal/structure.rs rename to arrow-data/src/equal/structure.rs index 0f943e40cac6..e4751c26f489 100644 --- a/arrow/src/array/equal/structure.rs +++ b/arrow-data/src/equal/structure.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::data::count_nulls, array::ArrayData, util::bit_util::get_bit}; +use crate::data::{contains_nulls, ArrayData}; use super::equal_range; @@ -43,23 +43,21 @@ pub(super) fn struct_equal( rhs_start: usize, len: usize, ) -> bool { - // we have to recalculate null counts from the null buffers - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { equal_child_values(lhs, rhs, lhs_start, rhs_start, len) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; // if both struct and child had no null buffers, - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); if lhs_is_null != rhs_is_null { return false; diff --git a/arrow/src/array/equal/union.rs b/arrow-data/src/equal/union.rs similarity index 79% rename from arrow/src/array/equal/union.rs rename to arrow-data/src/equal/union.rs index e8b9d27b6f0f..62de276e507f 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow-data/src/equal/union.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode}; +use crate::data::ArrayData; +use arrow_schema::{DataType, UnionFields, UnionMode}; use super::equal_range; @@ -27,8 +28,8 @@ fn equal_dense( rhs_type_ids: &[i8], lhs_offsets: &[i32], rhs_offsets: &[i32], - lhs_field_type_ids: &[i8], - rhs_field_type_ids: &[i8], + lhs_fields: &UnionFields, + rhs_fields: &UnionFields, ) -> bool { let offsets = lhs_offsets.iter().zip(rhs_offsets.iter()); @@ -37,13 +38,13 @@ fn equal_dense( .zip(rhs_type_ids.iter()) .zip(offsets) .all(|((l_type_id, r_type_id), (l_offset, r_offset))| { - let lhs_child_index = lhs_field_type_ids + let lhs_child_index = lhs_fields .iter() - .position(|r| r == l_type_id) + .position(|(r, _)| r == *l_type_id) .unwrap(); - let rhs_child_index = rhs_field_type_ids + let rhs_child_index = rhs_fields .iter() - .position(|r| r == r_type_id) + .position(|(r, _)| r == *r_type_id) .unwrap(); let lhs_values = &lhs.child_data()[lhs_child_index]; let rhs_values = &rhs.child_data()[rhs_child_index]; @@ -69,7 +70,13 @@ fn equal_sparse( .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) + equal_range( + lhs_values, + rhs_values, + lhs_start + lhs.offset(), + rhs_start + rhs.offset(), + len, + ) }) } @@ -88,8 +95,8 @@ pub(super) fn union_equal( match (lhs.data_type(), rhs.data_type()) { ( - DataType::Union(_, lhs_type_ids, UnionMode::Dense), - DataType::Union(_, rhs_type_ids, UnionMode::Dense), + DataType::Union(lhs_fields, UnionMode::Dense), + DataType::Union(rhs_fields, UnionMode::Dense), ) => { let lhs_offsets = lhs.buffer::(1); let rhs_offsets = rhs.buffer::(1); @@ -105,14 +112,11 @@ pub(super) fn union_equal( rhs_type_id_range, lhs_offsets_range, rhs_offsets_range, - lhs_type_ids, - rhs_type_ids, + lhs_fields, + rhs_fields, ) } - ( - DataType::Union(_, _, UnionMode::Sparse), - DataType::Union(_, _, UnionMode::Sparse), - ) => { + (DataType::Union(_, UnionMode::Sparse), DataType::Union(_, UnionMode::Sparse)) => { lhs_type_id_range == rhs_type_id_range && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) } diff --git a/arrow/src/array/equal/utils.rs b/arrow-data/src/equal/utils.rs similarity index 69% rename from arrow/src/array/equal/utils.rs rename to arrow-data/src/equal/utils.rs index 449055d366ec..f1f4be44730e 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow-data/src/equal/utils.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::array::data::contains_nulls; -use crate::array::ArrayData; -use crate::datatypes::DataType; -use crate::util::bit_chunk_iterator::BitChunks; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::bit_chunk_iterator::BitChunks; +use arrow_schema::DataType; // whether bits along the positions are equal // `lhs_start`, `rhs_start` and `len` are _measured in bits_. @@ -30,16 +29,9 @@ pub(super) fn equal_bits( rhs_start: usize, len: usize, ) -> bool { - let lhs = BitChunks::new(lhs_values, lhs_start, len); - let rhs = BitChunks::new(rhs_values, rhs_start, len); - - for (a, b) in lhs.iter().zip(rhs.iter()) { - if a != b { - return false; - } - } - - lhs.remainder_bits() == rhs.remainder_bits() + let lhs = BitChunks::new(lhs_values, lhs_start, len).iter_padded(); + let rhs = BitChunks::new(rhs_values, rhs_start, len).iter_padded(); + lhs.zip(rhs).all(|(a, b)| a == b) } #[inline] @@ -50,15 +42,16 @@ pub(super) fn equal_nulls( rhs_start: usize, len: usize, ) -> bool { - let lhs_offset = lhs_start + lhs.offset(); - let rhs_offset = rhs_start + rhs.offset(); - - match (lhs.null_buffer(), rhs.null_buffer()) { - (Some(lhs), Some(rhs)) => { - equal_bits(lhs.as_slice(), rhs.as_slice(), lhs_offset, rhs_offset, len) - } - (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_offset, len), - (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_offset, len), + match (lhs.nulls(), rhs.nulls()) { + (Some(lhs), Some(rhs)) => equal_bits( + lhs.validity(), + rhs.validity(), + lhs.offset() + lhs_start, + rhs.offset() + rhs_start, + len, + ), + (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_start, len), + (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_start, len), (None, None) => true, } } @@ -66,7 +59,7 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let equal_type = match (lhs.data_type(), rhs.data_type()) { - (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, r_mode)) => { + (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => { l_fields == r_fields && l_mode == r_mode } (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { @@ -74,17 +67,15 @@ pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { (DataType::Struct(l_fields), DataType::Struct(r_fields)) if l_fields.len() == 2 && r_fields.len() == 2 => { - let l_key_field = l_fields.get(0).unwrap(); - let r_key_field = r_fields.get(0).unwrap(); - let l_value_field = l_fields.get(1).unwrap(); - let r_value_field = r_fields.get(1).unwrap(); + let l_key_field = &l_fields[0]; + let r_key_field = &r_fields[0]; + let l_value_field = &l_fields[1]; + let r_value_field = &r_fields[1]; // We don't enforce the equality of field names - let data_type_equal = l_key_field.data_type() - == r_key_field.data_type() + let data_type_equal = l_key_field.data_type() == r_key_field.data_type() && l_value_field.data_type() == r_value_field.data_type(); - let nullability_equal = l_key_field.is_nullable() - == r_key_field.is_nullable() + let nullability_equal = l_key_field.is_nullable() == r_key_field.is_nullable() && l_value_field.is_nullable() == r_value_field.is_nullable(); let metadata_equal = l_key_field.metadata() == r_key_field.metadata() && l_value_field.metadata() == r_value_field.metadata(); diff --git a/arrow/src/array/equal/variable_size.rs b/arrow-data/src/equal/variable_size.rs similarity index 62% rename from arrow/src/array/equal/variable_size.rs rename to arrow-data/src/equal/variable_size.rs index f40f79e404ac..d6e8e6a95481 100644 --- a/arrow/src/array/equal/variable_size.rs +++ b/arrow-data/src/equal/variable_size.rs @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::util::bit_util::get_bit; -use crate::{ - array::data::count_nulls, - array::{ArrayData, OffsetSizeTrait}, -}; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; +use num::Integer; use super::utils::equal_len; -fn offset_value_equal( +fn offset_value_equal( lhs_values: &[u8], rhs_values: &[u8], lhs_offsets: &[T], @@ -32,22 +30,23 @@ fn offset_value_equal( rhs_pos: usize, len: usize, ) -> bool { - let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); - let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); - let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; - let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; + let lhs_start = lhs_offsets[lhs_pos].as_usize(); + let rhs_start = rhs_offsets[rhs_pos].as_usize(); + let lhs_len = (lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]) + .to_usize() + .unwrap(); + let rhs_len = (rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]) + .to_usize() + .unwrap(); - lhs_len == rhs_len - && equal_len( - lhs_values, - rhs_values, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ) + if lhs_len == 0 && rhs_len == 0 { + return true; + } + + lhs_len == rhs_len && equal_len(lhs_values, rhs_values, lhs_start, rhs_start, lhs_len) } -pub(super) fn variable_sized_equal( +pub(super) fn variable_sized_equal( lhs: &ArrayData, rhs: &ArrayData, lhs_start: usize, @@ -61,14 +60,9 @@ pub(super) fn variable_sized_equal( let lhs_values = lhs.buffers()[1].as_slice(); let rhs_values = rhs.buffers()[1].as_slice(); - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 - && rhs_null_count == 0 - && !lhs_values.is_empty() - && !rhs_values.is_empty() - { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { offset_value_equal( lhs_values, rhs_values, @@ -84,15 +78,8 @@ pub(super) fn variable_sized_equal( let rhs_pos = rhs_start + i; // the null bits can still be `None`, indicating that the value is valid. - let lhs_is_null = !lhs - .null_buffer() - .map(|v| get_bit(v.as_slice(), lhs.offset() + lhs_pos)) - .unwrap_or(true); - - let rhs_is_null = !rhs - .null_buffer() - .map(|v| get_bit(v.as_slice(), rhs.offset() + rhs_pos)) - .unwrap_or(true); + let lhs_is_null = lhs.nulls().map(|v| v.is_null(lhs_pos)).unwrap_or_default(); + let rhs_is_null = rhs.nulls().map(|v| v.is_null(rhs_pos)).unwrap_or_default(); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs new file mode 100644 index 000000000000..cd283d32662f --- /dev/null +++ b/arrow-data/src/ffi.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). + +use crate::bit_mask::set_bits; +use crate::{layout, ArrayData}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{Buffer, MutableBuffer, ScalarBuffer}; +use arrow_schema::DataType; +use std::ffi::c_void; + +/// ABI-compatible struct for ArrowArray from C Data Interface +/// See +/// +/// ``` +/// # use arrow_data::ArrayData; +/// # use arrow_data::ffi::FFI_ArrowArray; +/// fn export_array(array: &ArrayData) -> FFI_ArrowArray { +/// FFI_ArrowArray::new(array) +/// } +/// ``` +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowArray { + length: i64, + null_count: i64, + offset: i64, + n_buffers: i64, + n_children: i64, + buffers: *mut *const c_void, + children: *mut *mut FFI_ArrowArray, + dictionary: *mut FFI_ArrowArray, + release: Option, + // When exported, this MUST contain everything that is owned by this array. + // for example, any buffer pointed to in `buffers` must be here, as well + // as the `buffers` pointer itself. + // In other words, everything in [FFI_ArrowArray] must be owned by + // `private_data` and can assume that they do not outlive `private_data`. + private_data: *mut c_void, +} + +impl Drop for FFI_ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +unsafe impl Send for FFI_ArrowArray {} +unsafe impl Sync for FFI_ArrowArray {} + +// callback used to drop [FFI_ArrowArray] when it is exported +unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it` + let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); + for child in private.children.iter() { + let _ = Box::from_raw(*child); + } + if !private.dictionary.is_null() { + let _ = Box::from_raw(private.dictionary); + } + + array.release = None; +} + +/// Aligns the provided `nulls` to the provided `data_offset` +/// +/// This is a temporary measure until offset is removed from ArrayData (#1799) +fn align_nulls(data_offset: usize, nulls: Option<&NullBuffer>) -> Option { + let nulls = nulls?; + if data_offset == nulls.offset() { + // Underlying buffer is already aligned + return Some(nulls.buffer().clone()); + } + if data_offset == 0 { + return Some(nulls.inner().sliced()); + } + let mut builder = MutableBuffer::new_null(data_offset + nulls.len()); + set_bits( + builder.as_slice_mut(), + nulls.validity(), + data_offset, + nulls.offset(), + nulls.len(), + ); + Some(builder.into()) +} + +struct ArrayPrivateData { + #[allow(dead_code)] + buffers: Vec>, + buffers_ptr: Box<[*const c_void]>, + children: Box<[*mut FFI_ArrowArray]>, + dictionary: *mut FFI_ArrowArray, +} + +impl FFI_ArrowArray { + /// creates a new `FFI_ArrowArray` from existing data. + pub fn new(data: &ArrayData) -> Self { + let data_layout = layout(data.data_type()); + + let mut buffers = if data_layout.can_contain_null_mask { + // * insert the null buffer at the start + // * make all others `Option`. + std::iter::once(align_nulls(data.offset(), data.nulls())) + .chain(data.buffers().iter().map(|b| Some(b.clone()))) + .collect::>() + } else { + data.buffers().iter().map(|b| Some(b.clone())).collect() + }; + + // `n_buffers` is the number of buffers by the spec. + let mut n_buffers = { + data_layout.buffers.len() + { + // If the layout has a null buffer by Arrow spec. + // Note that even the array doesn't have a null buffer because it has + // no null value, we still need to count 1 here to follow the spec. + usize::from(data_layout.can_contain_null_mask) + } + } as i64; + + if data_layout.variadic { + // Save the lengths of all variadic buffers into a new buffer. + // The first buffer is `views`, and the rest are variadic. + let mut data_buffers_lengths = Vec::new(); + for buffer in data.buffers().iter().skip(1) { + data_buffers_lengths.push(buffer.len() as i64); + n_buffers += 1; + } + + buffers.push(Some(ScalarBuffer::from(data_buffers_lengths).into_inner())); + n_buffers += 1; + } + + let buffers_ptr = buffers + .iter() + .flat_map(|maybe_buffer| match maybe_buffer { + Some(b) => Some(b.as_ptr() as *const c_void), + // This is for null buffer. We only put a null pointer for + // null buffer if by spec it can contain null mask. + None if data_layout.can_contain_null_mask => Some(std::ptr::null()), + None => None, + }) + .collect::>(); + + let empty = vec![]; + let (child_data, dictionary) = match data.data_type() { + DataType::Dictionary(_, _) => ( + empty.as_slice(), + Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))), + ), + _ => (data.child_data(), std::ptr::null_mut()), + }; + + let children = child_data + .iter() + .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child)))) + .collect::>(); + let n_children = children.len() as i64; + + // As in the IPC format, emit null_count = length for Null type + let null_count = match data.data_type() { + DataType::Null => data.len(), + _ => data.null_count(), + }; + + // create the private data owning everything. + // any other data must be added here, e.g. via a struct, to track lifetime. + let mut private_data = Box::new(ArrayPrivateData { + buffers, + buffers_ptr, + children, + dictionary, + }); + + Self { + length: data.len() as i64, + null_count: null_count as i64, + offset: data.offset() as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children.as_mut_ptr(), + dictionary, + release: Some(release_array), + private_data: Box::into_raw(private_data) as *mut c_void, + } + } + + /// Takes ownership of the pointed to [`FFI_ArrowArray`] + /// + /// This acts to [move] the data out of `array`, setting the release callback to NULL + /// + /// # Safety + /// + /// * `array` must be [valid] for reads and writes + /// * `array` must be properly aligned + /// * `array` must point to a properly initialized value of [`FFI_ArrowArray`] + /// + /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array + /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety + pub unsafe fn from_raw(array: *mut FFI_ArrowArray) -> Self { + std::ptr::replace(array, Self::empty()) + } + + /// create an empty `FFI_ArrowArray`, which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + #[inline] + pub fn len(&self) -> usize { + self.length as usize + } + + /// whether the array is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Whether the array has been released + #[inline] + pub fn is_released(&self) -> bool { + self.release.is_none() + } + + /// the offset of the array + #[inline] + pub fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + #[inline] + pub fn null_count(&self) -> usize { + self.null_count as usize + } + + /// Returns the buffer at the provided index + /// + /// # Panic + /// Panics if index exceeds the number of buffers or the buffer is not correctly aligned + #[inline] + pub fn buffer(&self, index: usize) -> *const u8 { + assert!(!self.buffers.is_null()); + assert!(index < self.num_buffers()); + // SAFETY: + // If buffers is not null must be valid for reads up to num_buffers + unsafe { std::ptr::read_unaligned((self.buffers as *mut *const u8).add(index)) } + } + + /// Returns the number of buffers + #[inline] + pub fn num_buffers(&self) -> usize { + self.n_buffers as _ + } + + /// Returns the child at the provided index + #[inline] + pub fn child(&self, index: usize) -> &FFI_ArrowArray { + assert!(!self.children.is_null()); + assert!(index < self.num_children()); + // Safety: + // If children is not null must be valid for reads up to num_children + unsafe { + let child = std::ptr::read_unaligned(self.children.add(index)); + child.as_ref().unwrap() + } + } + + /// Returns the number of children + #[inline] + pub fn num_children(&self) -> usize { + self.n_children as _ + } + + /// Returns the dictionary if any + #[inline] + pub fn dictionary(&self) -> Option<&Self> { + // Safety: + // If dictionary is not null should be valid for reads of `Self` + unsafe { self.dictionary.as_ref() } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // More tests located in top-level arrow crate + + #[test] + fn null_array_n_buffers() { + let data = ArrayData::new_null(&DataType::Null, 10); + + let ffi_array = FFI_ArrowArray::new(&data); + assert_eq!(0, ffi_array.n_buffers); + + let private_data = + unsafe { Box::from_raw(ffi_array.private_data as *mut ArrayPrivateData) }; + + assert_eq!(0, private_data.buffers_ptr.len()); + + let _ = Box::into_raw(private_data); + } +} diff --git a/arrow-data/src/lib.rs b/arrow-data/src/lib.rs new file mode 100644 index 000000000000..a7feca6cd976 --- /dev/null +++ b/arrow-data/src/lib.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Low-level array data abstractions for [Apache Arrow Rust](https://docs.rs/arrow) +//! +//! For a higher-level, strongly-typed interface see [arrow_array](https://docs.rs/arrow_array) + +#![warn(missing_docs)] +mod data; +pub use data::*; + +mod equal; +pub mod transform; + +pub use arrow_buffer::{bit_iterator, bit_mask}; +pub mod decimal; + +#[cfg(feature = "ffi")] +pub mod ffi; + +mod byte_view; +pub use byte_view::*; diff --git a/arrow/src/array/transform/boolean.rs b/arrow-data/src/transform/boolean.rs similarity index 95% rename from arrow/src/array/transform/boolean.rs rename to arrow-data/src/transform/boolean.rs index e0b6231a226e..d93fa15a4e0f 100644 --- a/arrow/src/array/transform/boolean.rs +++ b/arrow-data/src/transform/boolean.rs @@ -16,8 +16,8 @@ // under the License. use super::{Extend, _MutableArrayData, utils::resize_for_bits}; -use crate::array::ArrayData; -use crate::util::bit_mask::set_bits; +use crate::bit_mask::set_bits; +use crate::ArrayData; pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = array.buffers()[0].as_slice(); diff --git a/arrow/src/array/transform/fixed_binary.rs b/arrow-data/src/transform/fixed_binary.rs similarity index 54% rename from arrow/src/array/transform/fixed_binary.rs rename to arrow-data/src/transform/fixed_binary.rs index 6d6262ca3c4e..44c6f46ebf7e 100644 --- a/arrow/src/array/transform/fixed_binary.rs +++ b/arrow-data/src/transform/fixed_binary.rs @@ -15,50 +15,28 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::ArrayData, datatypes::DataType}; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; +use arrow_schema::DataType; pub(super) fn build_extend(array: &ArrayData) -> Extend { let size = match array.data_type() { DataType::FixedSizeBinary(i) => *i as usize, - DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; let values = &array.buffers()[0].as_slice()[array.offset() * size..]; - if array.null_count() == 0 { - // fast case where we can copy regions without null issues - Box::new( - move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let buffer = &mut mutable.buffer1; - buffer.extend_from_slice(&values[start * size..(start + len) * size]); - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - // nulls present: append item by item, ignoring null entries - let values_buffer = &mut mutable.buffer1; - - (start..start + len).for_each(|i| { - if array.is_valid(i) { - // append value - let bytes = &values[i * size..(i + 1) * size]; - values_buffer.extend_from_slice(bytes); - } else { - values_buffer.extend_zeros(size); - } - }) - }, - ) - } + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + let buffer = &mut mutable.buffer1; + buffer.extend_from_slice(&values[start * size..(start + len) * size]); + }, + ) } pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let size = match mutable.data_type { DataType::FixedSizeBinary(i) => i as usize, - DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/transform/fixed_size_list.rs b/arrow-data/src/transform/fixed_size_list.rs similarity index 53% rename from arrow/src/array/transform/fixed_size_list.rs rename to arrow-data/src/transform/fixed_size_list.rs index 77912a7026fd..8eef7bce9bb3 100644 --- a/arrow/src/array/transform/fixed_size_list.rs +++ b/arrow-data/src/transform/fixed_size_list.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; -use crate::datatypes::DataType; +use crate::ArrayData; +use arrow_schema::DataType; use super::{Extend, _MutableArrayData}; @@ -26,38 +26,14 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { _ => unreachable!(), }; - if array.null_count() == 0 { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - mutable.child_data.iter_mut().for_each(|child| { - child.extend(index, start * size, (start + len) * size) - }) - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - (start..start + len).for_each(|i| { - if array.is_valid(i) { - mutable.child_data.iter_mut().for_each(|child| { - child.extend(index, i * size, (i + 1) * size) - }) - } else { - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend_nulls(size)) - } - }) - }, - ) - } + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start * size, (start + len) * size)) + }, + ) } pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { diff --git a/arrow-data/src/transform/list.rs b/arrow-data/src/transform/list.rs new file mode 100644 index 000000000000..d9a1c62a8e8e --- /dev/null +++ b/arrow-data/src/transform/list.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; +use crate::ArrayData; +use arrow_buffer::ArrowNativeType; +use num::{CheckedAdd, Integer}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + // offsets + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); + + mutable.child_data[0].extend( + index, + offsets[start].as_usize(), + offsets[start + len].as_usize(), + ) + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + (0..len).for_each(|_| offset_buffer.push(last_offset)) +} diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs new file mode 100644 index 000000000000..c74b0c43481a --- /dev/null +++ b/arrow-data/src/transform/mod.rs @@ -0,0 +1,839 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Low-level array data abstractions. +//! +//! Provides utilities for creating, manipulating, and converting Arrow arrays +//! made of primitive types, strings, and nested types. + +use super::{data::new_buffers, ArrayData, ArrayDataBuilder, ByteView}; +use crate::bit_mask::set_bits; +use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; +use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; +use half::f16; +use num::Integer; +use std::mem; + +mod boolean; +mod fixed_binary; +mod fixed_size_list; +mod list; +mod null; +mod primitive; +mod structure; +mod union; +mod utils; +mod variable_size; + +type ExtendNullBits<'a> = Box; +// function that extends `[start..start+len]` to the mutable array. +// this is dynamic because different data_types influence how buffers and children are extended. +type Extend<'a> = Box; + +type ExtendNulls = Box; + +/// A mutable [ArrayData] that knows how to freeze itself into an [ArrayData]. +/// This is just a data container. +#[derive(Debug)] +struct _MutableArrayData<'a> { + pub data_type: DataType, + pub null_count: usize, + + pub len: usize, + pub null_buffer: Option, + + // arrow specification only allows up to 3 buffers (2 ignoring the nulls above). + // Thus, we place them in the stack to avoid bound checks and greater data locality. + pub buffer1: MutableBuffer, + pub buffer2: MutableBuffer, + pub child_data: Vec>, +} + +impl<'a> _MutableArrayData<'a> { + fn null_buffer(&mut self) -> &mut MutableBuffer { + self.null_buffer + .as_mut() + .expect("MutableArrayData not nullable") + } +} + +fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits { + if let Some(nulls) = array.nulls() { + let bytes = nulls.validity(); + Box::new(move |mutable, start, len| { + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); + mutable.null_count += set_bits( + out.as_slice_mut(), + bytes, + mutable_len, + nulls.offset() + start, + len, + ); + }) + } else if use_nulls { + Box::new(|mutable, _, len| { + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); + let write_data = out.as_slice_mut(); + (0..len).for_each(|i| { + bit_util::set_bit(write_data, mutable_len + i); + }); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +/// Efficiently create an [ArrayData] from one or more existing [ArrayData]s by +/// copying chunks. +/// +/// The main use case of this struct is to perform unary operations to arrays of +/// arbitrary types, such as `filter` and `take`. +/// +/// # Example +/// ``` +/// use arrow_buffer::Buffer; +/// use arrow_data::ArrayData; +/// use arrow_data::transform::MutableArrayData; +/// use arrow_schema::DataType; +/// fn i32_array(values: &[i32]) -> ArrayData { +/// ArrayData::try_new(DataType::Int32, 5, None, 0, vec![Buffer::from_slice_ref(values)], vec![]).unwrap() +/// } +/// let arr1 = i32_array(&[1, 2, 3, 4, 5]); +/// let arr2 = i32_array(&[6, 7, 8, 9, 10]); +/// // Create a mutable array for copying values from arr1 and arr2, with a capacity for 6 elements +/// let capacity = 3 * std::mem::size_of::(); +/// let mut mutable = MutableArrayData::new(vec![&arr1, &arr2], false, 10); +/// // Copy the first 3 elements from arr1 +/// mutable.extend(0, 0, 3); +/// // Copy the last 3 elements from arr2 +/// mutable.extend(1, 2, 4); +/// // Complete the MutableArrayData into a new ArrayData +/// let frozen = mutable.freeze(); +/// assert_eq!(frozen, i32_array(&[1, 2, 3, 8, 9, 10])); +/// ``` +pub struct MutableArrayData<'a> { + /// Input arrays: the data being read FROM. + /// + /// Note this is "dead code" because all actual references to the arrays are + /// stored in closures for extending values and nulls. + #[allow(dead_code)] + arrays: Vec<&'a ArrayData>, + + /// In progress output array: The data being written TO + /// + /// Note these fields are in a separate struct, [_MutableArrayData], as they + /// cannot be in [MutableArrayData] itself due to mutability invariants (interior + /// mutability): [MutableArrayData] contains a function that can only mutate + /// [_MutableArrayData], not [MutableArrayData] itself + data: _MutableArrayData<'a>, + + /// The child data of the `Array` in Dictionary arrays. + /// + /// This is not stored in `_MutableArrayData` because these values are + /// constant and only needed at the end, when freezing [_MutableArrayData]. + dictionary: Option, + + /// Variadic data buffers referenced by views. + /// + /// Note this this is not stored in `_MutableArrayData` because these values + /// are constant and only needed at the end, when freezing + /// [_MutableArrayData] + variadic_data_buffers: Vec, + + /// function used to extend output array with values from input arrays. + /// + /// This function's lifetime is bound to the input arrays because it reads + /// values from them. + extend_values: Vec>, + + /// function used to extend the output array with nulls from input arrays. + /// + /// This function's lifetime is bound to the input arrays because it reads + /// nulls from it. + extend_null_bits: Vec>, + + /// function used to extend the output array with null elements. + /// + /// This function is independent of the arrays and therefore has no lifetime. + extend_nulls: ExtendNulls, +} + +impl<'a> std::fmt::Debug for MutableArrayData<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // ignores the closures. + f.debug_struct("MutableArrayData") + .field("data", &self.data) + .finish() + } +} + +/// Builds an extend that adds `offset` to the source primitive +/// Additionally validates that `max` fits into the +/// the underlying primitive returning None if not +fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Option { + macro_rules! validate_and_build { + ($dt: ty) => {{ + let _: $dt = max.try_into().ok()?; + let offset: $dt = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + }}; + } + match array.data_type() { + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => validate_and_build!(u8), + DataType::UInt16 => validate_and_build!(u16), + DataType::UInt32 => validate_and_build!(u32), + DataType::UInt64 => validate_and_build!(u64), + DataType::Int8 => validate_and_build!(i8), + DataType::Int16 => validate_and_build!(i16), + DataType::Int32 => validate_and_build!(i32), + DataType::Int64 => validate_and_build!(i64), + _ => unreachable!(), + }, + _ => None, + } +} + +/// Builds an extend that adds `buffer_offset` to any buffer indices encountered +fn build_extend_view(array: &ArrayData, buffer_offset: u32) -> Extend { + let views = array.buffer::(0); + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + mutable + .buffer1 + .extend(views[start..start + len].iter().map(|v| { + let len = *v as u32; + if len <= 12 { + return *v; // Stored inline + } + let mut view = ByteView::from(*v); + view.buffer_index += buffer_offset; + view.into() + })) + }, + ) +} + +fn build_extend(array: &ArrayData) -> Extend { + match array.data_type() { + DataType::Null => null::build_extend(array), + DataType::Boolean => boolean::build_extend(array), + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + DataType::Float32 => primitive::build_extend::(array), + DataType::Float64 => primitive::build_extend::(array), + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::build_extend::(array) + } + DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => primitive::build_extend::(array), + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::build_extend::(array), + DataType::Decimal128(_, _) => primitive::build_extend::(array), + DataType::Decimal256(_, _) => primitive::build_extend::(array), + DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), + DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), + DataType::BinaryView | DataType::Utf8View => unreachable!("should use build_extend_view"), + DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), + DataType::ListView(_) | DataType::LargeListView(_) => { + unimplemented!("ListView/LargeListView not implemented") + } + DataType::LargeList(_) => list::build_extend::(array), + DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), + DataType::Struct(_) => structure::build_extend(array), + DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), + DataType::Float16 => primitive::build_extend::(array), + DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::build_extend_sparse(array), + UnionMode::Dense => union::build_extend_dense(array), + }, + DataType::RunEndEncoded(_, _) => todo!(), + } +} + +fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { + Box::new(match data_type { + DataType::Null => null::extend_nulls, + DataType::Boolean => boolean::extend_nulls, + DataType::UInt8 => primitive::extend_nulls::, + DataType::UInt16 => primitive::extend_nulls::, + DataType::UInt32 => primitive::extend_nulls::, + DataType::UInt64 => primitive::extend_nulls::, + DataType::Int8 => primitive::extend_nulls::, + DataType::Int16 => primitive::extend_nulls::, + DataType::Int32 => primitive::extend_nulls::, + DataType::Int64 => primitive::extend_nulls::, + DataType::Float32 => primitive::extend_nulls::, + DataType::Float64 => primitive::extend_nulls::, + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::extend_nulls:: + } + DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => primitive::extend_nulls::, + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::extend_nulls::, + DataType::Decimal128(_, _) => primitive::extend_nulls::, + DataType::Decimal256(_, _) => primitive::extend_nulls::, + DataType::Utf8 | DataType::Binary => variable_size::extend_nulls::, + DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, + DataType::BinaryView | DataType::Utf8View => primitive::extend_nulls::, + DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, + DataType::ListView(_) | DataType::LargeListView(_) => { + unimplemented!("ListView/LargeListView not implemented") + } + DataType::LargeList(_) => list::extend_nulls::, + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => primitive::extend_nulls::, + DataType::UInt16 => primitive::extend_nulls::, + DataType::UInt32 => primitive::extend_nulls::, + DataType::UInt64 => primitive::extend_nulls::, + DataType::Int8 => primitive::extend_nulls::, + DataType::Int16 => primitive::extend_nulls::, + DataType::Int32 => primitive::extend_nulls::, + DataType::Int64 => primitive::extend_nulls::, + _ => unreachable!(), + }, + DataType::Struct(_) => structure::extend_nulls, + DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, + DataType::Float16 => primitive::extend_nulls::, + DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::extend_nulls_sparse, + UnionMode::Dense => union::extend_nulls_dense, + }, + DataType::RunEndEncoded(_, _) => todo!(), + }) +} + +fn preallocate_offset_and_binary_buffer( + capacity: usize, + binary_size: usize, +) -> [MutableBuffer; 2] { + // offsets + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(Offset::zero()); + + [ + buffer, + MutableBuffer::new(binary_size * mem::size_of::()), + ] +} + +/// Define capacities to pre-allocate for child data or data buffers. +#[derive(Debug, Clone)] +pub enum Capacities { + /// Binary, Utf8 and LargeUtf8 data types + /// + /// Defines + /// * the capacity of the array offsets + /// * the capacity of the binary/ str buffer + Binary(usize, Option), + /// List and LargeList data types + /// + /// Defines + /// * the capacity of the array offsets + /// * the capacity of the child data + List(usize, Option>), + /// Struct type + /// + /// Defines + /// * the capacity of the array + /// * the capacities of the fields + Struct(usize, Option>), + /// Dictionary type + /// + /// Defines + /// * the capacity of the array/keys + /// * the capacity of the values + Dictionary(usize, Option>), + /// Don't preallocate inner buffers and rely on array growth strategy + Array(usize), +} + +impl<'a> MutableArrayData<'a> { + /// Returns a new [MutableArrayData] with capacity to `capacity` slots and + /// specialized to create an [ArrayData] from multiple `arrays`. + /// + /// # Arguments + /// * `arrays` - the source arrays to copy from + /// * `use_nulls` - a flag used to optimize insertions + /// - `false` if the only source of nulls are the arrays themselves + /// - `true` if the user plans to call [MutableArrayData::extend_nulls]. + /// * capacity - the preallocated capacity of the output array, in bytes + /// + /// Thus, if `use_nulls` is `false`, calling + /// [MutableArrayData::extend_nulls] should not be used. + pub fn new(arrays: Vec<&'a ArrayData>, use_nulls: bool, capacity: usize) -> Self { + Self::with_capacities(arrays, use_nulls, Capacities::Array(capacity)) + } + + /// Similar to [MutableArrayData::new], but lets users define the + /// preallocated capacities of the array with more granularity. + /// + /// See [MutableArrayData::new] for more information on the arguments. + /// + /// # Panics + /// + /// This function panics if the given `capacities` don't match the data type + /// of `arrays`. Or when a [Capacities] variant is not yet supported. + pub fn with_capacities( + arrays: Vec<&'a ArrayData>, + use_nulls: bool, + capacities: Capacities, + ) -> Self { + let data_type = arrays[0].data_type(); + + for a in arrays.iter().skip(1) { + assert_eq!( + data_type, + a.data_type(), + "Arrays with inconsistent types passed to MutableArrayData" + ) + } + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + let use_nulls = use_nulls | arrays.iter().any(|array| array.null_count() > 0); + + let mut array_capacity; + + let [buffer1, buffer2] = match (data_type, &capacities) { + ( + DataType::LargeUtf8 | DataType::LargeBinary, + Capacities::Binary(capacity, Some(value_cap)), + ) => { + array_capacity = *capacity; + preallocate_offset_and_binary_buffer::(*capacity, *value_cap) + } + (DataType::Utf8 | DataType::Binary, Capacities::Binary(capacity, Some(value_cap))) => { + array_capacity = *capacity; + preallocate_offset_and_binary_buffer::(*capacity, *value_cap) + } + (_, Capacities::Array(capacity)) => { + array_capacity = *capacity; + new_buffers(data_type, *capacity) + } + ( + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _), + Capacities::List(capacity, _), + ) => { + array_capacity = *capacity; + new_buffers(data_type, *capacity) + } + _ => panic!("Capacities: {capacities:?} not yet supported"), + }; + + let child_data = match &data_type { + DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Null + | DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8View + | DataType::Interval(_) + | DataType::FixedSizeBinary(_) => vec![], + DataType::ListView(_) | DataType::LargeListView(_) => { + unimplemented!("ListView/LargeListView not implemented") + } + DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { + let children = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + + let capacities = + if let Capacities::List(capacity, ref child_capacities) = capacities { + child_capacities + .clone() + .map(|c| *c) + .unwrap_or(Capacities::Array(capacity)) + } else { + Capacities::Array(array_capacity) + }; + + vec![MutableArrayData::with_capacities( + children, use_nulls, capacities, + )] + } + // the dictionary type just appends keys and clones the values. + DataType::Dictionary(_, _) => vec![], + DataType::Struct(fields) => match capacities { + Capacities::Struct(capacity, Some(ref child_capacities)) => { + array_capacity = capacity; + (0..fields.len()) + .zip(child_capacities) + .map(|(i, child_cap)| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::with_capacities( + child_arrays, + use_nulls, + child_cap.clone(), + ) + }) + .collect::>() + } + Capacities::Struct(capacity, None) => { + array_capacity = capacity; + (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, capacity) + }) + .collect::>() + } + _ => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), + }, + DataType::RunEndEncoded(_, _) => { + let run_ends_child = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + let value_child = arrays + .iter() + .map(|array| &array.child_data()[1]) + .collect::>(); + vec![ + MutableArrayData::new(run_ends_child, false, array_capacity), + MutableArrayData::new(value_child, use_nulls, array_capacity), + ] + } + DataType::FixedSizeList(_, size) => { + let children = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + let capacities = + if let Capacities::List(capacity, ref child_capacities) = capacities { + child_capacities + .clone() + .map(|c| *c) + .unwrap_or(Capacities::Array(capacity * *size as usize)) + } else { + Capacities::Array(array_capacity * *size as usize) + }; + vec![MutableArrayData::with_capacities( + children, use_nulls, capacities, + )] + } + DataType::Union(fields, _) => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), + }; + + // Get the dictionary if any, and if it is a concatenation of multiple + let (dictionary, dict_concat) = match &data_type { + DataType::Dictionary(_, _) => { + // If more than one dictionary, concatenate dictionaries together + let dict_concat = !arrays + .windows(2) + .all(|a| a[0].child_data()[0].ptr_eq(&a[1].child_data()[0])); + + match dict_concat { + false => (Some(arrays[0].child_data()[0].clone()), false), + true => { + if let Capacities::Dictionary(_, _) = capacities { + panic!("dictionary capacity not yet supported") + } + let dictionaries: Vec<_> = + arrays.iter().map(|array| &array.child_data()[0]).collect(); + let lengths: Vec<_> = dictionaries + .iter() + .map(|dictionary| dictionary.len()) + .collect(); + let capacity = lengths.iter().sum(); + + let mut mutable = MutableArrayData::new(dictionaries, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + (Some(mutable.freeze()), true) + } + } + } + _ => (None, false), + }; + + let variadic_data_buffers = match &data_type { + DataType::BinaryView | DataType::Utf8View => arrays + .iter() + .flat_map(|x| x.buffers().iter().skip(1)) + .map(Buffer::clone) + .collect(), + _ => vec![], + }; + + let extend_nulls = build_extend_nulls(data_type); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(array, use_nulls)) + .collect(); + + let null_buffer = use_nulls.then(|| { + let null_bytes = bit_util::ceil(array_capacity, 8); + MutableBuffer::from_len_zeroed(null_bytes) + }); + + let extend_values = match &data_type { + DataType::Dictionary(_, _) => { + let mut next_offset = 0; + let extend_values: Result, _> = arrays + .iter() + .map(|array| { + let offset = next_offset; + let dict_len = array.child_data()[0].len(); + + if dict_concat { + next_offset += dict_len; + } + + build_extend_dictionary(array, offset, offset + dict_len) + .ok_or(ArrowError::DictionaryKeyOverflowError) + }) + .collect(); + + extend_values.expect("MutableArrayData::new is infallible") + } + DataType::BinaryView | DataType::Utf8View => { + let mut next_offset = 0u32; + arrays + .iter() + .map(|arr| { + let num_data_buffers = (arr.buffers().len() - 1) as u32; + let offset = next_offset; + next_offset = next_offset + .checked_add(num_data_buffers) + .expect("view buffer index overflow"); + build_extend_view(arr, offset) + }) + .collect() + } + _ => arrays.iter().map(|array| build_extend(array)).collect(), + }; + + let data = _MutableArrayData { + data_type: data_type.clone(), + len: 0, + null_count: 0, + null_buffer, + buffer1, + buffer2, + child_data, + }; + Self { + arrays, + data, + dictionary, + variadic_data_buffers, + extend_values, + extend_null_bits, + extend_nulls, + } + } + + /// Extends the in progress array with a region of the input arrays + /// + /// # Arguments + /// * `index` - the index of array that you what to copy values from + /// * `start` - the start index of the chunk (inclusive) + /// * `end` - the end index of the chunk (exclusive) + /// + /// # Panic + /// This function panics if there is an invalid index, + /// i.e. `index` >= the number of source arrays + /// or `end` > the length of the `index`th array + pub fn extend(&mut self, index: usize, start: usize, end: usize) { + let len = end - start; + (self.extend_null_bits[index])(&mut self.data, start, len); + (self.extend_values[index])(&mut self.data, index, start, len); + self.data.len += len; + } + + /// Extends the in progress array with null elements, ignoring the input arrays. + /// + /// # Panics + /// + /// Panics if [`MutableArrayData`] not created with `use_nulls` or nullable source arrays + pub fn extend_nulls(&mut self, len: usize) { + self.data.len += len; + let bit_len = bit_util::ceil(self.data.len, 8); + let nulls = self.data.null_buffer(); + nulls.resize(bit_len, 0); + self.data.null_count += len; + (self.extend_nulls)(&mut self.data, len); + } + + /// Returns the current length + #[inline] + pub fn len(&self) -> usize { + self.data.len + } + + /// Returns true if len is 0 + #[inline] + pub fn is_empty(&self) -> bool { + self.data.len == 0 + } + + /// Returns the current null count + #[inline] + pub fn null_count(&self) -> usize { + self.data.null_count + } + + /// Creates a [ArrayData] from the in progress array, consuming `self`. + pub fn freeze(self) -> ArrayData { + unsafe { self.into_builder().build_unchecked() } + } + + /// Consume self and returns the in progress array as [`ArrayDataBuilder`]. + /// + /// This is useful for extending the default behavior of MutableArrayData. + pub fn into_builder(self) -> ArrayDataBuilder { + let data = self.data; + + let buffers = match data.data_type { + DataType::Null | DataType::Struct(_) | DataType::FixedSizeList(_, _) => { + vec![] + } + DataType::BinaryView | DataType::Utf8View => { + let mut b = self.variadic_data_buffers; + b.insert(0, data.buffer1.into()); + b + } + DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + vec![data.buffer1.into(), data.buffer2.into()] + } + DataType::Union(_, mode) => { + match mode { + // Based on Union's DataTypeLayout + UnionMode::Sparse => vec![data.buffer1.into()], + UnionMode::Dense => vec![data.buffer1.into(), data.buffer2.into()], + } + } + _ => vec![data.buffer1.into()], + }; + + let child_data = match data.data_type { + DataType::Dictionary(_, _) => vec![self.dictionary.unwrap()], + _ => data.child_data.into_iter().map(|x| x.freeze()).collect(), + }; + + let nulls = data + .null_buffer + .map(|nulls| { + let bools = BooleanBuffer::new(nulls.into(), 0, data.len); + unsafe { NullBuffer::new_unchecked(bools, data.null_count) } + }) + .filter(|n| n.null_count() > 0); + + ArrayDataBuilder::new(data.data_type) + .offset(0) + .len(data.len) + .nulls(nulls) + .buffers(buffers) + .child_data(child_data) + } +} + +// See arrow/tests/array_transform.rs for tests of transform functionality + +#[cfg(test)] +mod test { + use super::*; + use arrow_schema::Field; + use std::sync::Arc; + + #[test] + fn test_list_append_with_capacities() { + let array = ArrayData::new_empty(&DataType::List(Arc::new(Field::new( + "element", + DataType::Int64, + false, + )))); + + let mutable = MutableArrayData::with_capacities( + vec![&array], + false, + Capacities::List(6, Some(Box::new(Capacities::Array(17)))), + ); + + // capacities are rounded up to multiples of 64 by MutableBuffer + assert_eq!(mutable.data.buffer1.capacity(), 64); + assert_eq!(mutable.data.child_data[0].data.buffer1.capacity(), 192); + } +} diff --git a/arrow/src/array/transform/null.rs b/arrow-data/src/transform/null.rs similarity index 97% rename from arrow/src/array/transform/null.rs rename to arrow-data/src/transform/null.rs index e1335e179713..5d1535564d9e 100644 --- a/arrow/src/array/transform/null.rs +++ b/arrow-data/src/transform/null.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; pub(super) fn build_extend(_: &ArrayData) -> Extend { Box::new(move |_, _, _, _| {}) diff --git a/arrow/src/array/transform/primitive.rs b/arrow-data/src/transform/primitive.rs similarity index 91% rename from arrow/src/array/transform/primitive.rs rename to arrow-data/src/transform/primitive.rs index 4c765c0c0d95..627dc00de1df 100644 --- a/arrow/src/array/transform/primitive.rs +++ b/arrow-data/src/transform/primitive.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::ArrayData; +use arrow_buffer::ArrowNativeType; use std::mem::size_of; use std::ops::Add; -use crate::{array::ArrayData, datatypes::ArrowNativeType}; - use super::{Extend, _MutableArrayData}; pub(super) fn build_extend(array: &ArrayData) -> Extend { @@ -47,9 +47,6 @@ where ) } -pub(super) fn extend_nulls( - mutable: &mut _MutableArrayData, - len: usize, -) { +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { mutable.buffer1.extend_zeros(len * size_of::()); } diff --git a/arrow-data/src/transform/structure.rs b/arrow-data/src/transform/structure.rs new file mode 100644 index 000000000000..7330dcaa3705 --- /dev/null +++ b/arrow-data/src/transform/structure.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{Extend, _MutableArrayData}; +use crate::ArrayData; + +pub(super) fn build_extend(_: &ArrayData) -> Extend { + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start, start + len)) + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(len)) +} diff --git a/arrow/src/array/transform/union.rs b/arrow-data/src/transform/union.rs similarity index 82% rename from arrow/src/array/transform/union.rs rename to arrow-data/src/transform/union.rs index bbea508219d0..d7083588d782 100644 --- a/arrow/src/array/transform/union.rs +++ b/arrow-data/src/transform/union.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); @@ -40,6 +39,9 @@ pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); let offsets = array.buffer::(1); + let arrow_schema::DataType::Union(src_fields, _) = array.data_type() else { + unreachable!(); + }; Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { @@ -49,14 +51,18 @@ pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { .extend_from_slice(&type_ids[start..start + len]); (start..start + len).for_each(|i| { - let type_id = type_ids[i] as usize; + let type_id = type_ids[i]; + let child_index = src_fields + .iter() + .position(|(r, _)| r == type_id) + .expect("invalid union type ID"); let src_offset = offsets[i] as usize; - let child_data = &mut mutable.child_data[type_id]; + let child_data = &mut mutable.child_data[child_index]; let dst_offset = child_data.len(); // Extend offsets mutable.buffer2.push(dst_offset as i32); - mutable.child_data[type_id].extend(index, src_offset, src_offset + 1) + mutable.child_data[child_index].extend(index, src_offset, src_offset + 1) }) }, ) diff --git a/arrow/src/array/transform/utils.rs b/arrow-data/src/transform/utils.rs similarity index 66% rename from arrow/src/array/transform/utils.rs rename to arrow-data/src/transform/utils.rs index 68aee79c41bb..5407f68e0d0c 100644 --- a/arrow/src/array/transform/utils.rs +++ b/arrow-data/src/transform/utils.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::OffsetSizeTrait, buffer::MutableBuffer, util::bit_util}; +use arrow_buffer::{bit_util, ArrowNativeType, MutableBuffer}; +use num::{CheckedAdd, Integer}; /// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. #[inline] @@ -26,24 +27,25 @@ pub(super) fn resize_for_bits(buffer: &mut MutableBuffer, len: usize) { } } -pub(super) fn extend_offsets( +pub(super) fn extend_offsets( buffer: &mut MutableBuffer, mut last_offset: T, offsets: &[T], ) { - buffer.reserve(offsets.len() * std::mem::size_of::()); + buffer.reserve(std::mem::size_of_val(offsets)); offsets.windows(2).for_each(|offsets| { // compute the new offset let length = offsets[1] - offsets[0]; - last_offset += length; + // if you hit this appending to a StringArray / BinaryArray it is because you + // are trying to add more data than can fit into that type. Try breaking your data into + // smaller batches or using LargeStringArray / LargeBinaryArray + last_offset = last_offset.checked_add(&length).expect("offset overflow"); buffer.push(last_offset); }); } #[inline] -pub(super) unsafe fn get_last_offset( - offset_buffer: &MutableBuffer, -) -> T { +pub(super) unsafe fn get_last_offset(offset_buffer: &MutableBuffer) -> T { // JUSTIFICATION // Benefit // 20% performance improvement extend of variable sized arrays (see bench `mutable_array`) @@ -54,3 +56,16 @@ pub(super) unsafe fn get_last_offset( debug_assert!(prefix.is_empty() && suffix.is_empty()); *offsets.get_unchecked(offsets.len() - 1) } + +#[cfg(test)] +mod tests { + use crate::transform::utils::extend_offsets; + use arrow_buffer::MutableBuffer; + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_overflow() { + let mut buffer = MutableBuffer::new(10); + extend_offsets(&mut buffer, i32::MAX - 4, &[0, 5]); + } +} diff --git a/arrow-data/src/transform/variable_size.rs b/arrow-data/src/transform/variable_size.rs new file mode 100644 index 000000000000..fa1592d973ed --- /dev/null +++ b/arrow-data/src/transform/variable_size.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use arrow_buffer::{ArrowNativeType, MutableBuffer}; +use num::traits::AsPrimitive; +use num::{CheckedAdd, Integer}; + +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; + +#[inline] +fn extend_offset_values>( + buffer: &mut MutableBuffer, + offsets: &[T], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].as_(); + let end_values = offsets[start + len].as_(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} + +pub(super) fn build_extend>( + array: &ArrayData, +) -> Extend { + let offsets = array.buffer::(0); + let values = array.buffers()[1].as_slice(); + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + let values_buffer = &mut mutable.buffer2; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset = unsafe { get_last_offset(offset_buffer) }; + + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); + // values + extend_offset_values::(values_buffer, offsets, values, start, len); + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + (0..len).for_each(|_| offset_buffer.push(last_offset)) +} diff --git a/arrow-flight/CONTRIBUTING.md b/arrow-flight/CONTRIBUTING.md new file mode 100644 index 000000000000..156a0b9caaed --- /dev/null +++ b/arrow-flight/CONTRIBUTING.md @@ -0,0 +1,41 @@ + + +# Flight + +## Generated Code + +The prost/tonic code can be generated by running, which in turn invokes the Rust binary located in [gen](./gen) + +This is necessary after modifying the protobuf definitions or altering the dependencies of [gen](./gen), and requires a +valid installation of [protoc](https://github.com/protocolbuffers/protobuf#protocol-compiler-installation). + +```bash +./regen.sh +``` + +### Why Vendor + +The standard approach to integrating `prost-build` / `tonic-build` is to use a `build.rs` script that automatically generates the code as part of the standard build process. + +Unfortunately this caused a lot of friction for users: + +- Requires all users to have a protoc install in order to compile the crate - [#2616](https://github.com/apache/arrow-rs/issues/2616) +- Some distributions have very old versions of protoc that don't support required functionality - [#1574](https://github.com/apache/arrow-rs/issues/1574) +- Inconsistent support within IDEs for code completion of automatically generated code diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index ecf02625c9d3..64bf2041ef61 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -18,37 +18,81 @@ [package] name = "arrow-flight" description = "Apache Arrow Flight" -version = "22.0.0" -edition = "2021" -rust-version = "1.62" -authors = ["Apache Arrow "] -homepage = "https://github.com/apache/arrow-rs" -repository = "https://github.com/apache/arrow-rs" -license = "Apache-2.0" +version = { workspace = true } +edition = { workspace = true } +rust-version = "1.71.1" +authors = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } [dependencies] -arrow = { path = "../arrow", version = "22.0.0", default-features = false, features = ["ipc"] } -base64 = { version = "0.13", default-features = false } -tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] } +arrow-arith = { workspace = true, optional = true } +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +# Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389 +arrow-cast = { workspace = true } +arrow-data = { workspace = true, optional = true } +arrow-ipc = { workspace = true } +arrow-ord = { workspace = true, optional = true } +arrow-row = { workspace = true, optional = true } +arrow-select = { workspace = true, optional = true } +arrow-schema = { workspace = true } +arrow-string = { workspace = true, optional = true } +base64 = { version = "0.22", default-features = false, features = ["std"] } bytes = { version = "1", default-features = false } -prost = { version = "0.11", default-features = false } -prost-types = { version = "0.11.0", default-features = false, optional = true } -prost-derive = { version = "0.11", default-features = false } +futures = { version = "0.3", default-features = false, features = ["alloc"] } +once_cell = { version = "1", optional = true } +paste = { version = "1.0" } +prost = { version = "0.13.1", default-features = false, features = ["prost-derive"] } +# For Timestamp type +prost-types = { version = "0.13.1", default-features = false } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } -futures = { version = "0.3", default-features = false, features = ["alloc"]} +tonic = { version = "0.12.1", default-features = false, features = ["transport", "codegen", "prost"] } + +# CLI-related dependencies +anyhow = { version = "1.0", optional = true } +clap = { version = "4.4.6", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage", "wrap_help", "color", "suggestions"], optional = true } +tracing-log = { version = "0.2", optional = true } +tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "env-filter", "fmt"], optional = true } + +[package.metadata.docs.rs] +all-features = true [features] default = [] -flight-sql-experimental = ["prost-types"] +flight-sql-experimental = ["arrow-arith", "arrow-data", "arrow-ord", "arrow-row", "arrow-select", "arrow-string", "once_cell"] +tls = ["tonic/tls"] -[dev-dependencies] +# Enable CLI tools +cli = ["anyhow", "arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "tonic/tls-webpki-roots"] -[build-dependencies] -tonic-build = { version = "0.8", default-features = false, features = ["transport", "prost"] } -# Pin specific version of the tonic-build dependencies to avoid auto-generated -# (and checked in) arrow.flight.protocol.rs from changing -proc-macro2 = { version = ">1.0.30", default-features = false } +[dev-dependencies] +arrow-cast = { workspace = true, features = ["prettyprint"] } +assert_cmd = "2.0.8" +http = "1.1.0" +http-body = "1.0.0" +hyper-util = "0.1" +pin-project-lite = "0.2" +tempfile = "3.3" +tokio-stream = { version = "0.1", features = ["net"] } +tower = { version = "0.5.0", features = ["util"] } +uuid = { version = "1.10.0", features = ["v4"] } [[example]] name = "flight_sql_server" -required-features = ["flight-sql-experimental"] +required-features = ["flight-sql-experimental", "tls"] + +[[bin]] +name = "flight_sql_client" +required-features = ["cli", "flight-sql-experimental", "tls"] + +[[test]] +name = "flight_sql_client" +path = "tests/flight_sql_client.rs" +required-features = ["flight-sql-experimental", "tls"] + +[[test]] +name = "flight_sql_client_cli" +path = "tests/flight_sql_client_cli.rs" +required-features = ["cli", "flight-sql-experimental", "tls"] diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 9e9a18ad4789..2266b81e44ca 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -21,15 +21,56 @@ [![Crates.io](https://img.shields.io/crates/v/arrow-flight.svg)](https://crates.io/crates/arrow-flight) +See the [API documentation](https://docs.rs/arrow_flight/latest) for examples and the full API. + +The API documentation for most recent, unreleased code is available [here](https://arrow.apache.org/rust/arrow_flight/index.html). + ## Usage Add this to your Cargo.toml: ```toml [dependencies] -arrow-flight = "22.0.0" +arrow-flight = "51.0.0" ``` Apache Arrow Flight is a gRPC based protocol for exchanging Arrow data between processes. See the blog post [Introducing Apache Arrow Flight: A Framework for Fast Data Transport](https://arrow.apache.org/blog/2019/10/13/introducing-arrow-flight/) for more information. -This crate provides a Rust implementation of the [Flight.proto](../../format/Flight.proto) gRPC protocol and provides an example that demonstrates how to build a Flight server implemented with Tonic. +This crate provides a Rust implementation of the +[Flight.proto](../format/Flight.proto) gRPC protocol and +[examples](https://github.com/apache/arrow-rs/tree/master/arrow-flight/examples) +that demonstrate how to build a Flight server implemented with [tonic](https://docs.rs/crate/tonic/latest). + +## Feature Flags + +- `flight-sql-experimental`: Enables experimental support for + [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. + +## CLI + +This crates offers a basic [Apache Arrow FlightSQL] command line interface. + +The client can be installed from the repository: + +```console +$ cargo install --features=cli,flight-sql-experimental,tls --bin=flight_sql_client --path=. --locked +``` + +The client comes with extensive help text: + +```console +$ flight_sql_client help +``` + +A query can be executed using: + +```console +$ flight_sql_client --host example.com statement-query "SELECT 1;" ++----------+ +| Int64(1) | ++----------+ +| 1 | ++----------+ +``` + +[apache arrow flightsql]: https://arrow.apache.org/docs/format/FlightSql.html diff --git a/arrow-flight/build.rs b/arrow-flight/build.rs deleted file mode 100644 index 25f034ac191b..000000000000 --- a/arrow-flight/build.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{ - env, - fs::OpenOptions, - io::{Read, Write}, - path::Path, -}; - -fn main() -> Result<(), Box> { - // override the build location, in order to check in the changes to proto files - env::set_var("OUT_DIR", "src"); - - // The current working directory can vary depending on how the project is being - // built or released so we build an absolute path to the proto file - let path = Path::new("../format/Flight.proto"); - if path.exists() { - // avoid rerunning build if the file has not changed - println!("cargo:rerun-if-changed=../format/Flight.proto"); - - let proto_dir = Path::new("../format"); - let proto_path = Path::new("../format/Flight.proto"); - - tonic_build::configure() - // protoc in unbuntu builder needs this option - .protoc_arg("--experimental_allow_proto3_optional") - .compile(&[proto_path], &[proto_dir])?; - - // read file contents to string - let mut file = OpenOptions::new() - .read(true) - .open("src/arrow.flight.protocol.rs")?; - let mut buffer = String::new(); - file.read_to_string(&mut buffer)?; - // append warning that file was auto-generate - let mut file = OpenOptions::new() - .write(true) - .truncate(true) - .open("src/arrow.flight.protocol.rs")?; - file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; - file.write_all(buffer.as_bytes())?; - } - - // override the build location, in order to check in the changes to proto files - env::set_var("OUT_DIR", "src/sql"); - // The current working directory can vary depending on how the project is being - // built or released so we build an absolute path to the proto file - let path = Path::new("../format/FlightSql.proto"); - if path.exists() { - // avoid rerunning build if the file has not changed - println!("cargo:rerun-if-changed=../format/FlightSql.proto"); - - let proto_dir = Path::new("../format"); - let proto_path = Path::new("../format/FlightSql.proto"); - - tonic_build::configure() - // protoc in unbuntu builder needs this option - .protoc_arg("--experimental_allow_proto3_optional") - .compile(&[proto_path], &[proto_dir])?; - - // read file contents to string - let mut file = OpenOptions::new() - .read(true) - .open("src/sql/arrow.flight.protocol.sql.rs")?; - let mut buffer = String::new(); - file.read_to_string(&mut buffer)?; - // append warning that file was auto-generate - let mut file = OpenOptions::new() - .write(true) - .truncate(true) - .open("src/sql/arrow.flight.protocol.sql.rs")?; - file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; - file.write_all(buffer.as_bytes())?; - } - - // Prost currently generates an empty file, this was fixed but then reverted - // https://github.com/tokio-rs/prost/pull/639 - let google_protobuf_rs = Path::new("src/sql/google.protobuf.rs"); - if google_protobuf_rs.exists() && google_protobuf_rs.metadata().unwrap().len() == 0 { - std::fs::remove_file(google_protobuf_rs).unwrap(); - } - - // As the proto file is checked in, the build should not fail if the file is not found - Ok(()) -} diff --git a/arrow-flight/examples/data/ca.pem b/arrow-flight/examples/data/ca.pem new file mode 100644 index 000000000000..d81956096677 --- /dev/null +++ b/arrow-flight/examples/data/ca.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIE3DCCA0SgAwIBAgIRAObeYbJFiVQSGR8yk44dsOYwDQYJKoZIhvcNAQELBQAw +gYUxHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTEtMCsGA1UECwwkbHVj +aW9ATHVjaW9zLVdvcmstTUJQIChMdWNpbyBGcmFuY28pMTQwMgYDVQQDDCtta2Nl +cnQgbHVjaW9ATHVjaW9zLVdvcmstTUJQIChMdWNpbyBGcmFuY28pMB4XDTE5MDky +OTIzMzUzM1oXDTI5MDkyOTIzMzUzM1owgYUxHjAcBgNVBAoTFW1rY2VydCBkZXZl +bG9wbWVudCBDQTEtMCsGA1UECwwkbHVjaW9ATHVjaW9zLVdvcmstTUJQIChMdWNp +byBGcmFuY28pMTQwMgYDVQQDDCtta2NlcnQgbHVjaW9ATHVjaW9zLVdvcmstTUJQ +IChMdWNpbyBGcmFuY28pMIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEA +y/vE61ItbN/1qMYt13LMf+le1svwfkCCOPsygk7nWeRXmomgUpymqn1LnWiuB0+e +4IdVH2f5E9DknWEpPhKIDMRTCbz4jTwQfHrxCb8EGj3I8oO73pJO5S/xCedM9OrZ +qWcYWwN0GQ8cO/ogazaoZf1uTrRNHyzRyQsKyb412kDBTNEeldJZ2ljKgXXvh4HO +2ZIk9K/ZAaAf6VN8K/89rlJ9/KPgRVNsyAapE+Pb8XXKtpzeFiEcUfuXVYWtkoW+ +xyn/Zu8A1L2CXMQ1sARh7P/42BTMKr5pfraYgcBGxKXLrxoySpxCO9KqeVveKy1q +fPm5FCwFsXDr0koFLrCiR58mcIO/04Q9DKKTV4Z2a+LoqDJRY37KfBSc8sDMPhw5 +k7g3WPoa6QwXRjZTCA5fHWVgLOtcwLsnju5tBE4LDxwF6s+1wPF8NI5yUfufcEjJ +Z6JBwgoWYosVj27Lx7KBNLU/57PX9ryee691zmtswt0tP0WVBAgalhYWg99RXoa3 +AgMBAAGjRTBDMA4GA1UdDwEB/wQEAwICBDASBgNVHRMBAf8ECDAGAQH/AgEAMB0G +A1UdDgQWBBQdvlE4Bdcsjc9oaxjDCRu5FiuZkzANBgkqhkiG9w0BAQsFAAOCAYEA +BP/6o1kPINksMJZSSXgNCPZskDLyGw7auUZBnQ0ocDT3W6gXQvT/27LM1Hxoj9Eh +qU1TYdEt7ppecLQSGvzQ02MExG7H75art75oLiB+A5agDira937YbK4MCjqW481d +bDhw6ixJnY1jIvwjEZxyH6g94YyL927aSPch51fys0kSnjkFzC2RmuzDADScc4XH +5P1+/3dnIm3M5yfpeUzoaOrTXNmhn8p0RDIGrZ5kA5eISIGGD3Mm8FDssUNKndtO +g4ojHUsxb14icnAYGeye1NOhGiqN6TEFcgr6MPd0XdFNZ5c0HUaBCfN6bc+JxDV5 +MKZVJdNeJsYYwilgJNHAyZgCi30JC20xeYVtTF7CEEsMrFDGJ70Kz7o/FnRiFsA1 +ZSwVVWhhkHG2VkT4vlo0O3fYeZpenYicvy+wZNTbGK83gzHWqxxNC1z3Etg5+HRJ +F9qeMWPyfA3IHYXygiMcviyLcyNGG/SJ0EhUpYBN/Gg7wI5yFkcsxUDPPzd23O0M +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/client1.key b/arrow-flight/examples/data/client1.key new file mode 100644 index 000000000000..f4d8da2758ac --- /dev/null +++ b/arrow-flight/examples/data/client1.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCiiWrmzpENsI+c +Cz4aBpG+Pl8WOsrByfZx/ZnJdCZHO3MTYE6sCLhYssf0ygAEEGxvmkd4cxmfCfgf +xuT8u+D7Y5zQSoymkbWdU6/9jbNY6Ovtc+a96I1LGXOKROQw6KR3PuqLpUqEOJiB +l03qK+HMU0g56G1n31Od7HkJsDRvtePqy3I3LgpdcRps23sk46tCzZzhyfqIQ7Qf +J5qZx93tA+pfy+Xtb9XIUTIWKIp1/uyfh8Fp8HA0c9zJCSZzJOX2j3GH1TYqkVgP +egI2lhmdXhP5Q8vdhwy0UJaL28RJXA6UAg0tPZeWJe6pux9JiA81sI6My+Krrw8D +yibkGTTbAgMBAAECggEANCQhRym9HsclSsnQgkjZOE6J8nep08EWbjsMurOoE/He +WLjshAPIH6w6uSyUFLmwD51OkDVcYsiv8IG9s9YRtpOeGrPPqx/TQ0U1kAGFJ2CR +Tvt/aizQJudjSVgQXCBFontsgp/j58bAJdKEDDtHlGSjJvCJKGlcSa0ypwj/yVXt +frjROJNYzw9gMM7fN/IKF/cysdXSeLl/Q9RnHVIfC3jOFJutsILCK8+PC51dM8Fl +IOjmPmiZ080yV8RBcMRECwl53vLOE3OOpR3ZijfNCY1KU8zWi1oELJ1o6f4+cBye +7WPgFEoBew5XHXZ+ke8rh8cc0wth7ZTcC+xC/456AQKBgQDQr2EzBwXxYLF8qsN1 +R4zlzXILLdZN8a4bKfrS507/Gi1gDBHzfvbE7HfljeqrAkbKMdKNkbz3iS85SguH +jsM047xUGJg0PAcwBLHUedlSn1xDDcDHW6X8ginpA2Zz1+WAlhNz6XurA1wnjZmS +VcPxopH7QsuFCclqtt14MbBQ6QKBgQDHY3jcAVfQF+yhQ0YyM6GPLN342aTplgyJ +yz4uWVMeXacU4QzqGbf2L2hc9M2L28Xb37RWC3Q/by0vUefiC6qxRt+GJdRsOuQj +2F1uUibeWtAWp249fcfvxjLib276J+Eit18LI0s0mNR3ekK4GcjSe4NwSq5IrU8e +pBreet3dIwKBgQCxVuil4WkGd+I8jC0v5A7zVsR8hYZhlGkdgm45fgHevdMjlP5I +S3PPYxh8hj6O9o9L0k0Yq2nHfdgYujjUCNkQgBuR55iogv6kqsioRKgPE4fnH6/c +eqCy1bZh4tbUyPqqbF65mQfUCzXsEuQXvDSYiku+F0Q2mVuGCUJpmug3yQKBgEd3 +LeCdUp4xlQ0QEd74hpXM3RrO178pmwDgqj7uoU4m/zYKnBhkc3137I406F+SvE5c +1kRpApeh/64QS27IA7xazM9GS+cnDJKUgJiENY5JOoCELo03wiv8/EwQ6NQc6yMI +WrahRdlqVe0lEzjtdP+MacYb3nAKPmubIk5P96nFAoGAFAyrKpFTyXbNYBTw9Rab +TG6q7qkn+YTHN3+k4mo9NGGwZ3pXvmrKMYCIRhLMbqzsmTbFqCPPIxKsrmf8QYLh +xHYQjrCkbZ0wZdcdeV6yFSDsF218nF/12ZPE7CBOQMfZTCKFNWGL97uIVcmR6K5G +ojTkOvaUnwQtSFhNuzyr23I= +-----END PRIVATE KEY----- diff --git a/arrow-flight/examples/data/client1.pem b/arrow-flight/examples/data/client1.pem new file mode 100644 index 000000000000..bb3b82c40c5a --- /dev/null +++ b/arrow-flight/examples/data/client1.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIQYbE9d1Rft5h4ku7FSAvWdzANBgkqhkiG9w0BAQsFADAn +MSUwIwYDVQQDExxUb25pYyBFeGFtcGxlIENsaWVudCBSb290IENBMB4XDTE5MTAx +NDEyMzkzNloXDTI0MTAxMjEyMzkzNlowEjEQMA4GA1UEAxMHY2xpZW50MTCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKKJaubOkQ2wj5wLPhoGkb4+XxY6 +ysHJ9nH9mcl0Jkc7cxNgTqwIuFiyx/TKAAQQbG+aR3hzGZ8J+B/G5Py74PtjnNBK +jKaRtZ1Tr/2Ns1jo6+1z5r3ojUsZc4pE5DDopHc+6oulSoQ4mIGXTeor4cxTSDno +bWffU53seQmwNG+14+rLcjcuCl1xGmzbeyTjq0LNnOHJ+ohDtB8nmpnH3e0D6l/L +5e1v1chRMhYoinX+7J+HwWnwcDRz3MkJJnMk5faPcYfVNiqRWA96AjaWGZ1eE/lD +y92HDLRQlovbxElcDpQCDS09l5Yl7qm7H0mIDzWwjozL4quvDwPKJuQZNNsCAwEA +AaNGMEQwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAfBgNVHSME +GDAWgBQV1YOR+Jpl1fbujvWLSBEoRvsDhTANBgkqhkiG9w0BAQsFAAOCAQEAfTPu +KeHXmyVTSCUrYQ1X5Mu7VzfZlRbhoytHOw7bYGgwaFwQj+ZhlPt8nFC22/bEk4IV +AoCOli0WyPIB7Lx52dZ+v9JmYOK6ca2Aa/Dkw8Q+M3XA024FQWq3nZ6qANKC32/9 +Nk+xOcb1Qd/11stpTkRf2Oj7F7K4GnlFbY6iMyNW+RFXGKEbL5QAJDTDPIT8vw1x +oYeNPwmC042uEboCZPNXmuctiK9Wt1TAxjZT/cwdIBGGJ+xrW72abfJGs7bUcJfc +O4r9V0xVv+X0iKWTW0fwd9qjNfiEP1tFCcZb2XsNQPe/DlQZ+h98P073tZEsWI/G +KJrFspGX8vOuSdIeqw== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/client_ca.pem b/arrow-flight/examples/data/client_ca.pem new file mode 100644 index 000000000000..aa483b931056 --- /dev/null +++ b/arrow-flight/examples/data/client_ca.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDGzCCAgOgAwIBAgIRAMNWpWRu6Q1txEYUyrkyXKEwDQYJKoZIhvcNAQELBQAw +JzElMCMGA1UEAxMcVG9uaWMgRXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0xOTEw +MTQxMjM5MzZaFw0yOTEwMTExMjM5MzZaMCcxJTAjBgNVBAMTHFRvbmljIEV4YW1w +bGUgQ2xpZW50IFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQCv8Nj4XJbMI0wWUvLbmCf7IEvJFnomodGnDurh8Y5AGMPJ8cGdZC1yo2Lgah+D +IhXdsd72Wp7MhdntJAyPrMCDBfDrFiuj6YHDgt3OhPQSYl7EWG7QjFK3B2sp1K5D +h16G5zfwUKDj9Jp3xuPGuqNFQHL02nwbhtDilqHvaTfOJKVjsFCoU8Z77mfwXSwn +sPXpPB7oOO4mWfAtcwU11rTMiHFSGFlFhgbHULU/y90DcpfRQEpEiBoiK13gkyoP +zHT9WAg3Pelwb6K7c7kJ7mp4axhbf7MkwFhDQIjbBWqus2Eu3b0mf86ALfDbAaNC +wBi8xbNH2vWaDjiwLDY5uMZDAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwICBDAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBQV1YOR+Jpl1fbujvWLSBEoRvsDhTANBgkq +hkiG9w0BAQsFAAOCAQEAaXmM29TYkFUzZUsV7TSonAK560BjxDmbg0GJSUgLEFUJ +wpKqa9UKOSapG45LEeR2wwAmVWDJomJplkuvTD/KOabAbZKyPEfp+VMCaBUnILQF +Cxv5m7kQ3wmPS/rEL8FD809UGowW9cYqnZzUy5i/r263rx0k3OPjkkZN66Mh6+3H +ibNdaxf7ITO0JVb/Ohq9vLC9qf7ujiB1atMdJwkOWsZrLJXLygpx/D0/UhBT4fFH +OlyVOmuR27qaMbPgOs2l8DznkJY/QUfnET8iOQhFgb0Dt/Os4PYFhSDRIrgl5dJ7 +L/zZVQfZYpdxlBHJlDC1/NzVQl/1MgDnSgPGStZKPQ== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/server.key b/arrow-flight/examples/data/server.key new file mode 100644 index 000000000000..80984ef9000d --- /dev/null +++ b/arrow-flight/examples/data/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDyptbMyYWztgta +t1MXLMzIkaQdeeVbs1Y/qCpAdwZe/Y5ZpbzjGIjCxbB6vNRSnEbYKpytKHPzYfM7 +8d8K8bPvpnqXIiTXFT0JQlw1OHLC1fr4e598GJumAmpMYFrtqv0fbmUFTuQGbHxe +OH2vji0bvr3NKZubMfkEZP3X4sNXXoXIuW2LaS8OMGKoJaeCBvdbszEiSGj/v9Bj +pM0yLTH89NNMX1T+FtTKnuXag5g7pr6lzJj83+MzAGy4nOjseSuUimuiyG90/C5t +A5wC0Qh5RbDnkFYhC44Kxof/i6+jnfateIPNiIIwQV+2f6G/aK1hgjekT10m/eoR +YDTf+e5ZAgMBAAECggEACODt7yRYjhDVLYaTtb9f5t7dYG67Y7WWLFIc6arxQryI +XuNfm/ej2WyeXn9WTYeGWBaHERbv1zH4UnMxNBdP/C7dQXZwXqZaS2JwOUpNeK+X +tUvgtAu6dkKUXSMRcKzXAjVp4N3YHhwOGOx8PNY49FDwZPdmyDD16aFAYIvdle6/ +PSMrj38rB1sbQQdmRob2FjJBSDZ44nsr+/nilrcOFNfNnWv7tQIWYVXNcLfdK/WJ +ZCDFhA8lr/Yon6MEq6ApTj2ZYRRGXPd6UeASJkmTZEUIUbeDcje/MO8cHkREpuRH +wm3pCjR7OdO4vc+/d/QmEvu5ns6wbTauelYnL616YQKBgQD414gJtpCHauNEUlFB +v/R3DzPI5NGp9PAqovOD8nCbI49Mw61gP/ExTIPKiR5uUX/5EL04uspaNkuohXk+ +ys0G5At0NfV7W39lzhvALEaSfleybvYxppbBrc20/q8Gvi/i30NY+1LM3RdtMiEw +hKHjU0SnFhJq0InFg3AO/iCeTQKBgQD5obkbzpOidSsa55aNsUlO2qjiUY9leq9b +irAohIZ8YnuuixYvkOeSeSz1eIrA4tECeAFSgTZxYe1Iz+USru2Xg/0xNte11dJD +rBoH/yMn2gDvBK7xQ6uFMPTeYtKG0vfvpXZYSWZzGntyrHTwFk6UV+xdrt9MBdd1 +XdSn7bwOPQKBgC9VQAko8uDvUf+C8PXiv2uONrl13PPJJY3WpR9qFEVOREnDxszS +HNzVwxPZdTJiykbkCjoqPadfQJDzopZxGQLAifU29lTamKcSx3CMe3gOFDxaovXa +zD5XAxP0hfJwZsdu1G6uj5dsTrJ0oJ+L+wc0pZBqwGIU/L/XOo9/g1DZAoGAUebL +kuH98ik7EUK2VJq8EJERI9/ailLsQb6I+WIxtZGiPqwHhWencpkrNQZtj8dbB9JT +rLwUHrMgZOlAoRafgTyez4zMzS3wJJ/Mkp8U67hM4h7JPwMSvUpIrMYDiJSjIA9L +er/qSw1/Pypx22uWMHmAZWRAgvLPtAQrB0Wqk4kCgYEAr2H1PvfbwZwkSvlMt5o8 +WLnBbxcM3AKglLRbkShxxgiZYdEP71/uOtRMiL26du5XX8evItITN0DsvmXL/kcd +h29LK7LM5uLw7efz0Qxs03G6kEyIHVkacowHi5I5Ul1qI61SoV3yMB1TjIU+bXZt +0ZjC07totO0fqPOLQxonjQg= +-----END PRIVATE KEY----- diff --git a/arrow-flight/examples/data/server.pem b/arrow-flight/examples/data/server.pem new file mode 100644 index 000000000000..4cc97bcf4b6d --- /dev/null +++ b/arrow-flight/examples/data/server.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEmDCCAwCgAwIBAgIQVEJFCgU/CZk9JEwTucWPpzANBgkqhkiG9w0BAQsFADCB +hTEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMS0wKwYDVQQLDCRsdWNp +b0BMdWNpb3MtV29yay1NQlAgKEx1Y2lvIEZyYW5jbykxNDAyBgNVBAMMK21rY2Vy +dCBsdWNpb0BMdWNpb3MtV29yay1NQlAgKEx1Y2lvIEZyYW5jbykwHhcNMTkwNjAx +MDAwMDAwWhcNMjkwOTI5MjMzNTM0WjBYMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxv +cG1lbnQgY2VydGlmaWNhdGUxLTArBgNVBAsMJGx1Y2lvQEx1Y2lvcy1Xb3JrLU1C +UCAoTHVjaW8gRnJhbmNvKTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +APKm1szJhbO2C1q3UxcszMiRpB155VuzVj+oKkB3Bl79jlmlvOMYiMLFsHq81FKc +RtgqnK0oc/Nh8zvx3wrxs++mepciJNcVPQlCXDU4csLV+vh7n3wYm6YCakxgWu2q +/R9uZQVO5AZsfF44fa+OLRu+vc0pm5sx+QRk/dfiw1dehci5bYtpLw4wYqglp4IG +91uzMSJIaP+/0GOkzTItMfz000xfVP4W1Mqe5dqDmDumvqXMmPzf4zMAbLic6Ox5 +K5SKa6LIb3T8Lm0DnALRCHlFsOeQViELjgrGh/+Lr6Od9q14g82IgjBBX7Z/ob9o +rWGCN6RPXSb96hFgNN/57lkCAwEAAaOBrzCBrDAOBgNVHQ8BAf8EBAMCBaAwEwYD +VR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAfBgNVHSMEGDAWgBQdvlE4 +Bdcsjc9oaxjDCRu5FiuZkzBWBgNVHREETzBNggtleGFtcGxlLmNvbYINKi5leGFt +cGxlLmNvbYIMZXhhbXBsZS50ZXN0gglsb2NhbGhvc3SHBH8AAAGHEAAAAAAAAAAA +AAAAAAAAAAEwDQYJKoZIhvcNAQELBQADggGBAKb2TJ8l+e1eraNwZWizLw5fccAf +y59J1JAWdLxZyAI/bkiTlVO3DQoPZpw7XwLhefCvILkwKAL4TtIGGVC9yTb5Q5eg +rqGO3FC0yg1fn65Kf1VpVxxUVyoiM5PQ4pFJb4AicAv88rCOLD9FFuE0PKOKU/dm +Tw0WgPStoh9wsJ1RXUuTJYZs1nd1kMBlfv9NbLilnL+cR2sLktS54X5XagsBYVlf +oapRb0JtABOoQhX3U8QMq8UF8yzceRHNTN9yfLOUrW26s9nKtlWVniNhw1uPxZw9 +RHM7w9/4+a9LXtEDYg4IP/1mm0ywBoUqy1O6hA73uId+Yi/kFBks/GyYaGjKgYcO +23B75tkPGYEdGuGZYLzZNHbXg4V0UxFQG3KA1pUiSnD3bN2Rxs+CMpzORnOeK3xi +EooKgAPYsehItoQOMPpccI2xHdSAMWtwUgOKrefUQujkx2Op+KFlspF0+WJ6AZEe +2D4hyWaEZsvvILXapwqHDCuN3/jSUlTIqUoE1w== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index aa0d407113d7..81afecf85625 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,31 +15,132 @@ // specific language governing permissions and limitations // under the License. -use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket}; -use futures::Stream; +use arrow_flight::sql::server::PeekableFlightDataStream; +use arrow_flight::sql::DoPutPreparedStatementResult; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use futures::{stream, Stream, TryStreamExt}; +use once_cell::sync::Lazy; +use prost::Message; +use std::collections::HashSet; use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use tonic::metadata::MetadataValue; use tonic::transport::Server; +use tonic::transport::{Certificate, Identity, ServerTlsConfig}; use tonic::{Request, Response, Status, Streaming}; +use arrow_array::builder::StringBuilder; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::sql::metadata::{ + SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder, +}; +use arrow_flight::sql::{ + server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult, + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionCancelQueryRequest, + ActionCancelQueryResult, ActionClosePreparedStatementRequest, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, + ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, + CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, + ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType, +}; +use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ - flight_service_server::FlightService, - flight_service_server::FlightServiceServer, - sql::{ - server::FlightSqlService, ActionClosePreparedStatementRequest, - ActionCreatePreparedStatementRequest, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, - CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementUpdate, - TicketStatementQuery, - }, - FlightDescriptor, FlightInfo, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, SchemaAsIpc, Ticket, }; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_schema::{ArrowError, DataType, Field, Schema}; + +macro_rules! status { + ($desc:expr, $err:expr) => { + Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) + }; +} + +const FAKE_TOKEN: &str = "uuid_token"; +const FAKE_HANDLE: &str = "uuid_handle"; +const FAKE_UPDATE_RESULT: i64 = 1; + +static INSTANCE_SQL_DATA: Lazy = Lazy::new(|| { + let mut builder = SqlInfoDataBuilder::new(); + // Server information + builder.append(SqlInfo::FlightSqlServerName, "Example Flight SQL Server"); + builder.append(SqlInfo::FlightSqlServerVersion, "1"); + // 1.3 comes from https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/Schema.fbs#L24 + builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3"); + builder.build().unwrap() +}); + +static INSTANCE_XBDC_DATA: Lazy = Lazy::new(|| { + let mut builder = XdbcTypeInfoDataBuilder::new(); + builder.append(XdbcTypeInfo { + type_name: "INTEGER".into(), + data_type: XdbcDataType::XdbcInteger, + column_size: Some(32), + literal_prefix: None, + literal_suffix: None, + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: Some(false), + fixed_prec_scale: false, + auto_increment: Some(false), + local_type_name: Some("INTEGER".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInteger, + datetime_subcode: None, + num_prec_radix: Some(2), + interval_precision: None, + }); + builder.build().unwrap() +}); + +static TABLES: Lazy> = Lazy::new(|| vec!["flight_sql.example.table"]); #[derive(Clone)] pub struct FlightSqlServiceImpl {} +impl FlightSqlServiceImpl { + fn check_token(&self, req: &Request) -> Result<(), Status> { + let metadata = req.metadata(); + let auth = metadata.get("authorization").ok_or_else(|| { + Status::internal(format!("No authorization header! metadata = {metadata:?}")) + })?; + let str = auth + .to_str() + .map_err(|e| Status::internal(format!("Error parsing header: {e}")))?; + let authorization = str.to_string(); + let bearer = "Bearer "; + if !authorization.starts_with(bearer) { + Err(Status::internal("Invalid auth header!"))?; + } + let token = authorization[bearer.len()..].to_string(); + if token == FAKE_TOKEN { + Ok(()) + } else { + Err(Status::unauthenticated("invalid token ")) + } + } + + fn fake_result() -> Result { + let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, false)]); + let mut builder = StringBuilder::new(); + builder.append_value("Hello, FlightSQL!"); + let cols = vec![Arc::new(builder.finish()) as ArrayRef]; + RecordBatch::try_new(Arc::new(schema), cols) + } +} + #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; @@ -55,41 +156,67 @@ impl FlightSqlService for FlightSqlServiceImpl { let authorization = request .metadata() .get("authorization") - .ok_or(Status::invalid_argument("authorization field not present"))? + .ok_or_else(|| Status::invalid_argument("authorization field not present"))? .to_str() - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + .map_err(|e| status!("authorization not parsable", e))?; if !authorization.starts_with(basic) { Err(Status::invalid_argument(format!( - "Auth type not implemented: {}", - authorization + "Auth type not implemented: {authorization}" )))?; } let base64 = &authorization[basic.len()..]; - let bytes = base64::decode(base64) - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; - let str = String::from_utf8(bytes) - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; - let parts: Vec<_> = str.split(":").collect(); - if parts.len() != 2 { - Err(Status::invalid_argument(format!( - "Invalid authorization header" - )))?; - } - let user = parts[0]; - let pass = parts[1]; - if user != "admin" || pass != "password" { + let bytes = BASE64_STANDARD + .decode(base64) + .map_err(|e| status!("authorization not decodable", e))?; + let str = String::from_utf8(bytes).map_err(|e| status!("authorization not parsable", e))?; + let parts: Vec<_> = str.split(':').collect(); + let (user, pass) = match parts.as_slice() { + [user, pass] => (user, pass), + _ => Err(Status::invalid_argument( + "Invalid authorization header".to_string(), + ))?, + }; + if user != &"admin" || pass != &"password" { Err(Status::unauthenticated("Invalid credentials!"))? } + let result = HandshakeResponse { protocol_version: 0, - payload: "random_uuid_token".as_bytes().to_vec(), + payload: FAKE_TOKEN.into(), }; let result = Ok(result); let output = futures::stream::iter(vec![result]); - return Ok(Response::new(Box::pin(output))); + + let token = format!("Bearer {}", FAKE_TOKEN); + let mut response: Response + Send>>> = + Response::new(Box::pin(output)); + response.metadata_mut().append( + "authorization", + MetadataValue::from_str(token.as_str()).unwrap(), + ); + return Ok(response); + } + + async fn do_get_fallback( + &self, + request: Request, + _message: Any, + ) -> Result::DoGetStream>, Status> { + self.check_token(&request)?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let schema = batch.schema_ref(); + let batches = vec![batch.clone()]; + let flight_data = batches_to_flight_data(schema, batches) + .map_err(|e| status!("Could not convert batches", e))? + .into_iter() + .map(Ok); + + let stream: Pin> + Send>> = + Box::pin(stream::iter(flight_data)); + let resp = Response::new(stream); + Ok(resp) } - // get_flight_info async fn get_flight_info_statement( &self, _query: CommandStatementQuery, @@ -100,44 +227,112 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } - async fn get_flight_info_prepared_statement( + async fn get_flight_info_substrait_plan( &self, - _query: CommandPreparedStatementQuery, + _query: CommandStatementSubstraitPlan, _request: Request, ) -> Result, Status> { Err(Status::unimplemented( - "get_flight_info_prepared_statement not implemented", + "get_flight_info_substrait_plan not implemented", )) } + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + request: Request, + ) -> Result, Status> { + self.check_token(&request)?; + let handle = std::str::from_utf8(&cmd.prepared_statement_handle) + .map_err(|e| status!("Unable to parse handle", e))?; + + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let schema = (*batch.schema()).clone(); + let num_rows = batch.num_rows(); + let num_bytes = batch.get_array_memory_size(); + + let fetch = FetchResults { + handle: handle.to_string(), + }; + let buf = fetch.as_any().encode_to_vec().into(); + let ticket = Ticket { ticket: buf }; + let endpoint = FlightEndpoint { + ticket: Some(ticket), + location: vec![], + expiration_time: None, + app_metadata: vec![].into(), + }; + let info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| status!("Unable to serialize schema", e))? + .with_descriptor(FlightDescriptor::new_cmd(vec![])) + .with_endpoint(endpoint) + .with_total_records(num_rows as i64) + .with_total_bytes(num_bytes as i64) + .with_ordered(false); + + let resp = Response::new(info); + Ok(resp) + } + async fn get_flight_info_catalogs( &self, - _query: CommandGetCatalogs, - _request: Request, + query: CommandGetCatalogs, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_catalogs not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_schemas( &self, - _query: CommandGetDbSchemas, - _request: Request, + query: CommandGetDbSchemas, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_schemas not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_tables( &self, - _query: CommandGetTables, - _request: Request, + query: CommandGetTables, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_tables not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_table_types( @@ -152,12 +347,20 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn get_flight_info_sql_info( &self, - _query: CommandGetSqlInfo, - _request: Request, + query: CommandGetSqlInfo, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_sql_info not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.as_any().encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(query.into_builder(&INSTANCE_SQL_DATA).schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_primary_keys( @@ -200,6 +403,24 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } + async fn get_flight_info_xdbc_type_info( + &self, + query: CommandGetXdbcTypeInfo, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.as_any().encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(query.into_builder(&INSTANCE_XBDC_DATA).schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) + } + // do_get async fn do_get_statement( &self, @@ -221,26 +442,91 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_catalogs( &self, - _query: CommandGetCatalogs, + query: CommandGetCatalogs, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_catalogs not implemented")) + let catalog_names = TABLES + .iter() + .map(|full_name| full_name.split('.').collect::>()[0].to_string()) + .collect::>(); + let mut builder = query.into_builder(); + for catalog_name in catalog_names { + builder.append(catalog_name); + } + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_schemas( &self, - _query: CommandGetDbSchemas, + query: CommandGetDbSchemas, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_schemas not implemented")) + let schemas = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + (parts[0].to_string(), parts[1].to_string()) + }) + .collect::>(); + + let mut builder = query.into_builder(); + for (catalog_name, schema_name) in schemas { + builder.append(catalog_name, schema_name); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_tables( &self, - _query: CommandGetTables, + query: CommandGetTables, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_tables not implemented")) + let tables = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + ( + parts[0].to_string(), + parts[1].to_string(), + parts[2].to_string(), + ) + }) + .collect::>(); + + let dummy_schema = Schema::empty(); + let mut builder = query.into_builder(); + for (catalog_name, schema_name, table_name) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + "TABLE", + &dummy_schema, + ) + .map_err(Status::from)?; + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_table_types( @@ -253,10 +539,17 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_sql_info( &self, - _query: CommandGetSqlInfo, + query: CommandGetSqlInfo, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_sql_info not implemented")) + let builder = query.into_builder(&INSTANCE_SQL_DATA); + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_primary_keys( @@ -297,22 +590,54 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } + async fn do_get_xdbc_type_info( + &self, + query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + // create a builder with pre-defined Xdbc data: + let builder = query.into_builder(&INSTANCE_XBDC_DATA); + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + // do_put async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request, + ) -> Result { + Ok(FAKE_UPDATE_RESULT) + } + + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Ok(FAKE_UPDATE_RESULT) + } + + async fn do_put_substrait_plan( + &self, + _ticket: CommandStatementSubstraitPlan, + _request: Request, ) -> Result { Err(Status::unimplemented( - "do_put_statement_update not implemented", + "do_put_substrait_plan not implemented", )) } async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, - ) -> Result::DoPutStream>, Status> { + _request: Request, + ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", )) @@ -321,27 +646,92 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", )) } - // do_action async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, - _request: Request, + request: Request, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + self.check_token(&request)?; + let record_batch = + Self::fake_result().map_err(|e| status!("Error getting result schema", e))?; + let schema = record_batch.schema_ref(); + let message = SchemaAsIpc::new(schema, &IpcWriteOptions::default()) + .try_into() + .map_err(|e| status!("Unable to serialize schema", e))?; + let IpcMessage(schema_bytes) = message; + let res = ActionCreatePreparedStatementResult { + prepared_statement_handle: FAKE_HANDLE.into(), + dataset_schema: schema_bytes, + parameter_schema: Default::default(), // TODO: parameters + }; + Ok(res) } + async fn do_action_close_prepared_statement( &self, _query: ActionClosePreparedStatementRequest, _request: Request, - ) { - unimplemented!("Not yet implemented") + ) -> Result<(), Status> { + Ok(()) + } + + async fn do_action_create_prepared_substrait_plan( + &self, + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "Implement do_action_create_prepared_substrait_plan", + )) + } + + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "Implement do_action_begin_transaction", + )) + } + + async fn do_action_end_transaction( + &self, + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented("Implement do_action_end_transaction")) + } + + async fn do_action_begin_savepoint( + &self, + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented("Implement do_action_begin_savepoint")) + } + + async fn do_action_end_savepoint( + &self, + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented("Implement do_action_end_savepoint")) + } + + async fn do_action_cancel_query( + &self, + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented("Implement do_action_cancel_query")) } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} @@ -350,13 +740,300 @@ impl FlightSqlService for FlightSqlServiceImpl { /// This example shows how to run a FlightSql server #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "0.0.0.0:50051".parse()?; - - let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + let addr_str = "0.0.0.0:50051"; + let addr = addr_str.parse()?; println!("Listening on {:?}", addr); - Server::builder().add_service(svc).serve(addr).await?; + if std::env::var("USE_TLS").ok().is_some() { + let cert = std::fs::read_to_string("arrow-flight/examples/data/server.pem")?; + let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?; + let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; + + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + let tls_config = ServerTlsConfig::new() + .identity(Identity::from_pem(&cert, &key)) + .client_ca_root(Certificate::from_pem(&client_ca)); + + Server::builder() + .tls_config(tls_config)? + .add_service(svc) + .serve(addr) + .await?; + } else { + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + + Server::builder().add_service(svc).serve(addr).await?; + } Ok(()) } + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchResults { + #[prost(string, tag = "1")] + pub handle: ::prost::alloc::string::String, +} + +impl ProstMessageExt for FetchResults { + fn type_url() -> &'static str { + "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" + } + + fn as_any(&self) -> Any { + Any { + type_url: FetchResults::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::{TryFutureExt, TryStreamExt}; + use hyper_util::rt::TokioIo; + use std::fs; + use std::future::Future; + use std::net::SocketAddr; + use std::time::Duration; + use tempfile::NamedTempFile; + use tokio::net::{TcpListener, UnixListener, UnixStream}; + use tokio_stream::wrappers::UnixListenerStream; + use tonic::transport::{Channel, ClientTlsConfig}; + + use arrow_cast::pretty::pretty_format_batches; + use arrow_flight::sql::client::FlightSqlServiceClient; + use tonic::transport::server::TcpIncoming; + use tonic::transport::{Certificate, Endpoint}; + use tower::service_fn; + + async fn bind_tcp() -> (TcpIncoming, SocketAddr) { + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + (incoming, addr) + } + + fn endpoint(uri: String) -> Result { + let endpoint = Endpoint::new(uri) + .map_err(|_| ArrowError::IpcError("Cannot create endpoint".to_string()))? + .connect_timeout(Duration::from_secs(20)) + .timeout(Duration::from_secs(20)) + .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait + .tcp_keepalive(Option::Some(Duration::from_secs(3600))) + .http2_keep_alive_interval(Duration::from_secs(300)) + .keep_alive_timeout(Duration::from_secs(20)) + .keep_alive_while_idle(true); + + Ok(endpoint) + } + + async fn auth_client(client: &mut FlightSqlServiceClient) { + let token = client.handshake("admin", "password").await.unwrap(); + client.set_token(String::from_utf8(token.to_vec()).unwrap()); + } + + async fn test_uds_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let file = NamedTempFile::new().unwrap(); + let path = file.into_temp_path().to_str().unwrap().to_string(); + let _ = fs::remove_file(path.clone()); + + let uds = UnixListener::bind(path.clone()).unwrap(); + let stream = UnixListenerStream::new(uds); + + let service = FlightSqlServiceImpl {}; + let serve_future = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming(stream); + + let request_future = async { + let connector = + service_fn(move |_| UnixStream::connect(path.clone()).map_ok(TokioIo::new)); + let channel = Endpoint::try_from("http://example.com") + .unwrap() + .connect_with_connector(connector) + .await + .unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_http_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let (incoming, addr) = bind_tcp().await; + let uri = format!("http://{}:{}", addr.ip(), addr.port()); + + let service = FlightSqlServiceImpl {}; + let serve_future = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming(incoming); + + let request_future = async { + let endpoint = endpoint(uri).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_https_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let cert = std::fs::read_to_string("examples/data/server.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/server.key").unwrap(); + let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap(); + + let tls_config = ServerTlsConfig::new() + .identity(Identity::from_pem(&cert, &key)) + .client_ca_root(Certificate::from_pem(&client_ca)); + + let (incoming, addr) = bind_tcp().await; + let uri = format!("https://{}:{}", addr.ip(), addr.port()); + + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + + let serve_future = Server::builder() + .tls_config(tls_config) + .unwrap() + .add_service(svc) + .serve_with_incoming(incoming); + + let request_future = async { + let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/client1.key").unwrap(); + let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap(); + + let tls_config = ClientTlsConfig::new() + .domain_name("localhost") + .ca_certificate(Certificate::from_pem(&server_ca)) + .identity(Identity::from_pem(cert, key)); + + let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_all_clients(task: F) + where + F: FnOnce(FlightSqlServiceClient) -> C + Copy, + C: Future, + { + println!("testing uds client"); + test_uds_client(task).await; + println!("======="); + + println!("testing http client"); + test_http_client(task).await; + println!("======="); + + println!("testing https client"); + test_https_client(task).await; + println!("======="); + } + + #[tokio::test] + async fn test_select() { + test_all_clients(|mut client| async move { + auth_client(&mut client).await; + + let mut stmt = client.prepare("select 1;".to_string(), None).await.unwrap(); + + let flight_info = stmt.execute().await.unwrap(); + + let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); + let flight_data = client.do_get(ticket).await.unwrap(); + let batches: Vec<_> = flight_data.try_collect().await.unwrap(); + + let res = pretty_format_batches(batches.as_slice()).unwrap(); + let expected = r#" ++-------------------+ +| salutation | ++-------------------+ +| Hello, FlightSQL! | ++-------------------+"# + .trim() + .to_string(); + assert_eq!(res.to_string(), expected); + }) + .await + } + + #[tokio::test] + async fn test_execute_update() { + test_all_clients(|mut client| async move { + auth_client(&mut client).await; + let res = client + .execute_update("creat table test(a int);".to_string(), None) + .await + .unwrap(); + assert_eq!(res, FAKE_UPDATE_RESULT); + }) + .await + } + + #[tokio::test] + async fn test_auth() { + test_all_clients(|mut client| async move { + // no handshake + assert!(client + .prepare("select 1;".to_string(), None) + .await + .unwrap_err() + .to_string() + .contains("No authorization header")); + + // Invalid credentials + assert!(client + .handshake("admin", "password2") + .await + .unwrap_err() + .to_string() + .contains("Invalid credentials")); + + // Invalid Tokens + client.handshake("admin", "password").await.unwrap(); + client.set_token("wrong token".to_string()); + assert!(client + .prepare("select 1;".to_string(), None) + .await + .unwrap_err() + .to_string() + .contains("invalid token")); + + client.clear_token(); + + // Successful call (token is automatically set by handshake) + client.handshake("admin", "password").await.unwrap(); + client.prepare("select 1;".to_string(), None).await.unwrap(); + }) + .await + } +} diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs index 75d05378710f..8c766b075957 100644 --- a/arrow-flight/examples/server.rs +++ b/arrow-flight/examples/server.rs @@ -15,16 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::pin::Pin; - -use futures::Stream; +use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, + HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; #[derive(Clone)] @@ -32,89 +30,82 @@ pub struct FlightServiceImpl {} #[tonic::async_trait] impl FlightService for FlightServiceImpl { - type HandshakeStream = Pin< - Box> + Send + Sync + 'static>, - >; - type ListFlightsStream = - Pin> + Send + Sync + 'static>>; - type DoGetStream = - Pin> + Send + Sync + 'static>>; - type DoPutStream = - Pin> + Send + Sync + 'static>>; - type DoActionStream = Pin< - Box< - dyn Stream> - + Send - + Sync - + 'static, - >, - >; - type ListActionsStream = - Pin> + Send + Sync + 'static>>; - type DoExchangeStream = - Pin> + Send + Sync + 'static>>; + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; async fn handshake( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement handshake")) } async fn list_flights( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement list_flights")) } async fn get_flight_info( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement get_flight_info")) + } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement poll_flight_info")) } async fn get_schema( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement get_schema")) } async fn do_get( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_get")) } async fn do_put( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_put")) } async fn do_action( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_action")) } async fn list_actions( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement list_actions")) } async fn do_exchange( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_exchange")) } } diff --git a/arrow-flight/gen/Cargo.toml b/arrow-flight/gen/Cargo.toml new file mode 100644 index 000000000000..08b53b729738 --- /dev/null +++ b/arrow-flight/gen/Cargo.toml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "gen" +description = "Code generation for arrow-flight" +version = "0.1.0" +edition = { workspace = true } +rust-version = { workspace = true } +authors = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Pin specific version of the tonic-build dependencies to avoid auto-generated +# (and checked in) arrow.flight.protocol.rs from changing +proc-macro2 = { version = "=1.0.86", default-features = false } +prost-build = { version = "=0.13.3", default-features = false } +tonic-build = { version = "=0.12.2", default-features = false, features = ["transport", "prost"] } diff --git a/arrow-flight/gen/src/main.rs b/arrow-flight/gen/src/main.rs new file mode 100644 index 000000000000..a3541c63b173 --- /dev/null +++ b/arrow-flight/gen/src/main.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fs::OpenOptions, + io::{Read, Write}, + path::Path, +}; + +fn main() -> Result<(), Box> { + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/Flight.proto"); + + tonic_build::configure() + // protoc in unbuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .out_dir("src") + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; + + // read file contents to string + let mut file = OpenOptions::new() + .read(true) + .open("src/arrow.flight.protocol.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .open("src/arrow.flight.protocol.rs")?; + file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/FlightSql.proto"); + + tonic_build::configure() + // protoc in ubuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .out_dir("src/sql") + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; + + // read file contents to string + let mut file = OpenOptions::new() + .read(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + + // Prost currently generates an empty file, this was fixed but then reverted + // https://github.com/tokio-rs/prost/pull/639 + let google_protobuf_rs = Path::new("src/sql/google.protobuf.rs"); + if google_protobuf_rs.exists() && google_protobuf_rs.metadata().unwrap().len() == 0 { + std::fs::remove_file(google_protobuf_rs).unwrap(); + } + + // As the proto file is checked in, the build should not fail if the file is not found + Ok(()) +} + +fn prost_config() -> prost_build::Config { + let mut config = prost_build::Config::new(); + config.bytes([".arrow"]); + config +} diff --git a/conbench/benchmarks.py b/arrow-flight/regen.sh old mode 100644 new mode 100755 similarity index 60% rename from conbench/benchmarks.py rename to arrow-flight/regen.sh index bc4c1796b85f..d83f9d580e8d --- a/conbench/benchmarks.py +++ b/arrow-flight/regen.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,27 +17,5 @@ # specific language governing permissions and limitations # under the License. -import conbench.runner - -import _criterion - - -@conbench.runner.register_benchmark -class TestBenchmark(conbench.runner.Benchmark): - name = "test" - - def run(self, **kwargs): - yield self.conbench.benchmark( - self._f(), - self.name, - options=kwargs, - ) - - def _f(self): - return lambda: 1 + 1 - - -@conbench.runner.register_benchmark -class CargoBenchmarks(_criterion.CriterionBenchmark): - name = "arrow-rs" - description = "Run Arrow Rust micro benchmarks." +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR && cargo run --manifest-path gen/Cargo.toml diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index 2b085d6d1f6b..f1eb549d54aa 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -1,117 +1,157 @@ // This file was automatically generated through the build.rs script, and should not be edited. +// This file is @generated by prost-build. /// -/// The request that a client provides to a server on handshake. +/// The request that a client provides to a server on handshake. #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeRequest { /// - /// A defined protocol version - #[prost(uint64, tag="1")] + /// A defined protocol version + #[prost(uint64, tag = "1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. - #[prost(bytes="vec", tag="2")] - pub payload: ::prost::alloc::vec::Vec, + /// Arbitrary auth/handshake info. + #[prost(bytes = "bytes", tag = "2")] + pub payload: ::prost::bytes::Bytes, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeResponse { /// - /// A defined protocol version - #[prost(uint64, tag="1")] + /// A defined protocol version + #[prost(uint64, tag = "1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. - #[prost(bytes="vec", tag="2")] - pub payload: ::prost::alloc::vec::Vec, + /// Arbitrary auth/handshake info. + #[prost(bytes = "bytes", tag = "2")] + pub payload: ::prost::bytes::Bytes, } /// -/// A message for doing simple auth. +/// A message for doing simple auth. #[derive(Clone, PartialEq, ::prost::Message)] pub struct BasicAuth { - #[prost(string, tag="2")] + #[prost(string, tag = "2")] pub username: ::prost::alloc::string::String, - #[prost(string, tag="3")] + #[prost(string, tag = "3")] pub password: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Empty { -} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Empty {} /// -/// Describes an available action, including both the name used for execution -/// along with a short description of the purpose of the action. +/// Describes an available action, including both the name used for execution +/// along with a short description of the purpose of the action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionType { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, - #[prost(string, tag="2")] + #[prost(string, tag = "2")] pub description: ::prost::alloc::string::String, } /// -/// A service specific expression that can be used to return a limited set -/// of available Arrow Flight streams. +/// A service specific expression that can be used to return a limited set +/// of available Arrow Flight streams. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Criteria { - #[prost(bytes="vec", tag="1")] - pub expression: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub expression: ::prost::bytes::Bytes, } /// -/// An opaque action specific for the service. +/// An opaque action specific for the service. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Action { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, - #[prost(bytes="vec", tag="2")] - pub body: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "2")] + pub body: ::prost::bytes::Bytes, +} +/// +/// The request of the CancelFlightInfo action. +/// +/// The request should be stored in Action.body. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CancelFlightInfoRequest { + #[prost(message, optional, tag = "1")] + pub info: ::core::option::Option, +} +/// +/// The request of the RenewFlightEndpoint action. +/// +/// The request should be stored in Action.body. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RenewFlightEndpointRequest { + #[prost(message, optional, tag = "1")] + pub endpoint: ::core::option::Option, } /// -/// An opaque result returned after executing an action. +/// An opaque result returned after executing an action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Result { - #[prost(bytes="vec", tag="1")] - pub body: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub body: ::prost::bytes::Bytes, } /// -/// Wrap the result of a getSchema call +/// The result of the CancelFlightInfo action. +/// +/// The result should be stored in Result.body. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CancelFlightInfoResult { + #[prost(enumeration = "CancelStatus", tag = "1")] + pub status: i32, +} +/// +/// Wrap the result of a getSchema call #[derive(Clone, PartialEq, ::prost::Message)] pub struct SchemaResult { - /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes="vec", tag="1")] - pub schema: ::prost::alloc::vec::Vec, + /// The schema of the dataset in its IPC form: + /// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + /// 4 bytes - the byte length of the payload + /// a flatbuffer Message whose header is the Schema + #[prost(bytes = "bytes", tag = "1")] + pub schema: ::prost::bytes::Bytes, } /// -/// The name or tag for a Flight. May be used as a way to retrieve or generate -/// a flight or be used to expose a set of previously defined flights. +/// The name or tag for a Flight. May be used as a way to retrieve or generate +/// a flight or be used to expose a set of previously defined flights. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightDescriptor { - #[prost(enumeration="flight_descriptor::DescriptorType", tag="1")] + #[prost(enumeration = "flight_descriptor::DescriptorType", tag = "1")] pub r#type: i32, /// - /// Opaque value used to express a command. Should only be defined when - /// type = CMD. - #[prost(bytes="vec", tag="2")] - pub cmd: ::prost::alloc::vec::Vec, + /// Opaque value used to express a command. Should only be defined when + /// type = CMD. + #[prost(bytes = "bytes", tag = "2")] + pub cmd: ::prost::bytes::Bytes, /// - /// List of strings identifying a particular dataset. Should only be defined - /// when type = PATH. - #[prost(string, repeated, tag="3")] + /// List of strings identifying a particular dataset. Should only be defined + /// when type = PATH. + #[prost(string, repeated, tag = "3")] pub path: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } /// Nested message and enum types in `FlightDescriptor`. pub mod flight_descriptor { /// - /// Describes what type of descriptor is defined. - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + /// Describes what type of descriptor is defined. + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] #[repr(i32)] pub enum DescriptorType { - /// Protobuf pattern, not used. + /// Protobuf pattern, not used. Unknown = 0, /// - /// A named path that identifies a dataset. A path is composed of a string - /// or list of strings describing a particular dataset. This is conceptually + /// A named path that identifies a dataset. A path is composed of a string + /// or list of strings describing a particular dataset. This is conceptually /// similar to a path inside a filesystem. Path = 1, /// - /// An opaque command to generate a dataset. + /// An opaque command to generate a dataset. Cmd = 2, } impl DescriptorType { @@ -121,98 +161,259 @@ pub mod flight_descriptor { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - DescriptorType::Unknown => "UNKNOWN", - DescriptorType::Path => "PATH", - DescriptorType::Cmd => "CMD", + Self::Unknown => "UNKNOWN", + Self::Path => "PATH", + Self::Cmd => "CMD", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNKNOWN" => Some(Self::Unknown), + "PATH" => Some(Self::Path), + "CMD" => Some(Self::Cmd), + _ => None, } } } } /// -/// The access coordinates for retrieval of a dataset. With a FlightInfo, a -/// consumer is able to determine how to retrieve a dataset. +/// The access coordinates for retrieval of a dataset. With a FlightInfo, a +/// consumer is able to determine how to retrieve a dataset. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightInfo { - /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes="vec", tag="1")] - pub schema: ::prost::alloc::vec::Vec, + /// The schema of the dataset in its IPC form: + /// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + /// 4 bytes - the byte length of the payload + /// a flatbuffer Message whose header is the Schema + #[prost(bytes = "bytes", tag = "1")] + pub schema: ::prost::bytes::Bytes, /// - /// The descriptor associated with this info. - #[prost(message, optional, tag="2")] + /// The descriptor associated with this info. + #[prost(message, optional, tag = "2")] pub flight_descriptor: ::core::option::Option, /// - /// A list of endpoints associated with the flight. To consume the whole - /// flight, all endpoints must be consumed. - #[prost(message, repeated, tag="3")] + /// A list of endpoints associated with the flight. To consume the + /// whole flight, all endpoints (and hence all Tickets) must be + /// consumed. Endpoints can be consumed in any order. + /// + /// In other words, an application can use multiple endpoints to + /// represent partitioned data. + /// + /// If the returned data has an ordering, an application can use + /// "FlightInfo.ordered = true" or should return the all data in a + /// single endpoint. Otherwise, there is no ordering defined on + /// endpoints or the data within. + /// + /// A client can read ordered data by reading data from returned + /// endpoints, in order, from front to back. + /// + /// Note that a client may ignore "FlightInfo.ordered = true". If an + /// ordering is important for an application, an application must + /// choose one of them: + /// + /// * An application requires that all clients must read data in + /// returned endpoints order. + /// * An application must return the all data in a single endpoint. + #[prost(message, repeated, tag = "3")] pub endpoint: ::prost::alloc::vec::Vec, - /// Set these to -1 if unknown. - #[prost(int64, tag="4")] + /// Set these to -1 if unknown. + #[prost(int64, tag = "4")] pub total_records: i64, - #[prost(int64, tag="5")] + #[prost(int64, tag = "5")] pub total_bytes: i64, + /// + /// FlightEndpoints are in the same order as the data. + #[prost(bool, tag = "6")] + pub ordered: bool, + /// + /// Application-defined metadata. + /// + /// There is no inherent or required relationship between this + /// and the app_metadata fields in the FlightEndpoints or resulting + /// FlightData messages. Since this metadata is application-defined, + /// a given application could define there to be a relationship, + /// but there is none required by the spec. + #[prost(bytes = "bytes", tag = "7")] + pub app_metadata: ::prost::bytes::Bytes, +} +/// +/// The information to process a long-running query. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PollInfo { + /// + /// The currently available results. + /// + /// If "flight_descriptor" is not specified, the query is complete + /// and "info" specifies all results. Otherwise, "info" contains + /// partial query results. + /// + /// Note that each PollInfo response contains a complete + /// FlightInfo (not just the delta between the previous and current + /// FlightInfo). + /// + /// Subsequent PollInfo responses may only append new endpoints to + /// info. + /// + /// Clients can begin fetching results via DoGet(Ticket) with the + /// ticket in the info before the query is + /// completed. FlightInfo.ordered is also valid. + #[prost(message, optional, tag = "1")] + pub info: ::core::option::Option, + /// + /// The descriptor the client should use on the next try. + /// If unset, the query is complete. + #[prost(message, optional, tag = "2")] + pub flight_descriptor: ::core::option::Option, + /// + /// Query progress. If known, must be in \[0.0, 1.0\] but need not be + /// monotonic or nondecreasing. If unknown, do not set. + #[prost(double, optional, tag = "3")] + pub progress: ::core::option::Option, + /// + /// Expiration time for this request. After this passes, the server + /// might not accept the retry descriptor anymore (and the query may + /// be cancelled). This may be updated on a call to PollFlightInfo. + #[prost(message, optional, tag = "4")] + pub expiration_time: ::core::option::Option<::prost_types::Timestamp>, } /// -/// A particular stream or split associated with a flight. +/// A particular stream or split associated with a flight. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightEndpoint { /// - /// Token used to retrieve this stream. - #[prost(message, optional, tag="1")] + /// Token used to retrieve this stream. + #[prost(message, optional, tag = "1")] pub ticket: ::core::option::Option, /// - /// A list of URIs where this ticket can be redeemed. If the list is - /// empty, the expectation is that the ticket can only be redeemed on the - /// current service where the ticket was generated. - #[prost(message, repeated, tag="2")] + /// A list of URIs where this ticket can be redeemed via DoGet(). + /// + /// If the list is empty, the expectation is that the ticket can only + /// be redeemed on the current service where the ticket was + /// generated. + /// + /// If the list is not empty, the expectation is that the ticket can + /// be redeemed at any of the locations, and that the data returned + /// will be equivalent. In this case, the ticket may only be redeemed + /// at one of the given locations, and not (necessarily) on the + /// current service. + /// + /// In other words, an application can use multiple locations to + /// represent redundant and/or load balanced services. + #[prost(message, repeated, tag = "2")] pub location: ::prost::alloc::vec::Vec, + /// + /// Expiration time of this stream. If present, clients may assume + /// they can retry DoGet requests. Otherwise, it is + /// application-defined whether DoGet requests may be retried. + #[prost(message, optional, tag = "3")] + pub expiration_time: ::core::option::Option<::prost_types::Timestamp>, + /// + /// Application-defined metadata. + /// + /// There is no inherent or required relationship between this + /// and the app_metadata fields in the FlightInfo or resulting + /// FlightData messages. Since this metadata is application-defined, + /// a given application could define there to be a relationship, + /// but there is none required by the spec. + #[prost(bytes = "bytes", tag = "4")] + pub app_metadata: ::prost::bytes::Bytes, } /// -/// A location where a Flight service will accept retrieval of a particular -/// stream given a ticket. +/// A location where a Flight service will accept retrieval of a particular +/// stream given a ticket. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Location { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub uri: ::prost::alloc::string::String, } /// -/// An opaque identifier that the service can use to retrieve a particular -/// portion of a stream. +/// An opaque identifier that the service can use to retrieve a particular +/// portion of a stream. +/// +/// Tickets are meant to be single use. It is an error/application-defined +/// behavior to reuse a ticket. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Ticket { - #[prost(bytes="vec", tag="1")] - pub ticket: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub ticket: ::prost::bytes::Bytes, } /// -/// A batch of Arrow data as part of a stream of batches. +/// A batch of Arrow data as part of a stream of batches. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightData { /// - /// The descriptor of the data. This is only relevant when a client is - /// starting a new DoPut stream. - #[prost(message, optional, tag="1")] + /// The descriptor of the data. This is only relevant when a client is + /// starting a new DoPut stream. + #[prost(message, optional, tag = "1")] pub flight_descriptor: ::core::option::Option, /// - /// Header for message data as described in Message.fbs::Message. - #[prost(bytes="vec", tag="2")] - pub data_header: ::prost::alloc::vec::Vec, + /// Header for message data as described in Message.fbs::Message. + #[prost(bytes = "bytes", tag = "2")] + pub data_header: ::prost::bytes::Bytes, /// - /// Application-defined metadata. - #[prost(bytes="vec", tag="3")] - pub app_metadata: ::prost::alloc::vec::Vec, + /// Application-defined metadata. + #[prost(bytes = "bytes", tag = "3")] + pub app_metadata: ::prost::bytes::Bytes, /// - /// The actual batch of Arrow data. Preferably handled with minimal-copies - /// coming last in the definition to help with sidecar patterns (it is - /// expected that some implementations will fetch this field off the wire - /// with specialized code to avoid extra memory copies). - #[prost(bytes="vec", tag="1000")] - pub data_body: ::prost::alloc::vec::Vec, + /// The actual batch of Arrow data. Preferably handled with minimal-copies + /// coming last in the definition to help with sidecar patterns (it is + /// expected that some implementations will fetch this field off the wire + /// with specialized code to avoid extra memory copies). + #[prost(bytes = "bytes", tag = "1000")] + pub data_body: ::prost::bytes::Bytes, } /// * -/// The response message associated with the submission of a DoPut. +/// The response message associated with the submission of a DoPut. #[derive(Clone, PartialEq, ::prost::Message)] pub struct PutResult { - #[prost(bytes="vec", tag="1")] - pub app_metadata: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub app_metadata: ::prost::bytes::Bytes, +} +/// +/// The result of a cancel operation. +/// +/// This is used by CancelFlightInfoResult.status. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum CancelStatus { + /// The cancellation status is unknown. Servers should avoid using + /// this value (send a NOT_FOUND error if the requested query is + /// not known). Clients can retry the request. + Unspecified = 0, + /// The cancellation request is complete. Subsequent requests with + /// the same payload may return CANCELLED or a NOT_FOUND error. + Cancelled = 1, + /// The cancellation request is in progress. The client may retry + /// the cancellation request. + Cancelling = 2, + /// The query is not cancellable. The client should not retry the + /// cancellation request. + NotCancellable = 3, +} +impl CancelStatus { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "CANCEL_STATUS_UNSPECIFIED", + Self::Cancelled => "CANCEL_STATUS_CANCELLED", + Self::Cancelling => "CANCEL_STATUS_CANCELLING", + Self::NotCancellable => "CANCEL_STATUS_NOT_CANCELLABLE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CANCEL_STATUS_UNSPECIFIED" => Some(Self::Unspecified), + "CANCEL_STATUS_CANCELLED" => Some(Self::Cancelled), + "CANCEL_STATUS_CANCELLING" => Some(Self::Cancelling), + "CANCEL_STATUS_NOT_CANCELLABLE" => Some(Self::NotCancellable), + _ => None, + } + } } /// Generated client implementations. pub mod flight_service_client { @@ -232,7 +433,7 @@ pub mod flight_service_client { /// Attempt to create a new client by connecting to a given endpoint. pub async fn connect(dst: D) -> Result where - D: std::convert::TryInto, + D: TryInto, D::Error: Into, { let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; @@ -243,8 +444,8 @@ pub mod flight_service_client { where T: tonic::client::GrpcService, T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, { pub fn new(inner: T) -> Self { let inner = tonic::client::Grpc::new(inner); @@ -269,7 +470,7 @@ pub mod flight_service_client { >, , - >>::Error: Into + Send + Sync, + >>::Error: Into + std::marker::Send + std::marker::Sync, { FlightServiceClient::new(InterceptedService::new(inner, interceptor)) } @@ -288,6 +489,22 @@ pub mod flight_service_client { self.inner = self.inner.accept_compressed(encoding); self } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } /// /// Handshake between client and server. Depending on the server, the /// handshake may be required to determine the token that should be used for @@ -296,7 +513,7 @@ pub mod flight_service_client { pub async fn handshake( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -313,7 +530,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/Handshake", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "Handshake"), + ); + self.inner.streaming(req, path, codec).await } /// /// Get a list of available streams given a particular criteria. Most flight @@ -325,7 +547,7 @@ pub mod flight_service_client { pub async fn list_flights( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -342,7 +564,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListFlights", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "ListFlights"), + ); + self.inner.server_streaming(req, path, codec).await } /// /// For a given FlightDescriptor, get information about how the flight can be @@ -358,7 +585,7 @@ pub mod flight_service_client { pub async fn get_flight_info( &mut self, request: impl tonic::IntoRequest, - ) -> Result, tonic::Status> { + ) -> std::result::Result, tonic::Status> { self.inner .ready() .await @@ -372,7 +599,65 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetFlightInfo", ); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "arrow.flight.protocol.FlightService", + "GetFlightInfo", + ), + ); + self.inner.unary(req, path, codec).await + } + /// + /// For a given FlightDescriptor, start a query and get information + /// to poll its execution status. This is a useful interface if the + /// query may be a long-running query. The first PollFlightInfo call + /// should return as quickly as possible. (GetFlightInfo doesn't + /// return until the query is complete.) + /// + /// A client can consume any available results before + /// the query is completed. See PollInfo.info for details. + /// + /// A client can poll the updated query status by calling + /// PollFlightInfo() with PollInfo.flight_descriptor. A server + /// should not respond until the result would be different from last + /// time. That way, the client can "long poll" for updates + /// without constantly making requests. Clients can set a short timeout + /// to avoid blocking calls if desired. + /// + /// A client can't use PollInfo.flight_descriptor after + /// PollInfo.expiration_time passes. A server might not accept the + /// retry descriptor anymore and the query may be cancelled. + /// + /// A client may use the CancelFlightInfo action with + /// PollInfo.info to cancel the running query. + pub async fn poll_flight_info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/arrow.flight.protocol.FlightService/PollFlightInfo", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "arrow.flight.protocol.FlightService", + "PollFlightInfo", + ), + ); + self.inner.unary(req, path, codec).await } /// /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema @@ -382,7 +667,7 @@ pub mod flight_service_client { pub async fn get_schema( &mut self, request: impl tonic::IntoRequest, - ) -> Result, tonic::Status> { + ) -> std::result::Result, tonic::Status> { self.inner .ready() .await @@ -396,7 +681,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetSchema", ); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "GetSchema"), + ); + self.inner.unary(req, path, codec).await } /// /// Retrieve a single stream associated with a particular descriptor @@ -406,7 +696,7 @@ pub mod flight_service_client { pub async fn do_get( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -423,7 +713,10 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoGet", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("arrow.flight.protocol.FlightService", "DoGet")); + self.inner.server_streaming(req, path, codec).await } /// /// Push a stream to the flight service associated with a particular @@ -435,7 +728,7 @@ pub mod flight_service_client { pub async fn do_put( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -452,7 +745,10 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoPut", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("arrow.flight.protocol.FlightService", "DoPut")); + self.inner.streaming(req, path, codec).await } /// /// Open a bidirectional data channel for a given descriptor. This @@ -463,7 +759,7 @@ pub mod flight_service_client { pub async fn do_exchange( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -480,7 +776,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoExchange", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "DoExchange"), + ); + self.inner.streaming(req, path, codec).await } /// /// Flight services can support an arbitrary number of simple actions in @@ -492,7 +793,7 @@ pub mod flight_service_client { pub async fn do_action( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -509,7 +810,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoAction", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "DoAction"), + ); + self.inner.server_streaming(req, path, codec).await } /// /// A flight service exposes all of the available action types that it has @@ -518,7 +824,7 @@ pub mod flight_service_client { pub async fn list_actions( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -535,7 +841,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListActions", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "ListActions"), + ); + self.inner.server_streaming(req, path, codec).await } } } @@ -543,14 +854,14 @@ pub mod flight_service_client { pub mod flight_service_server { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - ///Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer. + /// Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer. #[async_trait] - pub trait FlightService: Send + Sync + 'static { - ///Server streaming response type for the Handshake method. - type HandshakeStream: futures_core::Stream< - Item = Result, + pub trait FlightService: std::marker::Send + std::marker::Sync + 'static { + /// Server streaming response type for the Handshake method. + type HandshakeStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Handshake between client and server. Depending on the server, the @@ -560,12 +871,12 @@ pub mod flight_service_server { async fn handshake( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the ListFlights method. - type ListFlightsStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ListFlights method. + type ListFlightsStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Get a list of available streams given a particular criteria. Most flight @@ -577,7 +888,10 @@ pub mod flight_service_server { async fn list_flights( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; /// /// For a given FlightDescriptor, get information about how the flight can be /// consumed. This is a useful interface if the consumer of the interface @@ -592,7 +906,34 @@ pub mod flight_service_server { async fn get_flight_info( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result, tonic::Status>; + /// + /// For a given FlightDescriptor, start a query and get information + /// to poll its execution status. This is a useful interface if the + /// query may be a long-running query. The first PollFlightInfo call + /// should return as quickly as possible. (GetFlightInfo doesn't + /// return until the query is complete.) + /// + /// A client can consume any available results before + /// the query is completed. See PollInfo.info for details. + /// + /// A client can poll the updated query status by calling + /// PollFlightInfo() with PollInfo.flight_descriptor. A server + /// should not respond until the result would be different from last + /// time. That way, the client can "long poll" for updates + /// without constantly making requests. Clients can set a short timeout + /// to avoid blocking calls if desired. + /// + /// A client can't use PollInfo.flight_descriptor after + /// PollInfo.expiration_time passes. A server might not accept the + /// retry descriptor anymore and the query may be cancelled. + /// + /// A client may use the CancelFlightInfo action with + /// PollInfo.info to cancel the running query. + async fn poll_flight_info( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; /// /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema /// This is used when a consumer needs the Schema of flight stream. Similar to @@ -601,12 +942,12 @@ pub mod flight_service_server { async fn get_schema( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoGet method. - type DoGetStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoGet method. + type DoGetStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Retrieve a single stream associated with a particular descriptor @@ -616,12 +957,12 @@ pub mod flight_service_server { async fn do_get( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoPut method. - type DoPutStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoPut method. + type DoPutStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Push a stream to the flight service associated with a particular @@ -633,12 +974,12 @@ pub mod flight_service_server { async fn do_put( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoExchange method. - type DoExchangeStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoExchange method. + type DoExchangeStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Open a bidirectional data channel for a given descriptor. This @@ -649,12 +990,12 @@ pub mod flight_service_server { async fn do_exchange( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoAction method. - type DoActionStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoAction method. + type DoActionStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// Flight services can support an arbitrary number of simple actions in @@ -666,12 +1007,12 @@ pub mod flight_service_server { async fn do_action( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the ListActions method. - type ListActionsStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ListActions method. + type ListActionsStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > - + Send + + std::marker::Send + 'static; /// /// A flight service exposes all of the available action types that it has @@ -680,7 +1021,10 @@ pub mod flight_service_server { async fn list_actions( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; } /// /// A flight service is an endpoint for retrieving or storing Arrow data. A @@ -688,22 +1032,24 @@ pub mod flight_service_server { /// accessed using the Arrow Flight Protocol. Additionally, a flight service /// can expose a set of actions that are available. #[derive(Debug)] - pub struct FlightServiceServer { - inner: _Inner, + pub struct FlightServiceServer { + inner: Arc, accept_compression_encodings: EnabledCompressionEncodings, send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, } - struct _Inner(Arc); - impl FlightServiceServer { + impl FlightServiceServer { pub fn new(inner: T) -> Self { Self::from_arc(Arc::new(inner)) } pub fn from_arc(inner: Arc) -> Self { - let inner = _Inner(inner); Self { inner, accept_compression_encodings: Default::default(), send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } pub fn with_interceptor( @@ -727,12 +1073,28 @@ pub mod flight_service_server { self.send_compression_encodings.enable(encoding); self } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } } impl tonic::codegen::Service> for FlightServiceServer where T: FlightService, - B: Body + Send + 'static, - B::Error: Into + Send + 'static, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, { type Response = http::Response; type Error = std::convert::Infallible; @@ -740,11 +1102,10 @@ pub mod flight_service_server { fn poll_ready( &mut self, _cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { - let inner = self.inner.clone(); match req.uri().path() { "/arrow.flight.protocol.FlightService/Handshake" => { #[allow(non_camel_case_types)] @@ -765,22 +1126,29 @@ pub mod flight_service_server { tonic::Streaming, >, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).handshake(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::handshake(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = HandshakeSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -804,24 +1172,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_flights(request).await + ::list_flights(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = ListFlightsSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -844,24 +1217,75 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).get_flight_info(request).await + ::get_flight_info(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = GetFlightInfoSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/PollFlightInfo" => { + #[allow(non_camel_case_types)] + struct PollFlightInfoSvc(pub Arc); + impl< + T: FlightService, + > tonic::server::UnaryService + for PollFlightInfoSvc { + type Response = super::PollInfo; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::poll_flight_info(&inner, request) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = PollFlightInfoSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.unary(method, req).await; Ok(res) @@ -884,22 +1308,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).get_schema(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_schema(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = GetSchemaSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.unary(method, req).await; Ok(res) @@ -923,22 +1354,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_get(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_get(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = DoGetSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -962,22 +1400,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request>, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_put(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_put(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = DoPutSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -1001,22 +1446,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request>, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_exchange(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_exchange(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = DoExchangeSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -1040,22 +1492,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_action(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_action(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = DoActionSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -1079,24 +1538,29 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_actions(request).await + ::list_actions(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { - let inner = inner.0; let method = ListActionsSvc(inner); let codec = tonic::codec::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -1108,8 +1572,11 @@ pub mod flight_service_server { Ok( http::Response::builder() .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") + .header("grpc-status", tonic::Code::Unimplemented as i32) + .header( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ) .body(empty_body()) .unwrap(), ) @@ -1118,27 +1585,21 @@ pub mod flight_service_server { } } } - impl Clone for FlightServiceServer { + impl Clone for FlightServiceServer { fn clone(&self) -> Self { let inner = self.inner.clone(); Self { inner, accept_compression_encodings: self.accept_compression_encodings, send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, } } } - impl Clone for _Inner { - fn clone(&self) -> Self { - Self(self.0.clone()) - } - } - impl std::fmt::Debug for _Inner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } - } - impl tonic::server::NamedService for FlightServiceServer { - const NAME: &'static str = "arrow.flight.protocol.FlightService"; + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "arrow.flight.protocol.FlightService"; + impl tonic::server::NamedService for FlightServiceServer { + const NAME: &'static str = SERVICE_NAME; } } diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs new file mode 100644 index 000000000000..c334b95a9a96 --- /dev/null +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{sync::Arc, time::Duration}; + +use anyhow::{bail, Context, Result}; +use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; +use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; +use arrow_flight::{ + sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, + FlightInfo, +}; +use arrow_schema::Schema; +use clap::{Parser, Subcommand}; +use futures::TryStreamExt; +use tonic::{ + metadata::MetadataMap, + transport::{Channel, ClientTlsConfig, Endpoint}, +}; +use tracing_log::log::info; + +/// Logging CLI config. +#[derive(Debug, Parser)] +pub struct LoggingArgs { + /// Log verbosity. + /// + /// Defaults to "warn". + /// + /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace". + /// + /// Note you can also set logging level using `RUST_LOG` environment variable: + /// `RUST_LOG=debug`. + #[clap( + short = 'v', + long = "verbose", + action = clap::ArgAction::Count, + )] + log_verbose_count: u8, +} + +#[derive(Debug, Parser)] +struct ClientArgs { + /// Additional headers. + /// + /// Can be given multiple times. Headers and values are separated by '='. + /// + /// Example: `-H foo=bar -H baz=42` + #[clap(long = "header", short = 'H', value_parser = parse_key_val)] + headers: Vec<(String, String)>, + + /// Username. + /// + /// Optional. If given, `password` must also be set. + #[clap(long, requires = "password")] + username: Option, + + /// Password. + /// + /// Optional. If given, `username` must also be set. + #[clap(long, requires = "username")] + password: Option, + + /// Auth token. + #[clap(long)] + token: Option, + + /// Use TLS. + /// + /// If not provided, use cleartext connection. + #[clap(long)] + tls: bool, + + /// Server host. + /// + /// Required. + #[clap(long)] + host: String, + + /// Server port. + /// + /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. + #[clap(long)] + port: Option, +} + +#[derive(Debug, Parser)] +struct Args { + /// Logging args. + #[clap(flatten)] + logging_args: LoggingArgs, + + /// Client args. + #[clap(flatten)] + client_args: ClientArgs, + + #[clap(subcommand)] + cmd: Command, +} + +/// Different available commands. +#[derive(Debug, Subcommand)] +enum Command { + /// Get catalogs. + Catalogs, + /// Get db schemas for a catalog. + DbSchemas { + /// Name of a catalog. + /// + /// Required. + catalog: String, + /// Specifies a filter pattern for schemas to search for. + /// When no schema_filter is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + db_schema_filter: Option, + }, + /// Get tables for a catalog. + Tables { + /// Name of a catalog. + /// + /// Required. + catalog: String, + /// Specifies a filter pattern for schemas to search for. + /// When no schema_filter is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + db_schema_filter: Option, + /// Specifies a filter pattern for tables to search for. + /// When no table_filter is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + table_filter: Option, + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + #[clap(long)] + table_types: Vec, + }, + /// Get table types. + TableTypes, + + /// Execute given statement. + StatementQuery { + /// SQL query. + /// + /// Required. + query: String, + }, + + /// Prepare given statement and then execute it. + PreparedStatementQuery { + /// SQL query. + /// + /// Required. + /// + /// Can contains placeholders like `$1`. + /// + /// Example: `SELECT * FROM t WHERE x = $1` + query: String, + + /// Additional parameters. + /// + /// Can be given multiple times. Names and values are separated by '='. Values will be + /// converted to the type that the server reported for the prepared statement. + /// + /// Example: `-p $1=42` + #[clap(short, value_parser = parse_key_val)] + params: Vec<(String, String)>, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + setup_logging(args.logging_args)?; + let mut client = setup_client(args.client_args) + .await + .context("setup client")?; + + let flight_info = match args.cmd { + Command::Catalogs => client.get_catalogs().await.context("get catalogs")?, + Command::DbSchemas { + catalog, + db_schema_filter, + } => client + .get_db_schemas(CommandGetDbSchemas { + catalog: Some(catalog), + db_schema_filter_pattern: db_schema_filter, + }) + .await + .context("get db schemas")?, + Command::Tables { + catalog, + db_schema_filter, + table_filter, + table_types, + } => client + .get_tables(CommandGetTables { + catalog: Some(catalog), + db_schema_filter_pattern: db_schema_filter, + table_name_filter_pattern: table_filter, + table_types, + // Schema is returned as ipc encoded bytes. + // We do not support returning the schema as there is no trivial mechanism + // to display the information to the user. + include_schema: false, + }) + .await + .context("get tables")?, + Command::TableTypes => client.get_table_types().await.context("get table types")?, + Command::StatementQuery { query } => client + .execute(query, None) + .await + .context("execute statement")?, + Command::PreparedStatementQuery { query, params } => { + let mut prepared_stmt = client + .prepare(query, None) + .await + .context("prepare statement")?; + + if !params.is_empty() { + prepared_stmt + .set_parameters( + construct_record_batch_from_params( + ¶ms, + prepared_stmt + .parameter_schema() + .context("get parameter schema")?, + ) + .context("construct parameters")?, + ) + .context("bind parameters")?; + } + + prepared_stmt + .execute() + .await + .context("execute prepared statement")? + } + }; + + let batches = execute_flight(&mut client, flight_info) + .await + .context("read flight data")?; + + let res = pretty_format_batches(batches.as_slice()).context("format results")?; + println!("{res}"); + + Ok(()) +} + +async fn execute_flight( + client: &mut FlightSqlServiceClient, + info: FlightInfo, +) -> Result> { + let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?); + let mut batches = Vec::with_capacity(info.endpoint.len() + 1); + batches.push(RecordBatch::new_empty(schema)); + info!("decoded schema"); + + for endpoint in info.endpoint { + let Some(ticket) = &endpoint.ticket else { + bail!("did not get ticket"); + }; + + let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?; + log_metadata(flight_data.headers(), "header"); + + let mut endpoint_batches: Vec<_> = (&mut flight_data) + .try_collect() + .await + .context("collect data stream")?; + batches.append(&mut endpoint_batches); + + if let Some(trailers) = flight_data.trailers() { + log_metadata(&trailers, "trailer"); + } + } + info!("received data"); + + Ok(batches) +} + +fn construct_record_batch_from_params( + params: &[(String, String)], + parameter_schema: &Schema, +) -> Result { + let mut items = Vec::<(&String, ArrayRef)>::new(); + + for (name, value) in params { + let field = parameter_schema.field_with_name(name)?; + let value_as_array = StringArray::new_scalar(value); + let casted = cast_with_options( + value_as_array.get().0, + field.data_type(), + &CastOptions::default(), + )?; + items.push((name, casted)) + } + + Ok(RecordBatch::try_from_iter(items)?) +} + +fn setup_logging(args: LoggingArgs) -> Result<()> { + use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber}; + + tracing_log::LogTracer::init().context("tracing log init")?; + + let filter = match args.log_verbose_count { + 0 => "warn", + 1 => "info", + 2 => "debug", + _ => "trace", + }; + let filter = EnvFilter::try_new(filter).context("set up log env filter")?; + + let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish(); + subscriber.try_init().context("init logging subscriber")?; + + Ok(()) +} + +async fn setup_client(args: ClientArgs) -> Result> { + let port = args.port.unwrap_or(if args.tls { 443 } else { 80 }); + + let protocol = if args.tls { "https" } else { "http" }; + + let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port)) + .context("create endpoint")? + .connect_timeout(Duration::from_secs(20)) + .timeout(Duration::from_secs(20)) + .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait + .tcp_keepalive(Option::Some(Duration::from_secs(3600))) + .http2_keep_alive_interval(Duration::from_secs(300)) + .keep_alive_timeout(Duration::from_secs(20)) + .keep_alive_while_idle(true); + + if args.tls { + let tls_config = ClientTlsConfig::new(); + endpoint = endpoint + .tls_config(tls_config) + .context("create TLS endpoint")?; + } + + let channel = endpoint.connect().await.context("connect to endpoint")?; + + let mut client = FlightSqlServiceClient::new(channel); + info!("connected"); + + for (k, v) in args.headers { + client.set_header(k, v); + } + + if let Some(token) = args.token { + client.set_token(token); + info!("token set"); + } + + match (args.username, args.password) { + (None, None) => {} + (Some(username), Some(password)) => { + client + .handshake(&username, &password) + .await + .context("handshake")?; + info!("performed handshake"); + } + (Some(_), None) => { + bail!("when username is set, you also need to set a password") + } + (None, Some(_)) => { + bail!("when password is set, you also need to set a username") + } + } + + Ok(client) +} + +/// Parse a single key-value pair +fn parse_key_val(s: &str) -> Result<(String, String), String> { + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].to_owned(), s[pos + 1..].to_owned())) +} + +/// Log headers/trailers. +fn log_metadata(map: &MetadataMap, what: &'static str) { + for k_v in map.iter() { + match k_v { + tonic::metadata::KeyAndValueRef::Ascii(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + v.to_str().unwrap_or(""), + ); + } + tonic::metadata::KeyAndValueRef::Binary(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + String::from_utf8_lossy(v.as_ref()), + ); + } + } + } +} diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs new file mode 100644 index 000000000000..97d9899a9fb0 --- /dev/null +++ b/arrow-flight/src/client.rs @@ -0,0 +1,673 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + decode::FlightRecordBatchStream, + flight_service_client::FlightServiceClient, + gen::{CancelFlightInfoRequest, CancelFlightInfoResult, RenewFlightEndpointRequest}, + trailers::extract_lazy_trailers, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, + HandshakeRequest, PollInfo, PutResult, Ticket, +}; +use arrow_schema::Schema; +use bytes::Bytes; +use futures::{ + future::ready, + stream::{self, BoxStream}, + Stream, StreamExt, TryStreamExt, +}; +use prost::Message; +use tonic::{metadata::MetadataMap, transport::Channel}; + +use crate::error::{FlightError, Result}; +use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream}; + +/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client. +/// +/// [`FlightClient`] is intended as a convenience for interactions +/// with Arrow Flight servers. For more direct control, such as access +/// to the response headers, use [`FlightServiceClient`] directly +/// via methods such as [`Self::inner`] or [`Self::into_inner`]. +/// +/// # Example: +/// ```no_run +/// # async fn run() { +/// # use arrow_flight::FlightClient; +/// # use bytes::Bytes; +/// use tonic::transport::Channel; +/// let channel = Channel::from_static("http://localhost:1234") +/// .connect() +/// .await +/// .expect("error connecting"); +/// +/// let mut client = FlightClient::new(channel); +/// +/// // Send 'Hi' bytes as the handshake request to the server +/// let response = client +/// .handshake(Bytes::from("Hi")) +/// .await +/// .expect("error handshaking"); +/// +/// // Expect the server responded with 'Ho' +/// assert_eq!(response, Bytes::from("Ho")); +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightClient { + /// Optional grpc header metadata to include with each request + metadata: MetadataMap, + + /// The inner client + inner: FlightServiceClient, +} + +impl FlightClient { + /// Creates a client client with the provided [`Channel`] + pub fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + pub fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + metadata: MetadataMap::new(), + inner, + } + } + + /// Return a reference to gRPC metadata included with each request + pub fn metadata(&self) -> &MetadataMap { + &self.metadata + } + + /// Return a reference to gRPC metadata included with each request + /// + /// These headers can be used, for example, to include + /// authorization or other application specific headers. + pub fn metadata_mut(&mut self) -> &mut MetadataMap { + &mut self.metadata + } + + /// Add the specified header with value to all subsequent + /// requests. See [`Self::metadata_mut`] for fine grained control. + pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> { + let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes()) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + let value = value + .parse() + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + // ignore previous value + self.metadata.insert(key, value); + + Ok(()) + } + + /// Return a reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.inner + } + + /// Return a mutable reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.inner + } + + /// Consume this client and return the underlying tonic + /// [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { + self.inner + } + + /// Perform an Arrow Flight handshake with the server, sending + /// `payload` as the [`HandshakeRequest`] payload and returning + /// the [`HandshakeResponse`](crate::HandshakeResponse) + /// bytes returned from the server + /// + /// See [`FlightClient`] docs for an example. + pub async fn handshake(&mut self, payload: impl Into) -> Result { + let request = HandshakeRequest { + protocol_version: 0, + payload: payload.into(), + }; + + // apply headers, etc + let request = self.make_request(stream::once(ready(request))); + + let mut response_stream = self.inner.handshake(request).await?.into_inner(); + + if let Some(response) = response_stream.next().await.transpose()? { + // check if there is another response + if response_stream.next().await.is_some() { + return Err(FlightError::protocol( + "Got unexpected second response from handshake", + )); + } + + Ok(response.payload) + } else { + Err(FlightError::protocol("No response from handshake")) + } + } + + /// Make a `DoGet` call to the server with the provided ticket, + /// returning a [`FlightRecordBatchStream`] for reading + /// [`RecordBatch`](arrow_array::RecordBatch)es. + /// + /// # Note + /// + /// To access the returned [`FlightData`] use + /// [`FlightRecordBatchStream::into_inner()`] + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::Ticket; + /// # use arrow_array::RecordBatch; + /// # use futures::stream::TryStreamExt; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// # let ticket = Ticket { ticket: Bytes::from("foo") }; + /// let mut client = FlightClient::new(channel); + /// + /// // Invoke a do_get request on the server with a previously + /// // received Ticket + /// + /// let response = client + /// .do_get(ticket) + /// .await + /// .expect("error invoking do_get"); + /// + /// // Use try_collect to get the RecordBatches from the server + /// let batches: Vec = response + /// .try_collect() + /// .await + /// .expect("no stream errors"); + /// # } + /// ``` + pub async fn do_get(&mut self, ticket: Ticket) -> Result { + let request = self.make_request(ticket); + + let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); + + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) + } + + /// Make a `GetFlightInfo` call to the server with the provided + /// [`FlightDescriptor`] and return the [`FlightInfo`] from the + /// server. The [`FlightInfo`] can be used with [`Self::do_get`] + /// to retrieve the requested batches. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::FlightDescriptor; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let flight_info = client + /// .get_flight_info(request) + /// .await + /// .expect("error handshaking"); + /// + /// // retrieve the first endpoint from the returned flight info + /// let ticket = flight_info + /// .endpoint[0] + /// // Extract the ticket + /// .ticket + /// .clone() + /// .expect("expected ticket"); + /// + /// // Retrieve the corresponding RecordBatch stream with do_get + /// let data = client + /// .do_get(ticket) + /// .await + /// .expect("error fetching data"); + /// # } + /// ``` + pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result { + let request = self.make_request(descriptor); + + let response = self.inner.get_flight_info(request).await?.into_inner(); + Ok(response) + } + + /// Make a `PollFlightInfo` call to the server with the provided + /// [`FlightDescriptor`] and return the [`PollInfo`] from the + /// server. + /// + /// The `info` field of the [`PollInfo`] can be used with + /// [`Self::do_get`] to retrieve the requested batches. + /// + /// If the `flight_descriptor` field of the [`PollInfo`] is + /// `None` then the `info` field represents the complete results. + /// + /// If the `flight_descriptor` field is some [`FlightDescriptor`] + /// then the `info` field has incomplete results, and the client + /// should call this method again with the new `flight_descriptor` + /// to get the updated status. + /// + /// The `expiration_time`, if set, represents the expiration time + /// of the `flight_descriptor`, after which the server may not accept + /// this retry descriptor and may cancel the query. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::FlightDescriptor; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let poll_info = client + /// .poll_flight_info(request) + /// .await + /// .expect("error handshaking"); + /// + /// // retrieve the first endpoint from the returned poll info + /// let ticket = poll_info + /// .info + /// .expect("expected flight info") + /// .endpoint[0] + /// // Extract the ticket + /// .ticket + /// .clone() + /// .expect("expected ticket"); + /// + /// // Retrieve the corresponding RecordBatch stream with do_get + /// let data = client + /// .do_get(ticket) + /// .await + /// .expect("error fetching data"); + /// # } + /// ``` + pub async fn poll_flight_info(&mut self, descriptor: FlightDescriptor) -> Result { + let request = self.make_request(descriptor); + + let response = self.inner.poll_flight_info(request).await?.into_inner(); + Ok(response) + } + + /// Make a `DoPut` call to the server with the provided + /// [`Stream`] of [`FlightData`] and returning a + /// stream of [`PutResult`]. + /// + /// # Note + /// + /// The input stream is [`Result`] so that this can be connected + /// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder), + /// without having to buffer. If the input stream returns an error + /// that error will not be sent to the server, instead it will be + /// placed into the result stream and the server connection + /// terminated. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::{TryStreamExt, StreamExt}; + /// # use std::sync::Arc; + /// # use arrow_array::UInt64Array; + /// # use arrow_array::RecordBatch; + /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult}; + /// # use arrow_flight::encode::FlightDataEncoderBuilder; + /// # let batch = RecordBatch::try_from_iter(vec![ + /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _) + /// # ]).unwrap(); + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // encode the batch as a stream of `FlightData` + /// let flight_data_stream = FlightDataEncoderBuilder::new() + /// .build(futures::stream::iter(vec![Ok(batch)])); + /// + /// // send the stream and get the results as `PutResult` + /// let response: Vec= client + /// .do_put(flight_data_stream) + /// .await + /// .unwrap() + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error calling do_put"); + /// # } + /// ``` + pub async fn do_put> + Send + 'static>( + &mut self, + request: S, + ) -> Result>> { + let (sender, receiver) = futures::channel::oneshot::channel(); + + // Intercepts client errors and sends them to the oneshot channel above + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); + + let request = self.make_request(request_stream); + let response_stream = self.inner.do_put(request).await?.into_inner(); + + // Forwards errors from the error oneshot with priority over responses from server + let response_stream = Box::pin(response_stream); + let error_stream = FallibleTonicResponseStream::new(receiver, response_stream); + + // combine the response from the server and any error from the client + Ok(error_stream.boxed()) + } + + /// Make a `DoExchange` call to the server with the provided + /// [`Stream`] of [`FlightData`] and returning a + /// stream of [`FlightData`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::{TryStreamExt, StreamExt}; + /// # use std::sync::Arc; + /// # use arrow_array::UInt64Array; + /// # use arrow_array::RecordBatch; + /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult}; + /// # use arrow_flight::encode::FlightDataEncoderBuilder; + /// # let batch = RecordBatch::try_from_iter(vec![ + /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _) + /// # ]).unwrap(); + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // encode the batch as a stream of `FlightData` + /// let flight_data_stream = FlightDataEncoderBuilder::new() + /// .build(futures::stream::iter(vec![Ok(batch)])); + /// + /// // send the stream and get the results as `RecordBatches` + /// let response: Vec = client + /// .do_exchange(flight_data_stream) + /// .await + /// .unwrap() + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error calling do_exchange"); + /// # } + /// ``` + pub async fn do_exchange> + Send + 'static>( + &mut self, + request: S, + ) -> Result { + let (sender, receiver) = futures::channel::oneshot::channel(); + + let request = Box::pin(request); + // Intercepts client errors and sends them to the oneshot channel above + let request_stream = FallibleRequestStream::new(sender, request); + + let request = self.make_request(request_stream); + let response_stream = self.inner.do_exchange(request).await?.into_inner(); + + let response_stream = Box::pin(response_stream); + let error_stream = FallibleTonicResponseStream::new(receiver, response_stream); + + // combine the response from the server and any error from the client + Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) + } + + /// Make a `ListFlights` call to the server with the provided + /// criteria and returning a [`Stream`] of [`FlightInfo`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::TryStreamExt; + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightInfo, FlightClient}; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send 'Name=Foo' bytes as the "expression" to the server + /// // and gather the returned FlightInfo + /// let responses: Vec = client + /// .list_flights(Bytes::from("Name=Foo")) + /// .await + /// .expect("error listing flights") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering flights"); + /// # } + /// ``` + pub async fn list_flights( + &mut self, + expression: impl Into, + ) -> Result>> { + let request = Criteria { + expression: expression.into(), + }; + + let request = self.make_request(request); + + let response = self + .inner + .list_flights(request) + .await? + .into_inner() + .map_err(FlightError::Tonic); + + Ok(response.boxed()) + } + + /// Make a `GetSchema` call to the server with the provided + /// [`FlightDescriptor`] and returning the associated [`Schema`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightDescriptor, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Request the schema result of a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// + /// let schema: Schema = client + /// .get_schema(request) + /// .await + /// .expect("error making request"); + /// # } + /// ``` + pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result { + let request = self.make_request(flight_descriptor); + + let schema_result = self.inner.get_schema(request).await?.into_inner(); + + // attempt decode from IPC + let schema: Schema = schema_result.try_into()?; + + Ok(schema) + } + + /// Make a `ListActions` call to the server and returning a + /// [`Stream`] of [`ActionType`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::TryStreamExt; + /// # use arrow_flight::{ActionType, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // List available actions on the server: + /// let actions: Vec = client + /// .list_actions() + /// .await + /// .expect("error listing actions") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering actions"); + /// # } + /// ``` + pub async fn list_actions(&mut self) -> Result>> { + let request = self.make_request(Empty {}); + + let action_stream = self + .inner + .list_actions(request) + .await? + .into_inner() + .map_err(FlightError::Tonic); + + Ok(action_stream.boxed()) + } + + /// Make a `DoAction` call to the server and returning a + /// [`Stream`] of opaque [`Bytes`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use futures::TryStreamExt; + /// # use arrow_flight::{Action, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// let request = Action::new("my_action", "the body"); + /// + /// // Make a request to run the action on the server + /// let results: Vec = client + /// .do_action(request) + /// .await + /// .expect("error executing acton") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering action results"); + /// # } + /// ``` + pub async fn do_action(&mut self, action: Action) -> Result>> { + let request = self.make_request(action); + + let result_stream = self + .inner + .do_action(request) + .await? + .into_inner() + .map_err(FlightError::Tonic) + .map(|r| { + r.map(|r| { + // unwrap inner bytes + let crate::Result { body } = r; + body + }) + }); + + Ok(result_stream.boxed()) + } + + /// Make a `CancelFlightInfo` call to the server and return + /// a [`CancelFlightInfoResult`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::{CancelFlightInfoRequest, FlightClient, FlightDescriptor}; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let flight_info = client + /// .get_flight_info(request) + /// .await + /// .expect("error handshaking"); + /// + /// // Cancel the query + /// let request = CancelFlightInfoRequest::new(flight_info); + /// let result = client + /// .cancel_flight_info(request) + /// .await + /// .expect("error cancelling"); + /// # } + /// ``` + pub async fn cancel_flight_info( + &mut self, + request: CancelFlightInfoRequest, + ) -> Result { + let action = Action::new("CancelFlightInfo", request.encode_to_vec()); + let response = self.do_action(action).await?.try_next().await?; + let response = response.ok_or(FlightError::protocol( + "Received no response for cancel_flight_info call", + ))?; + CancelFlightInfoResult::decode(response) + .map_err(|e| FlightError::DecodeError(e.to_string())) + } + + /// Make a `RenewFlightEndpoint` call to the server and return + /// the renewed [`FlightEndpoint`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::{FlightClient, FlightDescriptor, RenewFlightEndpointRequest}; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let flight_endpoint = client + /// .get_flight_info(request) + /// .await + /// .expect("error handshaking") + /// .endpoint[0]; + /// + /// // Renew the endpoint + /// let request = RenewFlightEndpointRequest::new(flight_endpoint); + /// let flight_endpoint = client + /// .renew_flight_endpoint(request) + /// .await + /// .expect("error renewing"); + /// # } + /// ``` + pub async fn renew_flight_endpoint( + &mut self, + request: RenewFlightEndpointRequest, + ) -> Result { + let action = Action::new("RenewFlightEndpoint", request.encode_to_vec()); + let response = self.do_action(action).await?.try_next().await?; + let response = response.ok_or(FlightError::protocol( + "Received no response for renew_flight_endpoint call", + ))?; + FlightEndpoint::decode(response).map_err(|e| FlightError::DecodeError(e.to_string())) + } + + /// return a Request, adding any configured metadata + fn make_request(&self, t: T) -> tonic::Request { + // Pass along metadata + let mut request = tonic::Request::new(t); + *request.metadata_mut() = self.metadata.clone(); + request + } +} diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs new file mode 100644 index 000000000000..5561f256ce01 --- /dev/null +++ b/arrow-flight/src/decode.rs @@ -0,0 +1,434 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; +use arrow_schema::{Schema, SchemaRef}; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; +use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; +use tonic::metadata::MetadataMap; + +use crate::error::{FlightError, Result}; + +/// Decodes a [Stream] of [`FlightData`] back into +/// [`RecordBatch`]es. This can be used to decode the response from an +/// Arrow Flight server +/// +/// # Note +/// To access the lower level Flight messages (e.g. to access +/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`] +/// and use the [`FlightDataDecoder`] directly. +/// +/// # Example: +/// ```no_run +/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{ +/// # use bytes::Bytes; +/// // make a do_get request +/// use arrow_flight::{ +/// error::Result, +/// decode::FlightRecordBatchStream, +/// Ticket, +/// flight_service_client::FlightServiceClient +/// }; +/// use tonic::transport::Channel; +/// use futures::stream::{StreamExt, TryStreamExt}; +/// +/// let client: FlightServiceClient = // make client.. +/// # unimplemented!(); +/// +/// let request = tonic::Request::new( +/// Ticket { ticket: Bytes::new() } +/// ); +/// +/// // Get a stream of FlightData; +/// let flight_data_stream = client +/// .do_get(request) +/// .await? +/// .into_inner(); +/// +/// // Decode stream of FlightData to RecordBatches +/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( +/// // convert tonic::Status to FlightError +/// flight_data_stream.map_err(|e| e.into()) +/// ); +/// +/// // Read back RecordBatches +/// while let Some(batch) = record_batch_stream.next().await { +/// match batch { +/// Ok(batch) => { /* process batch */ }, +/// Err(e) => { /* handle error */ }, +/// }; +/// } +/// +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightRecordBatchStream { + /// Optional grpc header metadata. + headers: MetadataMap, + + /// Optional grpc trailer metadata. + trailers: Option, + + inner: FlightDataDecoder, +} + +impl FlightRecordBatchStream { + /// Create a new [`FlightRecordBatchStream`] from a decoded stream + pub fn new(inner: FlightDataDecoder) -> Self { + Self { + inner, + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`] + pub fn new_from_flight_data(inner: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + inner: FlightDataDecoder::new(inner), + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Record response headers. + pub fn with_headers(self, headers: MetadataMap) -> Self { + Self { headers, ..self } + } + + /// Record response trailers. + pub fn with_trailers(self, trailers: LazyTrailers) -> Self { + Self { + trailers: Some(trailers), + ..self + } + } + + /// Headers attached to this stream. + pub fn headers(&self) -> &MetadataMap { + &self.headers + } + + /// Trailers attached to this stream. + /// + /// Note that this will return `None` until the entire stream is consumed. + /// Only after calling `next()` returns `None`, might any available trailers be returned. + pub fn trailers(&self) -> Option { + self.trailers.as_ref().and_then(|trailers| trailers.get()) + } + + /// Has a message defining the schema been received yet? + #[deprecated = "use schema().is_some() instead"] + pub fn got_schema(&self) -> bool { + self.schema().is_some() + } + + /// Return schema for the stream, if it has been received + pub fn schema(&self) -> Option<&SchemaRef> { + self.inner.schema() + } + + /// Consume self and return the wrapped [`FlightDataDecoder`] + pub fn into_inner(self) -> FlightDataDecoder { + self.inner + } +} + +impl futures::Stream for FlightRecordBatchStream { + type Item = Result; + + /// Returns the next [`RecordBatch`] available in this stream, or `None` if + /// there are no further results available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + let had_schema = self.schema().is_some(); + let res = ready!(self.inner.poll_next_unpin(cx)); + match res { + // Inner exhausted + None => { + return Poll::Ready(None); + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + // translate data + Some(Ok(data)) => match data.payload { + DecodedPayload::Schema(_) if had_schema => { + return Poll::Ready(Some(Err(FlightError::protocol( + "Unexpectedly saw multiple Schema messages in FlightData stream", + )))); + } + DecodedPayload::Schema(_) => { + // Need next message, poll inner again + } + DecodedPayload::RecordBatch(batch) => { + return Poll::Ready(Some(Ok(batch))); + } + DecodedPayload::None => { + // Need next message + } + }, + } + } + } +} + +/// Wrapper around a stream of [`FlightData`] that handles the details +/// of decoding low level Flight messages into [`Schema`] and +/// [`RecordBatch`]es, including details such as dictionaries. +/// +/// # Protocol Details +/// +/// The client handles flight messages as followes: +/// +/// - **None:** This message has no effect. This is useful to +/// transmit metadata without any actual payload. +/// +/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and +/// the decoded schema is returned. +/// +/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing +/// dictionary for the same column will be overwritten. This +/// message is NOT visible. +/// +/// - **Record Batch:** Record batch is created based on the current +/// schema and dictionaries. This fails if no schema was transmitted +/// yet. +/// +/// All other message types (at the time of writing: e.g. tensor and +/// sparse tensor) lead to an error. +/// +/// Example usecases +/// +/// 1. Using this low level stream it is possible to receive a steam +/// of RecordBatches in FlightData that have different schemas by +/// handling multiple schema messages separately. +pub struct FlightDataDecoder { + /// Underlying data stream + response: BoxStream<'static, Result>, + /// Decoding state + state: Option, + /// Seen the end of the inner stream? + done: bool, +} + +impl Debug for FlightDataDecoder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FlightDataDecoder") + .field("response", &"") + .field("state", &self.state) + .field("done", &self.done) + .finish() + } +} + +impl FlightDataDecoder { + /// Create a new wrapper around the stream of [`FlightData`] + pub fn new(response: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + state: None, + response: response.boxed(), + done: false, + } + } + + /// Returns the current schema for this stream + pub fn schema(&self) -> Option<&SchemaRef> { + self.state.as_ref().map(|state| &state.schema) + } + + /// Extracts flight data from the next message, updating decoding + /// state as necessary. + fn extract_message(&mut self, data: FlightData) -> Result> { + use arrow_ipc::MessageHeader; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?; + + match message.header_type() { + MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), + MessageHeader::Schema => { + let schema = Schema::try_from(&data) + .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; + + let schema = Arc::new(schema); + let dictionaries_by_field = HashMap::new(); + + self.state = Some(FlightStreamState { + schema: Arc::clone(&schema), + dictionaries_by_field, + }); + Ok(Some(DecodedFlightData::new_schema(data, schema))) + } + MessageHeader::DictionaryBatch => { + let state = if let Some(state) = self.state.as_mut() { + state + } else { + return Err(FlightError::protocol( + "Received DictionaryBatch prior to Schema", + )); + }; + + let buffer = Buffer::from_bytes(data.data_body.into()); + let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { + FlightError::protocol( + "Could not get dictionary batch from DictionaryBatch message", + ) + })?; + + arrow_ipc::reader::read_dictionary( + &buffer, + dictionary_batch, + &state.schema, + &mut state.dictionaries_by_field, + &message.version(), + ) + .map_err(|e| { + FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}")) + })?; + + // Updated internal state, but no decoded message + Ok(None) + } + MessageHeader::RecordBatch => { + let state = if let Some(state) = self.state.as_ref() { + state + } else { + return Err(FlightError::protocol( + "Received RecordBatch prior to Schema", + )); + }; + + let batch = flight_data_to_arrow_batch( + &data, + Arc::clone(&state.schema), + &state.dictionaries_by_field, + ) + .map_err(|e| { + FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}")) + })?; + + Ok(Some(DecodedFlightData::new_record_batch(data, batch))) + } + other => { + let name = other.variant_name().unwrap_or("UNKNOWN"); + Err(FlightError::protocol(format!("Unexpected message: {name}"))) + } + } + } +} + +impl futures::Stream for FlightDataDecoder { + type Item = Result; + /// Returns the result of decoding the next [`FlightData`] message + /// from the server, or `None` if there are no further results + /// available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + loop { + let res = ready!(self.response.poll_next_unpin(cx)); + + return Poll::Ready(match res { + None => { + self.done = true; + None // inner is exhausted + } + Some(data) => Some(match data { + Err(e) => Err(e), + Ok(data) => match self.extract_message(data) { + Ok(Some(extracted)) => Ok(extracted), + Ok(None) => continue, // Need next input message + Err(e) => Err(e), + }, + }), + }); + } + } +} + +/// tracks the state needed to reconstruct [`RecordBatch`]es from a +/// streaming flight response. +#[derive(Debug)] +struct FlightStreamState { + schema: SchemaRef, + dictionaries_by_field: HashMap, +} + +/// FlightData and the decoded payload (Schema, RecordBatch), if any +#[derive(Debug)] +pub struct DecodedFlightData { + pub inner: FlightData, + pub payload: DecodedPayload, +} + +impl DecodedFlightData { + pub fn new_none(inner: FlightData) -> Self { + Self { + inner, + payload: DecodedPayload::None, + } + } + + pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self { + Self { + inner, + payload: DecodedPayload::Schema(schema), + } + } + + pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self { + Self { + inner, + payload: DecodedPayload::RecordBatch(batch), + } + } + + /// return the metadata field of the inner flight data + pub fn app_metadata(&self) -> Bytes { + self.inner.app_metadata.clone() + } +} + +/// The result of decoding [`FlightData`] +#[derive(Debug)] +pub enum DecodedPayload { + /// None (no data was sent in the corresponding FlightData) + None, + + /// A decoded Schema message + Schema(SchemaRef), + + /// A decoded Record batch. + RecordBatch(RecordBatch), +} diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs new file mode 100644 index 000000000000..59fa8afd58d5 --- /dev/null +++ b/arrow-flight/src/encode.rs @@ -0,0 +1,1626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; + +use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc}; + +use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray}; +use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; + +use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode}; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; + +/// Creates a [`Stream`] of [`FlightData`]s from a +/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>. +/// +/// This can be used to implement [`FlightService::do_get`] in an +/// Arrow Flight implementation; +/// +/// This structure encodes a stream of `Result`s rather than `RecordBatch`es to +/// propagate errors from streaming execution, where the generation of the +/// `RecordBatch`es is incremental, and an error may occur even after +/// several have already been successfully produced. +/// +/// # Caveats +/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`], +/// [`DictionaryArray`]s are converted to their underlying types prior to +/// transport. +/// When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every +/// [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray). +/// See . +/// +/// [`DictionaryArray`]: arrow_array::array::DictionaryArray +/// +/// # Example +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get an input stream of Result +/// let input_stream = futures::stream::iter(vec![Ok(batch)]); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// +/// // Create a tonic `Response` that can be returned from a Flight server +/// let response = tonic::Response::new(flight_data_stream); +/// # } +/// ``` +/// +/// # Example: Sending `Vec` +/// +/// You can create a [`Stream`] to pass to [`Self::build`] from an existing +/// `Vec` of `RecordBatch`es like this: +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # fn make_batches() -> Vec { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// # vec![batch.clone(), batch.clone()] +/// # } +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get batches that you want to send via Flight +/// let batches: Vec = make_batches(); +/// +/// // Create an input stream of Result +/// let input_stream = futures::stream::iter( +/// batches.into_iter().map(Ok) +/// ); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// # } +/// ``` +/// +/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get +/// [`FlightError`]: crate::error::FlightError +#[derive(Debug)] +pub struct FlightDataEncoderBuilder { + /// The maximum approximate target message size in bytes + /// (see details on [`Self::with_max_flight_data_size`]). + max_flight_data_size: usize, + /// Ipc writer options + options: IpcWriteOptions, + /// Metadata to add to the schema message + app_metadata: Bytes, + /// Optional schema, if known before data. + schema: Option, + /// Optional flight descriptor, if known before data. + descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, +} + +/// Default target size for encoded [`FlightData`]. +/// +/// Note this value would normally be 4MB, but the size calculation is +/// somewhat inexact, so we set it to 2MB. +pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152; + +impl Default for FlightDataEncoderBuilder { + fn default() -> Self { + Self { + max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES, + options: IpcWriteOptions::default(), + app_metadata: Bytes::new(), + schema: None, + descriptor: None, + dictionary_handling: DictionaryHandling::Hydrate, + } + } +} + +impl FlightDataEncoderBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Set the (approximate) maximum size, in bytes, of the + /// [`FlightData`] produced by this encoder. Defaults to 2MB. + /// + /// Since there is often a maximum message size for gRPC messages + /// (typically around 4MB), this encoder splits up [`RecordBatch`]s + /// (preserving order) into multiple [`FlightData`] objects to + /// limit the size individual messages sent via gRPC. + /// + /// The size is approximate because of the additional encoding + /// overhead on top of the underlying data buffers themselves. + pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self { + self.max_flight_data_size = max_flight_data_size; + self + } + + /// Set [`DictionaryHandling`] for encoder + pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self { + self.dictionary_handling = dictionary_handling; + self + } + + /// Specify application specific metadata included in the + /// [`FlightData::app_metadata`] field of the the first Schema + /// message + pub fn with_metadata(mut self, app_metadata: Bytes) -> Self { + self.app_metadata = app_metadata; + self + } + + /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for transport. + pub fn with_options(mut self, options: IpcWriteOptions) -> Self { + self.options = options; + self + } + + /// Specify a schema for the RecordBatches being sent. If a schema + /// is not specified, an encoded Schema message will be sent when + /// the first [`RecordBatch`], if any, is encoded. Some clients + /// expect a Schema message even if there is no data sent. + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + /// Specify a flight descriptor in the first FlightData message. + pub fn with_flight_descriptor(mut self, descriptor: Option) -> Self { + self.descriptor = descriptor; + self + } + + /// Takes a [`Stream`] of [`Result`] and returns a [`Stream`] + /// of [`FlightData`], consuming self. + /// + /// See example on [`Self`] and [`FlightDataEncoder`] for more details + pub fn build(self, input: S) -> FlightDataEncoder + where + S: Stream> + Send + 'static, + { + let Self { + max_flight_data_size, + options, + app_metadata, + schema, + descriptor, + dictionary_handling, + } = self; + + FlightDataEncoder::new( + input.boxed(), + schema, + max_flight_data_size, + options, + app_metadata, + descriptor, + dictionary_handling, + ) + } +} + +/// Stream that encodes a stream of record batches to flight data. +/// +/// See [`FlightDataEncoderBuilder`] for details and example. +pub struct FlightDataEncoder { + /// Input stream + inner: BoxStream<'static, Result>, + /// schema, set after the first batch + schema: Option, + /// Target maximum size of flight data + /// (see details on [`FlightDataEncoderBuilder::with_max_flight_data_size`]). + max_flight_data_size: usize, + /// do the encoding / tracking of dictionaries + encoder: FlightIpcEncoder, + /// optional metadata to add to schema FlightData + app_metadata: Option, + /// data queued up to send but not yet sent + queue: VecDeque, + /// Is this stream done (inner is empty or errored) + done: bool, + /// cleared after the first FlightData message is sent + descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, +} + +impl FlightDataEncoder { + fn new( + inner: BoxStream<'static, Result>, + schema: Option, + max_flight_data_size: usize, + options: IpcWriteOptions, + app_metadata: Bytes, + descriptor: Option, + dictionary_handling: DictionaryHandling, + ) -> Self { + let mut encoder = Self { + inner, + schema: None, + max_flight_data_size, + encoder: FlightIpcEncoder::new( + options, + dictionary_handling != DictionaryHandling::Resend, + ), + app_metadata: Some(app_metadata), + queue: VecDeque::new(), + done: false, + descriptor, + dictionary_handling, + }; + + // If schema is known up front, enqueue it immediately + if let Some(schema) = schema { + encoder.encode_schema(&schema); + } + + encoder + } + + /// Place the `FlightData` in the queue to send + fn queue_message(&mut self, mut data: FlightData) { + if let Some(descriptor) = self.descriptor.take() { + data.flight_descriptor = Some(descriptor); + } + self.queue.push_back(data); + } + + /// Place the `FlightData` in the queue to send + fn queue_messages(&mut self, datas: impl IntoIterator) { + for data in datas { + self.queue_message(data) + } + } + + /// Encodes schema as a [`FlightData`] in self.queue. + /// Updates `self.schema` and returns the new schema + fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef { + // The first message is the schema message, and all + // batches have the same schema + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let schema = Arc::new(prepare_schema_for_flight( + schema, + &mut self.encoder.dictionary_tracker, + send_dictionaries, + )); + let mut schema_flight_data = self.encoder.encode_schema(&schema); + + // attach any metadata requested + if let Some(app_metadata) = self.app_metadata.take() { + schema_flight_data.app_metadata = app_metadata; + } + self.queue_message(schema_flight_data); + // remember schema + self.schema = Some(schema.clone()); + schema + } + + /// Encodes batch into one or more `FlightData` messages in self.queue + fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> { + let schema = match &self.schema { + Some(schema) => schema.clone(), + // encode the schema if this is the first time we have seen it + None => self.encode_schema(batch.schema_ref()), + }; + + let batch = match self.dictionary_handling { + DictionaryHandling::Resend => batch, + DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, + }; + + for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { + let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; + + self.queue_messages(flight_dictionaries); + self.queue_message(flight_batch); + } + + Ok(()) + } +} + +impl Stream for FlightDataEncoder { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + if self.done && self.queue.is_empty() { + return Poll::Ready(None); + } + + // Any messages queued to send? + if let Some(data) = self.queue.pop_front() { + return Poll::Ready(Some(Ok(data))); + } + + // Get next batch + let batch = ready!(self.inner.poll_next_unpin(cx)); + + match batch { + None => { + // inner is done + self.done = true; + // queue must also be empty so we are done + assert!(self.queue.is_empty()); + return Poll::Ready(None); + } + Some(Err(e)) => { + // error from inner + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + Some(Ok(batch)) => { + // had data, encode into the queue + if let Err(e) = self.encode_batch(batch) { + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + } + } + } + } +} + +/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s +/// +/// [`DictionaryArray`]: arrow_array::DictionaryArray +/// +/// In the arrow flight protocol dictionary values and keys are sent as two separate messages. +/// When a sender is encoding a [`RecordBatch`] containing ['DictionaryArray'] columns, it will +/// first send a dictionary batch (a batch with header `MessageHeader::DictionaryBatch`) containing +/// the dictionary values. The receiver is responsible for reading this batch and maintaining state that associates +/// those dictionary values with the corresponding array using the `dict_id` as a key. +/// +/// After sending the dictionary batch the sender will send the array data in a batch with header `MessageHeader::RecordBatch`. +/// For any dictionary array batches in this message, the encoded flight message will only contain the dictionary keys. The receiver +/// is then responsible for rebuilding the `DictionaryArray` on the client side using the dictionary values from the DictionaryBatch message +/// and the keys from the RecordBatch message. +/// +/// For example, if we have a batch with a `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` (a dictionary array where they keys are `u32` and the +/// values are `String`), then the DictionaryBatch will contain a `StringArray` and the RecordBatch will contain a `UInt32Array`. +/// +/// Note that since `dict_id` defined in the `Schema` is used as a key to associate dictionary values to their arrays it is required that each +/// `DictionaryArray` in a `RecordBatch` have a unique `dict_id`. +/// +/// The current implementation does not support "delta" dictionaries so a new dictionary batch will be sent each time the encoder sees a +/// dictionary which is not pointer-equal to the previously observed dictionary for a given `dict_id`. +/// +/// For clients which may not support `DictionaryEncoding`, the `DictionaryHandling::Hydrate` method will bypass the process defined above +/// and "hydrate" any `DictionaryArray` in the batch to their underlying value type (e.g. `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` will +/// be sent as a `StringArray`). With this method all data will be sent in ``MessageHeader::RecordBatch` messages and the batch schema +/// will be adjusted so that all dictionary encoded fields are changed to fields of the dictionary value type. +#[derive(Debug, PartialEq)] +pub enum DictionaryHandling { + /// Expands to the underlying type (default). This likely sends more data + /// over the network but requires less memory (dictionaries are not tracked) + /// and is more compatible with other arrow flight client implementations + /// that may not support `DictionaryEncoding` + /// + /// See also: + /// * + Hydrate, + /// Send dictionary FlightData with every RecordBatch that contains a + /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No + /// attempt is made to skip sending the same (logical) dictionary values + /// twice. + /// + /// [`DictionaryArray`]: arrow_array::DictionaryArray + /// + /// This requires identifying the different dictionaries in use and assigning + // them unique IDs + Resend, +} + +fn prepare_field_for_flight( + field: &FieldRef, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Field { + match field.data_type() { + DataType::List(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::LargeList(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::Struct(fields) => { + let new_fields: Vec = fields + .iter() + .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries)) + .collect(); + Field::new_struct(field.name(), new_fields, field.is_nullable()) + .with_metadata(field.metadata().clone()) + } + DataType::Union(fields, mode) => { + let (type_ids, new_fields): (Vec, Vec) = fields + .iter() + .map(|(type_id, f)| { + ( + type_id, + prepare_field_for_flight(f, dictionary_tracker, send_dictionaries), + ) + }) + .unzip(); + + Field::new_union(field.name(), type_ids, new_fields, *mode) + } + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + _ => field.as_ref().clone(), + } +} + +/// Prepare an arrow Schema for transport over the Arrow Flight protocol +/// +/// Convert dictionary types to underlying types +/// +/// See hydrate_dictionary for more information +fn prepare_schema_for_flight( + schema: &Schema, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Schema { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + tpe if tpe.is_nested() => { + prepare_field_for_flight(field, dictionary_tracker, send_dictionaries) + } + _ => field.as_ref().clone(), + }) + .collect(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +/// Split [`RecordBatch`] so it hopefully fits into a gRPC response. +/// +/// Data is zero-copy sliced into batches. +/// +/// Note: this method does not take into account already sliced +/// arrays: +fn split_batch_for_grpc_response( + batch: RecordBatch, + max_flight_data_size: usize, +) -> Vec { + let size = batch + .columns() + .iter() + .map(|col| col.get_buffer_memory_size()) + .sum::(); + + let n_batches = + (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1); + let rows_per_batch = (batch.num_rows() / n_batches).max(1); + let mut out = Vec::with_capacity(n_batches + 1); + + let mut offset = 0; + while offset < batch.num_rows() { + let length = (rows_per_batch).min(batch.num_rows() - offset); + out.push(batch.slice(offset, length)); + + offset += length; + } + + out +} + +/// The data needed to encode a stream of flight data, holding on to +/// shared Dictionaries. +/// +/// TODO: at allow dictionaries to be flushed / avoid building them +/// +/// TODO limit on the number of dictionaries??? +struct FlightIpcEncoder { + options: IpcWriteOptions, + data_gen: IpcDataGenerator, + dictionary_tracker: DictionaryTracker, +} + +impl FlightIpcEncoder { + fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { + let preserve_dict_id = options.preserve_dict_id(); + Self { + options, + data_gen: IpcDataGenerator::default(), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + error_on_replacement, + preserve_dict_id, + ), + } + } + + /// Encode a schema as a FlightData + fn encode_schema(&self, schema: &Schema) -> FlightData { + SchemaAsIpc::new(schema, &self.options).into() + } + + /// Convert a `RecordBatch` to a Vec of `FlightData` representing + /// dictionaries and a `FlightData` representing the batch + fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { + let (encoded_dictionaries, encoded_batch) = + self.data_gen + .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + Ok((flight_dictionaries, flight_batch)) + } +} + +/// Hydrates any dictionaries arrays in `batch` to its underlying type. See +/// hydrate_dictionary for more information. +fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result { + let columns = schema + .fields() + .iter() + .zip(batch.columns()) + .map(|(field, c)| hydrate_dictionary(c, field.data_type())) + .collect::>>()?; + + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + Ok(RecordBatch::try_new_with_options( + schema, columns, &options, + )?) +} + +/// Hydrates a dictionary to its underlying type. +fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result { + let arr = match (array.data_type(), data_type) { + (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => { + let union_arr = array.as_any().downcast_ref::().unwrap(); + + Arc::new(UnionArray::try_new( + fields.clone(), + union_arr.type_ids().clone(), + None, + fields + .iter() + .map(|(type_id, field)| { + Ok(arrow_cast::cast( + union_arr.child(type_id), + field.data_type(), + )?) + }) + .collect::>>()?, + )?) + } + (_, data_type) => arrow_cast::cast(array, data_type)?, + }; + Ok(arr) +} + +#[cfg(test)] +mod tests { + use crate::decode::{DecodedPayload, FlightDataDecoder}; + use arrow_array::builder::{ + GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder, + }; + use arrow_array::*; + use arrow_array::{cast::downcast_array, types::*}; + use arrow_buffer::ScalarBuffer; + use arrow_cast::pretty::pretty_format_batches; + use arrow_ipc::MetadataVersion; + use arrow_schema::{UnionFields, UnionMode}; + use std::collections::HashMap; + + use super::*; + + #[test] + /// ensure only the batch's used data (not the allocated data) is sent + /// + fn test_encode_flight_data() { + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); + let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) + .expect("cannot create record batch"); + let schema = batch.schema_ref(); + + let (_, baseline_flight_batch) = make_flight_data(&batch, &options); + + let big_batch = batch.slice(0, batch.num_rows() - 1); + let optimized_big_batch = + hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize"); + let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); + + assert_eq!( + baseline_flight_batch.data_body.len(), + optimized_big_flight_batch.data_body.len() + ); + + let small_batch = batch.slice(0, 1); + let optimized_small_batch = + hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize"); + let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); + + assert!( + baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len() + ); + } + + #[tokio::test] + async fn test_dictionary_hydration() { + let arr1: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); + let expected_schema = Arc::new(expected_schema); + let mut expected_arrays = vec![ + StringArray::from(vec!["a", "a", "b"]), + StringArray::from(vec!["c", "c", "d"]), + ] + .into_iter(); + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let actual_array = b.column_by_name("dict").unwrap(); + let actual_array = downcast_array::(actual_array); + + assert_eq!(actual_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_resend() { + let arr1: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_multiple_dictionaries_resend() { + // Create a schema with two dictionary fields that have the same dict ID + let schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false), + Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false), + ])); + + let arr_one_1: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_one_2: Arc> = + Arc::new(vec!["c", "c", "d"].into_iter().collect()); + let arr_two_1: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let arr_two_2: Arc> = + Arc::new(vec!["k", "d", "e"].into_iter().collect()); + let batch1 = + RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()]) + .unwrap(); + let batch2 = + RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()]) + .unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_dictionary_list_hydration() { + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + ] + .into_iter(); + + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let list_array = + downcast_array::(b.column_by_name("dict_list").unwrap()); + let elem_array = downcast_array::(list_array.value(0).as_ref()); + + assert_eq!(elem_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_list_resend() { + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_dictionary_struct_hydration() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); + + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("a"), None, Some("b")]); + + struct_builder.append(true); + + let arr1 = struct_builder.finish(); + + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "struct", + struct_fields, + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new_struct( + "struct", + vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )], + true, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + ] + .into_iter(); + + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let struct_array = + downcast_array::(b.column_by_name("struct").unwrap()); + let list_array = downcast_array::(struct_array.column(0)); + + let elem_array = downcast_array::(list_array.value(0).as_ref()); + + assert_eq!(elem_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_struct_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("a"), None, Some("b")]); + struct_builder.append(true); + + let arr1 = struct_builder.finish(); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "struct", + struct_fields, + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_dictionary_union_hydration() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let union_fields = [ + ( + 0, + Arc::new(Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )), + ), + ( + 1, + Arc::new(Field::new_struct("struct", struct_fields.clone(), true)), + ), + (2, Arc::new(Field::new("string", DataType::Utf8, true))), + ] + .into_iter() + .collect::(); + + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + let type_id_buffer = [0].into_iter().collect::>(); + let arr1 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + Arc::new(arr1) as Arc, + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let type_id_buffer = [1].into_iter().collect::>(); + let arr2 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + Arc::new(arr2), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + let type_id_buffer = [2].into_iter().collect::>(); + let arr3 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + Arc::new(StringArray::from(vec!["e"])), + ], + ) + .unwrap(); + + let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields + .iter() + .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone())) + .unzip(); + let schema = Arc::new(Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + union_fields.clone(), + UnionMode::Sparse, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + + let hydrated_struct_fields = vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )]; + + let hydrated_union_fields = vec![ + Field::new_list("dict_list", Field::new("item", DataType::Utf8, true), true), + Field::new_struct("struct", hydrated_struct_fields.clone(), true), + Field::new("string", DataType::Utf8, true), + ]; + + let expected_schema = Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + hydrated_union_fields, + UnionMode::Sparse, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + StringArray::from(vec!["e"]), + ] + .into_iter(); + + let mut batch = 0; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let union_arr = + downcast_array::(b.column_by_name("union").unwrap()); + + let elem_array = match batch { + 0 => { + let list_array = downcast_array::(union_arr.child(0)); + downcast_array::(list_array.value(0).as_ref()) + } + 1 => { + let struct_array = downcast_array::(union_arr.child(1)); + let list_array = downcast_array::(struct_array.column(0)); + + downcast_array::(list_array.value(0).as_ref()) + } + _ => downcast_array::(union_arr.child(2)), + }; + + batch += 1; + + assert_eq!(elem_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_union_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let union_fields = [ + ( + 0, + Arc::new(Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )), + ), + ( + 1, + Arc::new(Field::new_struct("struct", struct_fields.clone(), true)), + ), + (2, Arc::new(Field::new("string", DataType::Utf8, true))), + ] + .into_iter() + .collect::(); + + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + let type_id_buffer = [0].into_iter().collect::>(); + let arr1 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + Arc::new(arr1) as Arc, + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let type_id_buffer = [1].into_iter().collect::>(); + let arr2 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + Arc::new(arr2), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + let type_id_buffer = [2].into_iter().collect::>(); + let arr3 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + Arc::new(StringArray::from(vec!["e"])), + ], + ) + .unwrap(); + + let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields + .iter() + .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone())) + .unzip(); + let schema = Arc::new(Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + union_fields.clone(), + UnionMode::Sparse, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2, batch3]).await; + } + + async fn verify_flight_round_trip(mut batches: Vec) { + let expected_schema = batches.first().unwrap().schema(); + + let encoder = FlightDataEncoderBuilder::default() + .with_options(IpcWriteOptions::default().with_preserve_dict_id(false)) + .with_dictionary_handling(DictionaryHandling::Resend) + .build(futures::stream::iter(batches.clone().into_iter().map(Ok))); + + let mut expected_batches = batches.drain(..); + + let mut decoder = FlightDataDecoder::new(encoder); + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + let expected_batch = expected_batches.next().unwrap(); + assert_eq!(b, expected_batch); + } + } + } + } + + #[test] + fn test_schema_metadata_encoded() { + let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata( + HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), + ); + + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); + + let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); + assert!(got.metadata().contains_key("some_key")); + } + + #[test] + fn test_encode_no_column_batch() { + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(10)), + ) + .expect("cannot create record batch"); + + hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize"); + } + + pub fn make_flight_data( + batch: &RecordBatch, + options: &IpcWriteOptions, + ) -> (Vec, FlightData) { + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) + } + + #[test] + fn test_split_batch_for_grpc_response() { + let max_flight_data_size = 1024; + + // no split + let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + assert_eq!(split.len(), 1); + assert_eq!(batch, split[0]); + + // split once + let n_rows = max_flight_data_size + 1; + assert!(n_rows % 2 == 1, "should be an odd number"); + let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + assert_eq!(split.len(), 3); + assert_eq!( + split.iter().map(|batch| batch.num_rows()).sum::(), + n_rows + ); + let a = pretty_format_batches(&split).unwrap().to_string(); + let b = pretty_format_batches(&[batch]).unwrap().to_string(); + assert_eq!(a, b); + } + + #[test] + fn test_split_batch_for_grpc_response_sizes() { + // 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows + verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]); + + // 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows + verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]); + + // 2023 8 byte entries into 3k pieces does not divide evenly + verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]); + + // 10 8 byte entries into 1 byte pieces means each rows gets its own + verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]); + + // 10 8 byte entries into 1k byte pieces means one piece + verify_split(10, 1024, vec![10]); + } + + /// Creates a UInt64Array of 8 byte integers with input_rows rows + /// `max_flight_data_size_bytes` pieces and verifies the row counts in + /// those pieces + fn verify_split( + num_input_rows: u64, + max_flight_data_size_bytes: usize, + expected_sizes: Vec, + ) { + let array: UInt64Array = (0..num_input_rows).collect(); + + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]) + .expect("cannot create record batch"); + + let input_rows = batch.num_rows(); + + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); + let sizes: Vec<_> = split.iter().map(|batch| batch.num_rows()).collect(); + let output_rows: usize = sizes.iter().sum(); + + assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}"); + assert_eq!(input_rows, output_rows, "mismatch for {batch:?}"); + } + + // test sending record batches + // test sending record batches with multiple different dictionaries + + #[tokio::test] + async fn flight_data_size_even() { + let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024)); + let i1 = Int16Array::from_iter_values(0..1024); + let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024)); + let i2 = Int64Array::from_iter_values(0..1024); + + let batch = RecordBatch::try_from_iter(vec![ + ("s1", Arc::new(s1) as _), + ("i1", Arc::new(i1) as _), + ("s2", Arc::new(s2) as _), + ("i2", Arc::new(i2) as _), + ]) + .unwrap(); + + verify_encoded_split(batch, 112).await; + } + + #[tokio::test] + async fn flight_data_size_uneven_variable_lengths() { + // each row has a longer string than the last with increasing lengths 0 --> 1024 + let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i))); + let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 4304).await; + } + + #[tokio::test] + async fn flight_data_size_large_row() { + // batch with individual that can each exceed the batch size + let array1 = StringArray::from_iter_values(vec![ + "*".repeat(500), + "*".repeat(500), + "*".repeat(500), + "*".repeat(500), + ]); + let array2 = StringArray::from_iter_values(vec![ + "*".to_string(), + "*".repeat(1000), + "*".repeat(2000), + "*".repeat(4000), + ]); + + let array3 = StringArray::from_iter_values(vec![ + "*".to_string(), + "*".to_string(), + "*".repeat(1000), + "*".repeat(2000), + ]); + + let batch = RecordBatch::try_from_iter(vec![ + ("a1", Arc::new(array1) as _), + ("a2", Arc::new(array2) as _), + ("a3", Arc::new(array3) as _), + ]) + .unwrap(); + + // 5k over limit (which is 2x larger than limit of 5k) + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 5800).await; + } + + #[tokio::test] + async fn flight_data_size_string_dictionary() { + // Small dictionary (only 2 distinct values ==> 2 entries in dictionary) + let array: DictionaryArray = (1..1024) + .map(|i| match i % 3 { + 0 => Some("value0"), + 1 => Some("value1"), + _ => None, + }) + .collect(); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + verify_encoded_split(batch, 160).await; + } + + #[tokio::test] + async fn flight_data_size_large_dictionary() { + // large dictionary (all distinct values ==> 1024 entries in dictionary) + let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect(); + + let array: DictionaryArray = values.iter().map(|s| Some(s.as_str())).collect(); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 3328).await; + } + + #[tokio::test] + async fn flight_data_size_large_dictionary_repeated_non_uniform() { + // large dictionary (1024 distinct values) that are used throughout the array + let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i))); + let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024)); + let array = DictionaryArray::new(keys, Arc::new(values)); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 5280).await; + } + + #[tokio::test] + async fn flight_data_size_multiple_dictionaries() { + // high cardinality + let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect(); + // highish cardinality + let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect(); + // medium cardinality + let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect(); + + let array1: DictionaryArray = values1.iter().map(|s| Some(s.as_str())).collect(); + let array2: DictionaryArray = values2.iter().map(|s| Some(s.as_str())).collect(); + let array3: DictionaryArray = values3.iter().map(|s| Some(s.as_str())).collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("a1", Arc::new(array1) as _), + ("a2", Arc::new(array2) as _), + ("a3", Arc::new(array3) as _), + ]) + .unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 4128).await; + } + + /// Return size, in memory of flight data + fn flight_data_size(d: &FlightData) -> usize { + let flight_descriptor_size = d + .flight_descriptor + .as_ref() + .map(|descriptor| { + let path_len: usize = descriptor.path.iter().map(|p| p.as_bytes().len()).sum(); + + std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len + }) + .unwrap_or(0); + + flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len() + } + + /// Coverage for + /// + /// Encodes the specified batch using several values of + /// `max_flight_data_size` between 1K to 5K and ensures that the + /// resulting size of the flight data stays within the limit + /// + `allowed_overage` + /// + /// `allowed_overage` is how far off the actual data encoding is + /// from the target limit that was set. It is an improvement when + /// the allowed_overage decreses. + /// + /// Note this overhead will likely always be greater than zero to + /// account for encoding overhead such as IPC headers and padding. + /// + /// + async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) { + let num_rows = batch.num_rows(); + + // Track the overall required maximum overage + let mut max_overage_seen = 0; + + for max_flight_data_size in [1024, 2021, 5000] { + println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}"); + + let mut stream = FlightDataEncoderBuilder::new() + .with_max_flight_data_size(max_flight_data_size) + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) + .build(futures::stream::iter([Ok(batch.clone())])); + + let mut i = 0; + while let Some(data) = stream.next().await.transpose().unwrap() { + let actual_data_size = flight_data_size(&data); + + let actual_overage = if actual_data_size > max_flight_data_size { + actual_data_size - max_flight_data_size + } else { + 0 + }; + + assert!( + actual_overage <= allowed_overage, + "encoded data[{i}]: actual size {actual_data_size}, \ + actual_overage: {actual_overage} \ + allowed_overage: {allowed_overage}" + ); + + i += 1; + + max_overage_seen = max_overage_seen.max(actual_overage) + } + } + + // ensure that the specified overage is exactly the maxmium so + // that when the splitting logic improves, the tests must be + // updated to reflect the better logic + assert_eq!( + allowed_overage, max_overage_seen, + "Specified overage was too high" + ); + } +} diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs new file mode 100644 index 000000000000..ba979ca9f7a6 --- /dev/null +++ b/arrow-flight/src/error.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::error::Error; + +use arrow_schema::ArrowError; + +/// Errors for the Apache Arrow Flight crate +#[derive(Debug)] +pub enum FlightError { + /// Underlying arrow error + Arrow(ArrowError), + /// Returned when functionality is not yet available. + NotYetImplemented(String), + /// Error from the underlying tonic library + Tonic(tonic::Status), + /// Some unexpected message was received + ProtocolError(String), + /// An error occurred during decoding + DecodeError(String), + /// External error that can provide source of error by calling `Error::source`. + ExternalError(Box), +} + +impl FlightError { + pub fn protocol(message: impl Into) -> Self { + Self::ProtocolError(message.into()) + } + + /// Wraps an external error in an `ArrowError`. + pub fn from_external_error(error: Box) -> Self { + Self::ExternalError(error) + } +} + +impl std::fmt::Display for FlightError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FlightError::Arrow(source) => write!(f, "Arrow error: {}", source), + FlightError::NotYetImplemented(desc) => write!(f, "Not yet implemented: {}", desc), + FlightError::Tonic(source) => write!(f, "Tonic error: {}", source), + FlightError::ProtocolError(desc) => write!(f, "Protocol error: {}", desc), + FlightError::DecodeError(desc) => write!(f, "Decode error: {}", desc), + FlightError::ExternalError(source) => write!(f, "External error: {}", source), + } + } +} + +impl Error for FlightError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + FlightError::Arrow(source) => Some(source), + FlightError::Tonic(source) => Some(source), + FlightError::ExternalError(source) => Some(source.as_ref()), + _ => None, + } + } +} + +impl From for FlightError { + fn from(status: tonic::Status) -> Self { + Self::Tonic(status) + } +} + +impl From for FlightError { + fn from(value: ArrowError) -> Self { + Self::Arrow(value) + } +} + +// default conversion from FlightError to tonic treats everything +// other than `Status` as an internal error +impl From for tonic::Status { + fn from(value: FlightError) -> Self { + match value { + FlightError::Arrow(e) => tonic::Status::internal(e.to_string()), + FlightError::NotYetImplemented(e) => tonic::Status::internal(e), + FlightError::Tonic(status) => status, + FlightError::ProtocolError(e) => tonic::Status::internal(e), + FlightError::DecodeError(e) => tonic::Status::internal(e), + FlightError::ExternalError(e) => tonic::Status::internal(e.to_string()), + } + } +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn error_source() { + let e1 = FlightError::DecodeError("foo".into()); + assert!(e1.source().is_none()); + + // one level of wrapping + let e2 = FlightError::ExternalError(Box::new(e1)); + let source = e2.source().unwrap().downcast_ref::().unwrap(); + assert!(matches!(source, FlightError::DecodeError(_))); + + let e3 = FlightError::ExternalError(Box::new(e2)); + let source = e3 + .source() + .unwrap() + .downcast_ref::() + .unwrap() + .source() + .unwrap() + .downcast_ref::() + .unwrap(); + + assert!(matches!(source, FlightError::DecodeError(_))); + } + + #[test] + fn error_through_arrow() { + // flight error that wraps an arrow error that wraps a flight error + let e1 = FlightError::DecodeError("foo".into()); + let e2 = ArrowError::ExternalError(Box::new(e1)); + let e3 = FlightError::ExternalError(Box::new(e2)); + + // ensure we can find the lowest level error by following source() + let mut root_error: &dyn Error = &e3; + while let Some(source) = root_error.source() { + // walk the next level + root_error = source; + } + + let source = root_error.downcast_ref::().unwrap(); + assert!(matches!(source, FlightError::DecodeError(_))); + } +} diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 3f4f09855353..64e3ba01c5bd 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -15,43 +15,87 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::Schema; -use arrow::error::{ArrowError, Result as ArrowResult}; -use arrow::ipc::{ - convert, size_prefixed_root_as_message, writer, writer::EncodedData, - writer::IpcWriteOptions, -}; - -use std::{ - convert::{TryFrom, TryInto}, - fmt, - ops::Deref, -}; - -#[allow(clippy::derive_partial_eq_without_eq)] +//! A native Rust implementation of [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) +//! for exchanging [Arrow](https://arrow.apache.org) data between processes. +//! +//! Please see the [arrow-flight crates.io](https://crates.io/crates/arrow-flight) +//! page for feature flags and more information. +//! +//! # Overview +//! +//! This crate contains: +//! +//! 1. Low level [prost] generated structs +//! for Flight gRPC protobuf messages, such as [`FlightData`], [`FlightInfo`], +//! [`Location`] and [`Ticket`]. +//! +//! 2. Low level [tonic] generated [`flight_service_client`] and +//! [`flight_service_server`]. +//! +//! 3. Experimental support for [Flight SQL] in [`sql`]. Requires the +//! `flight-sql-experimental` feature of this crate to be activated. +//! +//! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html +#![allow(rustdoc::invalid_html_tags)] + +use arrow_ipc::{convert, writer, writer::EncodedData, writer::IpcWriteOptions}; +use arrow_schema::{ArrowError, Schema}; + +use arrow_ipc::convert::try_schema_from_ipc_buffer; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; +use prost_types::Timestamp; +use std::{fmt, ops::Deref}; + +type ArrowResult = std::result::Result; + +#[allow(clippy::all)] mod gen { include!("arrow.flight.protocol.rs"); } +/// Defines a `Flight` for generation or retrieval. pub mod flight_descriptor { use super::gen; pub use gen::flight_descriptor::DescriptorType; } +/// Low Level [tonic] [`FlightServiceClient`](gen::flight_service_client::FlightServiceClient). pub mod flight_service_client { use super::gen; pub use gen::flight_service_client::FlightServiceClient; } +/// Low Level [tonic] [`FlightServiceServer`](gen::flight_service_server::FlightServiceServer) +/// and [`FlightService`](gen::flight_service_server::FlightService). pub mod flight_service_server { use super::gen; pub use gen::flight_service_server::FlightService; pub use gen::flight_service_server::FlightServiceServer; } +/// Mid Level [`FlightClient`] +pub mod client; +pub use client::FlightClient; + +/// Decoder to create [`RecordBatch`](arrow_array::RecordBatch) streams from [`FlightData`] streams. +/// See [`FlightRecordBatchStream`](decode::FlightRecordBatchStream). +pub mod decode; + +/// Encoder to create [`FlightData`] streams from [`RecordBatch`](arrow_array::RecordBatch) streams. +/// See [`FlightDataEncoderBuilder`](encode::FlightDataEncoderBuilder). +pub mod encode; + +/// Common error types +pub mod error; + pub use gen::Action; pub use gen::ActionType; pub use gen::BasicAuth; +pub use gen::CancelFlightInfoRequest; +pub use gen::CancelFlightInfoResult; +pub use gen::CancelStatus; pub use gen::Criteria; pub use gen::Empty; pub use gen::FlightData; @@ -61,15 +105,21 @@ pub use gen::FlightInfo; pub use gen::HandshakeRequest; pub use gen::HandshakeResponse; pub use gen::Location; +pub use gen::PollInfo; pub use gen::PutResult; +pub use gen::RenewFlightEndpointRequest; pub use gen::Result; pub use gen::SchemaResult; pub use gen::Ticket; +/// Helper to extract HTTP/gRPC trailers from a tonic stream. +mod trailers; + pub mod utils; #[cfg(feature = "flight-sql-experimental")] pub mod sql; +mod streams; use flight_descriptor::DescriptorType; @@ -81,21 +131,20 @@ pub struct SchemaAsIpc<'a> { /// IpcMessage represents a `Schema` in the format expected in /// `FlightInfo.schema` #[derive(Debug)] -pub struct IpcMessage(pub Vec); +pub struct IpcMessage(pub Bytes); // Useful conversion functions -fn flight_schema_as_encoded_data( - arrow_schema: &Schema, - options: &IpcWriteOptions, -) -> EncodedData { +fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData { let data_gen = writer::IpcDataGenerator::default(); - data_gen.schema_to_bytes(arrow_schema, options) + let mut dict_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options) } fn flight_schema_as_flatbuffer(schema: &Schema, options: &IpcWriteOptions) -> IpcMessage { let encoded_data = flight_schema_as_encoded_data(schema, options); - IpcMessage(encoded_data.ipc_message) + IpcMessage(encoded_data.ipc_message.into()) } // Implement a bunch of useful traits for various conversions, displays, @@ -104,7 +153,7 @@ fn flight_schema_as_flatbuffer(schema: &Schema, options: &IpcWriteOptions) -> Ip // Deref impl Deref for IpcMessage { - type Target = Vec; + type Target = [u8]; fn deref(&self) -> &Self::Target { &self.0 @@ -135,7 +184,7 @@ impl fmt::Display for FlightData { write!(f, "FlightData {{")?; write!(f, " descriptor: ")?; match &self.flight_descriptor { - Some(d) => write!(f, "{}", d)?, + Some(d) => write!(f, "{d}")?, None => write!(f, "None")?, }; write!(f, ", header: ")?; @@ -161,7 +210,7 @@ impl fmt::Display for FlightDescriptor { write!(f, "path: [")?; let mut sep = ""; for element in &self.path { - write!(f, "{}{}", sep, element)?; + write!(f, "{sep}{element}")?; sep = ", "; } write!(f, "]")?; @@ -179,16 +228,23 @@ impl fmt::Display for FlightEndpoint { write!(f, "FlightEndpoint {{")?; write!(f, " ticket: ")?; match &self.ticket { - Some(value) => write!(f, "{}", value), - None => write!(f, " none"), + Some(value) => write!(f, "{value}"), + None => write!(f, " None"), }?; write!(f, ", location: [")?; let mut sep = ""; for location in &self.location { - write!(f, "{}{}", sep, location)?; + write!(f, "{sep}{location}")?; sep = ", "; } write!(f, "]")?; + write!(f, ", expiration_time:")?; + match &self.expiration_time { + Some(value) => write!(f, " {value}"), + None => write!(f, " None"), + }?; + write!(f, ", app_metadata: ")?; + limited_fmt(f, &self.app_metadata, 8)?; write!(f, " }}") } } @@ -198,20 +254,82 @@ impl fmt::Display for FlightInfo { let ipc_message = IpcMessage(self.schema.clone()); let schema: Schema = ipc_message.try_into().map_err(|_err| fmt::Error)?; write!(f, "FlightInfo {{")?; - write!(f, " schema: {}", schema)?; + write!(f, " schema: {schema}")?; write!(f, ", descriptor:")?; match &self.flight_descriptor { - Some(d) => write!(f, " {}", d), + Some(d) => write!(f, " {d}"), None => write!(f, " None"), }?; write!(f, ", endpoint: [")?; let mut sep = ""; for endpoint in &self.endpoint { - write!(f, "{}{}", sep, endpoint)?; + write!(f, "{sep}{endpoint}")?; sep = ", "; } write!(f, "], total_records: {}", self.total_records)?; write!(f, ", total_bytes: {}", self.total_bytes)?; + write!(f, ", ordered: {}", self.ordered)?; + write!(f, ", app_metadata: ")?; + limited_fmt(f, &self.app_metadata, 8)?; + write!(f, " }}") + } +} + +impl fmt::Display for PollInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PollInfo {{")?; + write!(f, " info:")?; + match &self.info { + Some(value) => write!(f, " {value}"), + None => write!(f, " None"), + }?; + write!(f, ", descriptor:")?; + match &self.flight_descriptor { + Some(d) => write!(f, " {d}"), + None => write!(f, " None"), + }?; + write!(f, ", progress:")?; + match &self.progress { + Some(value) => write!(f, " {value}"), + None => write!(f, " None"), + }?; + write!(f, ", expiration_time:")?; + match &self.expiration_time { + Some(value) => write!(f, " {value}"), + None => write!(f, " None"), + }?; + write!(f, " }}") + } +} + +impl fmt::Display for CancelFlightInfoRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CancelFlightInfoRequest {{")?; + write!(f, " info: ")?; + match &self.info { + Some(value) => write!(f, "{value}")?, + None => write!(f, "None")?, + }; + write!(f, " }}") + } +} + +impl fmt::Display for CancelFlightInfoResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CancelFlightInfoResult {{")?; + write!(f, " status: {}", self.status().as_str_name())?; + write!(f, " }}") + } +} + +impl fmt::Display for RenewFlightEndpointRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RenewFlightEndpointRequest {{")?; + write!(f, " endpoint: ")?; + match &self.endpoint { + Some(value) => write!(f, "{value}")?, + None => write!(f, "None")?, + }; write!(f, " }}") } } @@ -228,7 +346,7 @@ impl fmt::Display for Ticket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Ticket {{")?; write!(f, " ticket: ")?; - write!(f, "{}", base64::encode(&self.ticket)) + write!(f, "{}", BASE64_STANDARD.encode(&self.ticket)) } } @@ -237,8 +355,8 @@ impl fmt::Display for Ticket { impl From for FlightData { fn from(data: EncodedData) -> Self { FlightData { - data_header: data.ipc_message, - data_body: data.arrow_data, + data_header: data.ipc_message.into(), + data_body: data.arrow_data.into(), ..Default::default() } } @@ -254,20 +372,17 @@ impl From> for FlightData { } } -impl From> for SchemaResult { - fn from(schema_ipc: SchemaAsIpc) -> Self { - let IpcMessage(vals) = flight_schema_as_flatbuffer(schema_ipc.0, schema_ipc.1); - SchemaResult { schema: vals } - } -} - -// TryFrom... - -impl TryFrom for DescriptorType { +impl TryFrom> for SchemaResult { type Error = ArrowError; - fn try_from(value: i32) -> ArrowResult { - value.try_into() + fn try_from(schema_ipc: SchemaAsIpc) -> ArrowResult { + // According to the definition from `Flight.proto` + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + let IpcMessage(vals) = schema_to_ipc_format(schema_ipc)?; + Ok(SchemaResult { schema: vals }) } } @@ -275,22 +390,25 @@ impl TryFrom> for IpcMessage { type Error = ArrowError; fn try_from(schema_ipc: SchemaAsIpc) -> ArrowResult { - let pair = *schema_ipc; - let encoded_data = flight_schema_as_encoded_data(pair.0, pair.1); - - let mut schema = vec![]; - arrow::ipc::writer::write_message(&mut schema, encoded_data, pair.1)?; - Ok(IpcMessage(schema)) + schema_to_ipc_format(schema_ipc) } } +fn schema_to_ipc_format(schema_ipc: SchemaAsIpc) -> ArrowResult { + let pair = *schema_ipc; + let encoded_data = flight_schema_as_encoded_data(pair.0, pair.1); + + let mut schema = vec![]; + writer::write_message(&mut schema, encoded_data, pair.1)?; + Ok(IpcMessage(schema.into())) +} + impl TryFrom<&FlightData> for Schema { type Error = ArrowError; fn try_from(data: &FlightData) -> ArrowResult { - convert::schema_from_bytes(&data.data_header[..]).map_err(|err| { + convert::try_schema_from_flatbuffer_bytes(&data.data_header[..]).map_err(|err| { ArrowError::ParseError(format!( - "Unable to convert flight data to Arrow schema: {}", - err + "Unable to convert flight data to Arrow schema: {err}" )) }) } @@ -300,8 +418,7 @@ impl TryFrom for Schema { type Error = ArrowError; fn try_from(value: FlightInfo) -> ArrowResult { - let msg = IpcMessage(value.schema); - msg.try_into() + value.try_decode_schema() } } @@ -309,63 +426,97 @@ impl TryFrom for Schema { type Error = ArrowError; fn try_from(value: IpcMessage) -> ArrowResult { - // CONTINUATION TAKES 4 BYTES - // SIZE TAKES 4 BYTES (so read msg as size prefixed) - let msg = size_prefixed_root_as_message(&value.0[4..]).map_err(|err| { - ArrowError::ParseError(format!( - "Unable to convert flight info to a message: {}", - err - )) - })?; - let ipc_schema = msg.header_as_schema().ok_or_else(|| { - ArrowError::ParseError( - "Unable to convert flight info to a schema".to_string(), - ) - })?; - Ok(convert::fb_to_schema(ipc_schema)) + try_schema_from_ipc_buffer(&value) } } impl TryFrom<&SchemaResult> for Schema { type Error = ArrowError; fn try_from(data: &SchemaResult) -> ArrowResult { - convert::schema_from_bytes(&data.schema[..]).map_err(|err| { - ArrowError::ParseError(format!( - "Unable to convert schema result to Arrow schema: {}", - err - )) - }) + try_schema_from_ipc_buffer(&data.schema) + } +} + +impl TryFrom for Schema { + type Error = ArrowError; + fn try_from(data: SchemaResult) -> ArrowResult { + (&data).try_into() } } // FlightData, FlightDescriptor, etc.. impl FlightData { - pub fn new( - flight_descriptor: Option, - message: IpcMessage, - app_metadata: Vec, - data_body: Vec, - ) -> Self { - let IpcMessage(vals) = message; - FlightData { - flight_descriptor, - data_header: vals, - app_metadata, - data_body, - } + /// Create a new [`FlightData`]. + /// + /// # See Also + /// + /// See [`FlightDataEncoderBuilder`] for a higher level API to + /// convert a stream of [`RecordBatch`]es to [`FlightData`]s + /// + /// # Example: + /// + /// ``` + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightData, FlightDescriptor}; + /// # fn encode_data() -> Bytes { Bytes::new() } // dummy data + /// // Get encoded Arrow IPC data: + /// let data_body: Bytes = encode_data(); + /// // Create the FlightData message + /// let flight_data = FlightData::new() + /// .with_descriptor(FlightDescriptor::new_cmd("the command")) + /// .with_app_metadata("My apps metadata") + /// .with_data_body(data_body); + /// ``` + /// + /// [`FlightDataEncoderBuilder`]: crate::encode::FlightDataEncoderBuilder + /// [`RecordBatch`]: arrow_array::RecordBatch + pub fn new() -> Self { + Default::default() + } + + /// Add a [`FlightDescriptor`] describing the data + pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { + self.flight_descriptor = Some(flight_descriptor); + self + } + + /// Add a data header + pub fn with_data_header(mut self, data_header: impl Into) -> Self { + self.data_header = data_header.into(); + self + } + + /// Add a data body. See [`IpcDataGenerator`] to create this data. + /// + /// [`IpcDataGenerator`]: arrow_ipc::writer::IpcDataGenerator + pub fn with_data_body(mut self, data_body: impl Into) -> Self { + self.data_body = data_body.into(); + self + } + + /// Add optional application specific metadata to the message + pub fn with_app_metadata(mut self, app_metadata: impl Into) -> Self { + self.app_metadata = app_metadata.into(); + self } } impl FlightDescriptor { - pub fn new_cmd(cmd: Vec) -> Self { + /// Create a new opaque command [`CMD`] `FlightDescriptor` to generate a dataset. + /// + /// [`CMD`]: https://github.com/apache/arrow/blob/6bd31f37ae66bd35594b077cb2f830be57e08acd/format/Flight.proto#L224-L227 + pub fn new_cmd(cmd: impl Into) -> Self { FlightDescriptor { r#type: DescriptorType::Cmd.into(), - cmd, + cmd: cmd.into(), ..Default::default() } } + /// Create a new named path [`PATH`] `FlightDescriptor` that identifies a dataset + /// + /// [`PATH`]: https://github.com/apache/arrow/blob/6bd31f37ae66bd35594b077cb2f830be57e08acd/format/Flight.proto#L217-L222 pub fn new_path(path: Vec) -> Self { FlightDescriptor { r#type: DescriptorType::Path.into(), @@ -376,21 +527,159 @@ impl FlightDescriptor { } impl FlightInfo { - pub fn new( - message: IpcMessage, - flight_descriptor: Option, - endpoint: Vec, - total_records: i64, - total_bytes: i64, - ) -> Self { - let IpcMessage(vals) = message; + /// Create a new, empty `FlightInfo`, describing where to fetch flight data + /// + /// + /// # Example: + /// ``` + /// # use arrow_flight::{FlightInfo, Ticket, FlightDescriptor, FlightEndpoint}; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # fn get_schema() -> Schema { + /// # Schema::new(vec![ + /// # Field::new("a", DataType::Utf8, false), + /// # ]) + /// # } + /// # + /// // Create a new FlightInfo + /// let flight_info = FlightInfo::new() + /// // Encode the Arrow schema + /// .try_with_schema(&get_schema()) + /// .expect("encoding failed") + /// .with_endpoint( + /// FlightEndpoint::new() + /// .with_ticket(Ticket::new("ticket contents") + /// ) + /// ) + /// .with_descriptor(FlightDescriptor::new_cmd("RUN QUERY")); + /// ``` + pub fn new() -> FlightInfo { FlightInfo { - schema: vals, - flight_descriptor, - endpoint, - total_records, - total_bytes, + schema: Bytes::new(), + flight_descriptor: None, + endpoint: vec![], + ordered: false, + // Flight says "Set these to -1 if unknown." + // + // https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L287-L289 + total_records: -1, + total_bytes: -1, + app_metadata: Bytes::new(), + } + } + + /// Try and convert the data in this `FlightInfo` into a [`Schema`] + pub fn try_decode_schema(self) -> ArrowResult { + let msg = IpcMessage(self.schema); + msg.try_into() + } + + /// Specify the schema for the response. + /// + /// Note this takes the arrow [`Schema`] (not the IPC schema) and + /// encodes it using the default IPC options. + /// + /// Returns an error if `schema` can not be encoded into IPC form. + pub fn try_with_schema(mut self, schema: &Schema) -> ArrowResult { + let options = IpcWriteOptions::default(); + let IpcMessage(schema) = SchemaAsIpc::new(schema, &options).try_into()?; + self.schema = schema; + Ok(self) + } + + /// Add specific a endpoint for fetching the data + pub fn with_endpoint(mut self, endpoint: FlightEndpoint) -> Self { + self.endpoint.push(endpoint); + self + } + + /// Add a [`FlightDescriptor`] describing what this data is + pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { + self.flight_descriptor = Some(flight_descriptor); + self + } + + /// Set the number of records in the result, if known + pub fn with_total_records(mut self, total_records: i64) -> Self { + self.total_records = total_records; + self + } + + /// Set the number of bytes in the result, if known + pub fn with_total_bytes(mut self, total_bytes: i64) -> Self { + self.total_bytes = total_bytes; + self + } + + /// Specify if the response is [ordered] across endpoints + /// + /// [ordered]: https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L269-L275 + pub fn with_ordered(mut self, ordered: bool) -> Self { + self.ordered = ordered; + self + } + + /// Add optional application specific metadata to the message + pub fn with_app_metadata(mut self, app_metadata: impl Into) -> Self { + self.app_metadata = app_metadata.into(); + self + } +} + +impl PollInfo { + /// Create a new, empty [`PollInfo`], providing information for a long-running query + /// + /// # Example: + /// ``` + /// # use arrow_flight::{FlightInfo, PollInfo, FlightDescriptor}; + /// # use prost_types::Timestamp; + /// // Create a new PollInfo + /// let poll_info = PollInfo::new() + /// .with_info(FlightInfo::new()) + /// .with_descriptor(FlightDescriptor::new_cmd("RUN QUERY")) + /// .try_with_progress(0.5) + /// .expect("progress should've been valid") + /// .with_expiration_time( + /// "1970-01-01".parse().expect("invalid timestamp") + /// ); + /// ``` + pub fn new() -> Self { + Self { + info: None, + flight_descriptor: None, + progress: None, + expiration_time: None, + } + } + + /// Add the current available results for the poll call as a [`FlightInfo`] + pub fn with_info(mut self, info: FlightInfo) -> Self { + self.info = Some(info); + self + } + + /// Add a [`FlightDescriptor`] that the client should use for the next poll call, + /// if the query is not yet complete + pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { + self.flight_descriptor = Some(flight_descriptor); + self + } + + /// Set the query progress if known. Must be in the range [0.0, 1.0] else this will + /// return an error + pub fn try_with_progress(mut self, progress: f64) -> ArrowResult { + if !(0.0..=1.0).contains(&progress) { + return Err(ArrowError::InvalidArgumentError(format!( + "PollInfo progress must be in the range [0.0, 1.0], got {progress}" + ))); } + self.progress = Some(progress); + Ok(self) + } + + /// Specify expiration time for this request + pub fn with_expiration_time(mut self, expiration_time: Timestamp) -> Self { + self.expiration_time = Some(expiration_time); + self } } @@ -402,9 +691,129 @@ impl<'a> SchemaAsIpc<'a> { } } +impl CancelFlightInfoRequest { + /// Create a new [`CancelFlightInfoRequest`], providing the [`FlightInfo`] + /// of the query to cancel. + pub fn new(info: FlightInfo) -> Self { + Self { info: Some(info) } + } +} + +impl CancelFlightInfoResult { + /// Create a new [`CancelFlightInfoResult`] from the provided [`CancelStatus`]. + pub fn new(status: CancelStatus) -> Self { + Self { + status: status as i32, + } + } +} + +impl RenewFlightEndpointRequest { + /// Create a new [`RenewFlightEndpointRequest`], providing the [`FlightEndpoint`] + /// for which is being requested an extension of its expiration. + pub fn new(endpoint: FlightEndpoint) -> Self { + Self { + endpoint: Some(endpoint), + } + } +} + +impl Action { + /// Create a new Action with type and body + pub fn new(action_type: impl Into, body: impl Into) -> Self { + Self { + r#type: action_type.into(), + body: body.into(), + } + } +} + +impl Result { + /// Create a new Result with the specified body + pub fn new(body: impl Into) -> Self { + Self { body: body.into() } + } +} + +impl Ticket { + /// Create a new `Ticket` + /// + /// # Example + /// + /// ``` + /// # use arrow_flight::Ticket; + /// let ticket = Ticket::new("SELECT * from FOO"); + /// ``` + pub fn new(ticket: impl Into) -> Self { + Self { + ticket: ticket.into(), + } + } +} + +impl FlightEndpoint { + /// Create a new, empty `FlightEndpoint` that represents a location + /// to retrieve Flight results. + /// + /// # Example + /// ``` + /// # use arrow_flight::{FlightEndpoint, Ticket}; + /// # + /// // Specify the client should fetch results from this server + /// let endpoint = FlightEndpoint::new() + /// .with_ticket(Ticket::new("the ticket")); + /// + /// // Specify the client should fetch results from either + /// // `http://example.com` or `https://example.com` + /// let endpoint = FlightEndpoint::new() + /// .with_ticket(Ticket::new("the ticket")) + /// .with_location("http://example.com") + /// .with_location("https://example.com"); + /// ``` + pub fn new() -> FlightEndpoint { + Default::default() + } + + /// Set the [`Ticket`] used to retrieve data from the endpoint + pub fn with_ticket(mut self, ticket: Ticket) -> Self { + self.ticket = Some(ticket); + self + } + + /// Add a location `uri` to this endpoint. Note each endpoint can + /// have multiple locations. + /// + /// If no `uri` is specified, the [Flight Spec] says: + /// + /// ```text + /// * If the list is empty, the expectation is that the ticket can only + /// * be redeemed on the current service where the ticket was + /// * generated. + /// ``` + /// [Flight Spec]: https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L307C2-L312 + pub fn with_location(mut self, uri: impl Into) -> Self { + self.location.push(Location { uri: uri.into() }); + self + } + + /// Specify expiration time for this stream + pub fn with_expiration_time(mut self, expiration_time: Timestamp) -> Self { + self.expiration_time = Some(expiration_time); + self + } + + /// Add optional application specific metadata to the message + pub fn with_app_metadata(mut self, app_metadata: impl Into) -> Self { + self.app_metadata = app_metadata.into(); + self + } +} + #[cfg(test)] mod tests { use super::*; + use arrow_ipc::MetadataVersion; + use arrow_schema::{DataType, Field, TimeUnit}; struct TestVector(Vec, usize); @@ -426,7 +835,7 @@ mod tests { fn it_accepts_equal_output() { let input = TestVector(vec![91; 10], 10); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 10]); assert_eq!(actual, expected); } @@ -435,7 +844,7 @@ mod tests { fn it_accepts_short_output() { let input = TestVector(vec![91; 6], 10); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 6]); assert_eq!(actual, expected); } @@ -444,8 +853,35 @@ mod tests { fn it_accepts_long_output() { let input = TestVector(vec![91; 10], 9); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 9]); assert_eq!(actual, expected); } + + #[test] + fn ser_deser_schema_result() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + // V5 with write_legacy_ipc_format = false + // this will write the continuation marker + let option = IpcWriteOptions::default(); + let schema_ipc = SchemaAsIpc::new(&schema, &option); + let result: SchemaResult = schema_ipc.try_into().unwrap(); + let des_schema: Schema = (&result).try_into().unwrap(); + assert_eq!(schema, des_schema); + + // V4 with write_legacy_ipc_format = true + // this will not write the continuation marker + let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap(); + let schema_ipc = SchemaAsIpc::new(&schema, &option); + let result: SchemaResult = schema_ipc.try_into().unwrap(); + let des_schema: Schema = (&result).try_into().unwrap(); + assert_eq!(schema, des_schema); + } } diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index 77221dd1a489..7a37a0b28856 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -1,13 +1,14 @@ // This file was automatically generated through the build.rs script, and should not be edited. +// This file is @generated by prost-build. /// -/// Represents a metadata request. Used in the command member of FlightDescriptor -/// for the following RPC calls: +/// Represents a metadata request. Used in the command member of FlightDescriptor +/// for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// info_name: uint32 not null, /// value: dense_union< /// string_value: utf8, @@ -16,185 +17,260 @@ /// int32_bitmask: int32, /// string_list: list /// int32_to_int32_list_map: map> -/// > -/// where there is one row per requested piece of metadata information. +/// > +/// where there is one row per requested piece of metadata information. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetSqlInfo { /// - /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide - /// Flight SQL clients with basic, SQL syntax and SQL functions related information. - /// More information types can be added in future releases. - /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + /// Flight SQL clients with basic, SQL syntax and SQL functions related information. + /// More information types can be added in future releases. + /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. /// - /// Note that the set of metadata may expand. + /// Note that the set of metadata may expand. /// - /// Initially, Flight SQL will support the following information types: - /// - Server Information - Range [0-500) - /// - Syntax Information - Range [500-1000) - /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). - /// Custom options should start at 10,000. + /// Initially, Flight SQL will support the following information types: + /// - Server Information - Range [0-500) + /// - Syntax Information - Range [500-1000) + /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + /// Custom options should start at 10,000. /// - /// If omitted, then all metadata will be retrieved. - /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must - /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. - /// If additional metadata is included, the metadata IDs should start from 10,000. - #[prost(uint32, repeated, tag="1")] + /// If omitted, then all metadata will be retrieved. + /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + /// If additional metadata is included, the metadata IDs should start from 10,000. + #[prost(uint32, repeated, tag = "1")] pub info: ::prost::alloc::vec::Vec, } /// -/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. -/// The definition of a catalog depends on vendor/implementation. It is usually the database itself -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve information about data type supported on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned schema will be: +/// < +/// type_name: utf8 not null (The name of the data type, for example: VARCHAR, INTEGER, etc), +/// data_type: int32 not null (The SQL data type), +/// column_size: int32 (The maximum size supported by that column. +/// In case of exact numeric types, this represents the maximum precision. +/// In case of string types, this represents the character length. +/// In case of datetime data types, this represents the length in characters of the string representation. +/// NULL is returned for data types where column size is not applicable.), +/// literal_prefix: utf8 (Character or characters used to prefix a literal, NULL is returned for +/// data types where a literal prefix is not applicable.), +/// literal_suffix: utf8 (Character or characters used to terminate a literal, +/// NULL is returned for data types where a literal suffix is not applicable.), +/// create_params: list +/// (A list of keywords corresponding to which parameters can be used when creating +/// a column for that specific type. +/// NULL is returned if there are no parameters for the data type definition.), +/// nullable: int32 not null (Shows if the data type accepts a NULL value. The possible values can be seen in the +/// Nullable enum.), +/// case_sensitive: bool not null (Shows if a character data type is case-sensitive in collations and comparisons), +/// searchable: int32 not null (Shows how the data type is used in a WHERE clause. The possible values can be seen in the +/// Searchable enum.), +/// unsigned_attribute: bool (Shows if the data type is unsigned. NULL is returned if the attribute is +/// not applicable to the data type or the data type is not numeric.), +/// fixed_prec_scale: bool not null (Shows if the data type has predefined fixed precision and scale.), +/// auto_increment: bool (Shows if the data type is auto incremental. NULL is returned if the attribute +/// is not applicable to the data type or the data type is not numeric.), +/// local_type_name: utf8 (Localized version of the data source-dependent name of the data type. NULL +/// is returned if a localized name is not supported by the data source), +/// minimum_scale: int32 (The minimum scale of the data type on the data source. +/// If a data type has a fixed scale, the MINIMUM_SCALE and MAXIMUM_SCALE +/// columns both contain this value. NULL is returned if scale is not applicable.), +/// maximum_scale: int32 (The maximum scale of the data type on the data source. +/// NULL is returned if scale is not applicable.), +/// sql_data_type: int32 not null (The value of the SQL DATA TYPE which has the same values +/// as data_type value. Except for interval and datetime, which +/// uses generic values. More info about those types can be +/// obtained through datetime_subcode. The possible values can be seen +/// in the XdbcDataType enum.), +/// datetime_subcode: int32 (Only used when the SQL DATA TYPE is interval or datetime. It contains +/// its sub types. For type different from interval and datetime, this value +/// is NULL. The possible values can be seen in the XdbcDatetimeSubcode enum.), +/// num_prec_radix: int32 (If the data type is an approximate numeric type, this column contains +/// the value 2 to indicate that COLUMN_SIZE specifies a number of bits. For +/// exact numeric types, this column contains the value 10 to indicate that +/// column size specifies a number of decimal digits. Otherwise, this column is NULL.), +/// interval_precision: int32 (If the data type is an interval data type, then this column contains the value +/// of the interval leading precision. Otherwise, this column is NULL. This fields +/// is only relevant to be used by ODBC). +/// > +/// The returned data should be ordered by data_type and then by type_name. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CommandGetXdbcTypeInfo { + /// + /// Specifies the data type to search for the info. + #[prost(int32, optional, tag = "1")] + pub data_type: ::core::option::Option, +} +/// +/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. +/// The definition of a catalog depends on vendor/implementation. It is usually the database itself +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CommandGetCatalogs { -} +/// > +/// The returned data should be ordered by catalog_name. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CommandGetCatalogs {} /// -/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. -/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. +/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name, then db_schema_name. +/// > +/// The returned data should be ordered by catalog_name, then db_schema_name. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetDbSchemas { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="2")] + #[prost(string, optional, tag = "2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, } /// -/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8, /// table_name: utf8 not null, /// table_type: utf8 not null, /// \[optional\] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, /// it is serialized as an IPC message.) -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. +/// > +/// Fields on table_schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case-sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetTables { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="2")] + #[prost(string, optional, tag = "2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for tables to search for. - /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for tables to search for. + /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="3")] - pub table_name_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, - /// - /// Specifies a filter of table types which must match. - /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. - /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. - #[prost(string, repeated, tag="4")] + #[prost(string, optional, tag = "3")] + pub table_name_filter_pattern: ::core::option::Option< + ::prost::alloc::string::String, + >, + /// + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + #[prost(string, repeated, tag = "4")] pub table_types: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// Specifies if the Arrow schema should be returned for found tables. - #[prost(bool, tag="5")] + /// Specifies if the Arrow schema should be returned for found tables. + #[prost(bool, tag = "5")] pub include_schema: bool, } /// -/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. -/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. -/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. +/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. +/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// table_type: utf8 not null -/// > -/// The returned data should be ordered by table_type. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CommandGetTableTypes { -} +/// > +/// The returned data should be ordered by table_type. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CommandGetTableTypes {} /// -/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8, /// table_name: utf8 not null, /// column_name: utf8 not null, /// key_name: utf8, -/// key_sequence: int not null -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. +/// key_sequence: int32 not null +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetPrimaryKeys { /// - /// Specifies the catalog to search for the table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the table to get the primary keys for. - #[prost(string, tag="3")] + /// Specifies the table to get the primary keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns that reference the given table's -/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve a description of the foreign key columns that reference the given table's +/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -203,40 +279,40 @@ pub struct CommandGetPrimaryKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetExportedKeys { /// - /// Specifies the catalog to search for the foreign key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the foreign key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the foreign key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the foreign key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the foreign key table to get the foreign keys for. - #[prost(string, tag="3")] + /// Specifies the foreign key table to get the foreign keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -245,14 +321,14 @@ pub struct CommandGetExportedKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: /// - 0 = CASCADE /// - 1 = RESTRICT /// - 2 = SET NULL @@ -261,31 +337,31 @@ pub struct CommandGetExportedKeys { #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetImportedKeys { /// - /// Specifies the catalog to search for the primary key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the primary key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the primary key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the primary key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the primary key table to get the foreign keys for. - #[prost(string, tag="3")] + /// Specifies the primary key table to get the foreign keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that -/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same -/// or a different table) on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that +/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same +/// or a different table) on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -294,14 +370,14 @@ pub struct CommandGetImportedKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: /// - 0 = CASCADE /// - 1 = RESTRICT /// - 2 = SET NULL @@ -310,697 +386,1265 @@ pub struct CommandGetImportedKeys { #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetCrossReference { /// * - /// The catalog name where the parent table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// The catalog name where the parent table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub pk_catalog: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The Schema name where the parent table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// The Schema name where the parent table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub pk_db_schema: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The parent table name. It cannot be null. - #[prost(string, tag="3")] + /// The parent table name. It cannot be null. + #[prost(string, tag = "3")] pub pk_table: ::prost::alloc::string::String, /// * - /// The catalog name where the foreign table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="4")] + /// The catalog name where the foreign table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "4")] pub fk_catalog: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The schema name where the foreign table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="5")] + /// The schema name where the foreign table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "5")] pub fk_db_schema: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The foreign table name. It cannot be null. - #[prost(string, tag="6")] + /// The foreign table name. It cannot be null. + #[prost(string, tag = "6")] pub fk_table: ::prost::alloc::string::String, } -// SQL Execution Action Messages - /// -/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. +/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementRequest { - /// The valid SQL string to create a prepared statement for. - #[prost(string, tag="1")] + /// The valid SQL string to create a prepared statement for. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Create/execute the prepared statement as part of this transaction (if + /// unset, executions of the prepared statement will be auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, } /// -/// Wrap the result of a "GetPreparedStatement" action. +/// An embedded message describing a Substrait plan to execute. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SubstraitPlan { + /// The serialized substrait.Plan to create a prepared statement for. + /// XXX(ARROW-16902): this is bytes instead of an embedded message + /// because Protobuf does not really support one DLL using Protobuf + /// definitions from another DLL. + #[prost(bytes = "bytes", tag = "1")] + pub plan: ::prost::bytes::Bytes, + /// The Substrait release, e.g. "0.12.0". This information is not + /// tracked in the plan itself, so this is the only way for consumers + /// to potentially know if they can handle the plan. + #[prost(string, tag = "2")] + pub version: ::prost::alloc::string::String, +} /// -/// The resultant PreparedStatement can be closed either: -/// - Manually, through the "ClosePreparedStatement" action; -/// - Automatically, by a server timeout. +/// Request message for the "CreatePreparedSubstraitPlan" action on a Flight SQL enabled backend. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCreatePreparedSubstraitPlanRequest { + /// The serialized substrait.Plan to create a prepared statement for. + #[prost(message, optional, tag = "1")] + pub plan: ::core::option::Option, + /// Create/execute the prepared statement as part of this transaction (if + /// unset, executions of the prepared statement will be auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// Wrap the result of a "CreatePreparedStatement" or "CreatePreparedSubstraitPlan" action. +/// +/// The resultant PreparedStatement can be closed either: +/// - Manually, through the "ClosePreparedStatement" action; +/// - Automatically, by a server timeout. +/// +/// The result should be wrapped in a google.protobuf.Any message. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementResult { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, - /// If a result set generating query was provided, dataset_schema contains the - /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. - #[prost(bytes="vec", tag="2")] - pub dataset_schema: ::prost::alloc::vec::Vec, - /// If the query provided contained parameters, parameter_schema contains the - /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. - #[prost(bytes="vec", tag="3")] - pub parameter_schema: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, + /// If a result set generating query was provided, dataset_schema contains the + /// schema of the result set. It should be an IPC-encapsulated Schema, as described in Schema.fbs. + /// For some queries, the schema of the results may depend on the schema of the parameters. The server + /// should provide its best guess as to the schema at this point. Clients must not assume that this + /// schema, if provided, will be accurate. + #[prost(bytes = "bytes", tag = "2")] + pub dataset_schema: ::prost::bytes::Bytes, + /// If the query provided contained parameters, parameter_schema contains the + /// schema of the expected parameters. It should be an IPC-encapsulated Schema, as described in Schema.fbs. + #[prost(bytes = "bytes", tag = "3")] + pub parameter_schema: ::prost::bytes::Bytes, } /// -/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. -/// Closes server resources associated with the prepared statement handle. +/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. +/// Closes server resources associated with the prepared statement handle. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionClosePreparedStatementRequest { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } -// SQL Execution Messages. - /// -/// Represents a SQL query. Used in the command member of FlightDescriptor -/// for the following RPC calls: +/// Request message for the "BeginTransaction" action. +/// Begins a transaction. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ActionBeginTransactionRequest {} +/// +/// Request message for the "BeginSavepoint" action. +/// Creates a savepoint within a transaction. +/// +/// Only supported if FLIGHT_SQL_TRANSACTION is +/// FLIGHT_SQL_TRANSACTION_SUPPORT_SAVEPOINT. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginSavepointRequest { + /// The transaction to which a savepoint belongs. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, + /// Name for the savepoint. + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, +} +/// +/// The result of a "BeginTransaction" action. +/// +/// The transaction can be manipulated with the "EndTransaction" action, or +/// automatically via server timeout. If the transaction times out, then it is +/// automatically rolled back. +/// +/// The result should be wrapped in a google.protobuf.Any message. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginTransactionResult { + /// Opaque handle for the transaction on the server. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, +} +/// +/// The result of a "BeginSavepoint" action. +/// +/// The transaction can be manipulated with the "EndSavepoint" action. +/// If the associated transaction is committed, rolled back, or times +/// out, then the savepoint is also invalidated. +/// +/// The result should be wrapped in a google.protobuf.Any message. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginSavepointResult { + /// Opaque handle for the savepoint on the server. + #[prost(bytes = "bytes", tag = "1")] + pub savepoint_id: ::prost::bytes::Bytes, +} +/// +/// Request message for the "EndTransaction" action. +/// +/// Commit (COMMIT) or rollback (ROLLBACK) the transaction. +/// +/// If the action completes successfully, the transaction handle is +/// invalidated, as are all associated savepoints. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionEndTransactionRequest { + /// Opaque handle for the transaction on the server. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, + /// Whether to commit/rollback the given transaction. + #[prost(enumeration = "action_end_transaction_request::EndTransaction", tag = "2")] + pub action: i32, +} +/// Nested message and enum types in `ActionEndTransactionRequest`. +pub mod action_end_transaction_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum EndTransaction { + Unspecified = 0, + /// Commit the transaction. + Commit = 1, + /// Roll back the transaction. + Rollback = 2, + } + impl EndTransaction { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "END_TRANSACTION_UNSPECIFIED", + Self::Commit => "END_TRANSACTION_COMMIT", + Self::Rollback => "END_TRANSACTION_ROLLBACK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "END_TRANSACTION_UNSPECIFIED" => Some(Self::Unspecified), + "END_TRANSACTION_COMMIT" => Some(Self::Commit), + "END_TRANSACTION_ROLLBACK" => Some(Self::Rollback), + _ => None, + } + } + } +} +/// +/// Request message for the "EndSavepoint" action. +/// +/// Release (RELEASE) the savepoint or rollback (ROLLBACK) to the +/// savepoint. +/// +/// Releasing a savepoint invalidates that savepoint. Rolling back to +/// a savepoint does not invalidate the savepoint, but invalidates all +/// savepoints created after the current savepoint. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionEndSavepointRequest { + /// Opaque handle for the savepoint on the server. + #[prost(bytes = "bytes", tag = "1")] + pub savepoint_id: ::prost::bytes::Bytes, + /// Whether to rollback/release the given savepoint. + #[prost(enumeration = "action_end_savepoint_request::EndSavepoint", tag = "2")] + pub action: i32, +} +/// Nested message and enum types in `ActionEndSavepointRequest`. +pub mod action_end_savepoint_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum EndSavepoint { + Unspecified = 0, + /// Release the savepoint. + Release = 1, + /// Roll back to a savepoint. + Rollback = 2, + } + impl EndSavepoint { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "END_SAVEPOINT_UNSPECIFIED", + Self::Release => "END_SAVEPOINT_RELEASE", + Self::Rollback => "END_SAVEPOINT_ROLLBACK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "END_SAVEPOINT_UNSPECIFIED" => Some(Self::Unspecified), + "END_SAVEPOINT_RELEASE" => Some(Self::Release), + "END_SAVEPOINT_ROLLBACK" => Some(Self::Rollback), + _ => None, + } + } + } +} +/// +/// Represents a SQL query. Used in the command member of FlightDescriptor +/// for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case-sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// - GetFlightInfo: execute the query. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementQuery { - /// The SQL syntax. - #[prost(string, tag="1")] + /// The SQL syntax. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// Represents a Substrait plan. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case-sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. +/// - GetFlightInfo: execute the query. +/// - DoPut: execute the query. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandStatementSubstraitPlan { + /// A serialized substrait.Plan + #[prost(message, optional, tag = "1")] + pub plan: ::core::option::Option, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, } /// * -/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. -/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. +/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. +/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. #[derive(Clone, PartialEq, ::prost::Message)] pub struct TicketStatementQuery { - /// Unique identifier for the instance of the statement to execute. - #[prost(bytes="vec", tag="1")] - pub statement_handle: ::prost::alloc::vec::Vec, + /// Unique identifier for the instance of the statement to execute. + #[prost(bytes = "bytes", tag = "1")] + pub statement_handle: ::prost::bytes::Bytes, } /// -/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for -/// the following RPC calls: +/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for +/// the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case-sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. +/// +/// If the schema is retrieved after parameter values have been bound with DoPut, then the server should account +/// for the parameters when determining the schema. /// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. /// - GetFlightInfo: execute the prepared statement instance. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementQuery { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included SQL update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the RPC call DoPut to cause the server to execute the included SQL update. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementUpdate { - /// The SQL syntax. - #[prost(string, tag="1")] + /// The SQL syntax. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included -/// prepared statement handle as an update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the RPC call DoPut to cause the server to execute the included +/// prepared statement handle as an update. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementUpdate { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } /// -/// Returned from the RPC call DoPut when a CommandStatementUpdate -/// CommandPreparedStatementUpdate was in the request, containing -/// results from the update. +/// Represents a bulk ingestion request. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server load the contents of the stream's +/// FlightData into the target destination. #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandStatementIngest { + /// The behavior for handling the table definition. + #[prost(message, optional, tag = "1")] + pub table_definition_options: ::core::option::Option< + command_statement_ingest::TableDefinitionOptions, + >, + /// The table to load data into. + #[prost(string, tag = "2")] + pub table: ::prost::alloc::string::String, + /// The db_schema of the destination table to load data into. If unset, a backend-specific default may be used. + #[prost(string, optional, tag = "3")] + pub schema: ::core::option::Option<::prost::alloc::string::String>, + /// The catalog of the destination table to load data into. If unset, a backend-specific default may be used. + #[prost(string, optional, tag = "4")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Store ingested data in a temporary table. + /// The effect of setting temporary is to place the table in a backend-defined namespace, and to drop the table at the end of the session. + /// The namespacing may make use of a backend-specific schema and/or catalog. + /// The server should return an error if an explicit choice of schema or catalog is incompatible with the server's namespacing decision. + #[prost(bool, tag = "5")] + pub temporary: bool, + /// Perform the ingestion as part of this transaction. If specified, results should not be committed in the event of an error/cancellation. + #[prost(bytes = "bytes", optional, tag = "6")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, + /// Backend-specific options. + #[prost(map = "string, string", tag = "1000")] + pub options: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +/// Nested message and enum types in `CommandStatementIngest`. +pub mod command_statement_ingest { + /// Options for table definition behavior + #[derive(Clone, Copy, PartialEq, ::prost::Message)] + pub struct TableDefinitionOptions { + #[prost( + enumeration = "table_definition_options::TableNotExistOption", + tag = "1" + )] + pub if_not_exist: i32, + #[prost(enumeration = "table_definition_options::TableExistsOption", tag = "2")] + pub if_exists: i32, + } + /// Nested message and enum types in `TableDefinitionOptions`. + pub mod table_definition_options { + /// The action to take if the target table does not exist + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum TableNotExistOption { + /// Do not use. Servers should error if this is specified by a client. + Unspecified = 0, + /// Create the table if it does not exist + Create = 1, + /// Fail if the table does not exist + Fail = 2, + } + impl TableNotExistOption { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "TABLE_NOT_EXIST_OPTION_UNSPECIFIED", + Self::Create => "TABLE_NOT_EXIST_OPTION_CREATE", + Self::Fail => "TABLE_NOT_EXIST_OPTION_FAIL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "TABLE_NOT_EXIST_OPTION_UNSPECIFIED" => Some(Self::Unspecified), + "TABLE_NOT_EXIST_OPTION_CREATE" => Some(Self::Create), + "TABLE_NOT_EXIST_OPTION_FAIL" => Some(Self::Fail), + _ => None, + } + } + } + /// The action to take if the target table already exists + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum TableExistsOption { + /// Do not use. Servers should error if this is specified by a client. + Unspecified = 0, + /// Fail if the table already exists + Fail = 1, + /// Append to the table if it already exists + Append = 2, + /// Drop and recreate the table if it already exists + Replace = 3, + } + impl TableExistsOption { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "TABLE_EXISTS_OPTION_UNSPECIFIED", + Self::Fail => "TABLE_EXISTS_OPTION_FAIL", + Self::Append => "TABLE_EXISTS_OPTION_APPEND", + Self::Replace => "TABLE_EXISTS_OPTION_REPLACE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "TABLE_EXISTS_OPTION_UNSPECIFIED" => Some(Self::Unspecified), + "TABLE_EXISTS_OPTION_FAIL" => Some(Self::Fail), + "TABLE_EXISTS_OPTION_APPEND" => Some(Self::Append), + "TABLE_EXISTS_OPTION_REPLACE" => Some(Self::Replace), + _ => None, + } + } + } + } +} +/// +/// Returned from the RPC call DoPut when a CommandStatementUpdate, +/// CommandPreparedStatementUpdate, or CommandStatementIngest was +/// in the request, containing results from the update. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct DoPutUpdateResult { - /// The number of records updated. A return value of -1 represents - /// an unknown updated record count. - #[prost(int64, tag="1")] + /// The number of records updated. A return value of -1 represents + /// an unknown updated record count. + #[prost(int64, tag = "1")] pub record_count: i64, } -/// Options for CommandGetSqlInfo. +/// An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. +/// +/// *Note on legacy behavior*: previous versions of the protocol did not return any result for +/// this command, and that behavior should still be supported by clients. In that case, the client +/// can continue as though the fields in this message were not provided or set to sensible default values. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DoPutPreparedStatementResult { + /// Represents a (potentially updated) opaque handle for the prepared statement on the server. + /// Because the handle could potentially be updated, any previous handles for this prepared + /// statement should be considered invalid, and all subsequent requests for this prepared + /// statement must use this new handle. + /// The updated handle allows implementing query parameters with stateless services. + /// + /// When an updated handle is not provided by the server, clients should contiue + /// using the previous handle provided by `ActionCreatePreparedStatementResonse`. + #[prost(bytes = "bytes", optional, tag = "1")] + pub prepared_statement_handle: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// Request message for the "CancelQuery" action. +/// +/// Explicitly cancel a running query. +/// +/// This lets a single client explicitly cancel work, no matter how many clients +/// are involved/whether the query is distributed or not, given server support. +/// The transaction/statement is not rolled back; it is the application's job to +/// commit or rollback as appropriate. This only indicates the client no longer +/// wishes to read the remainder of the query results or continue submitting +/// data. +/// +/// This command is idempotent. +/// +/// This command is deprecated since 13.0.0. Use the "CancelFlightInfo" +/// action with DoAction instead. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCancelQueryRequest { + /// The result of the GetFlightInfo RPC that initiated the query. + /// XXX(ARROW-16902): this must be a serialized FlightInfo, but is + /// rendered as bytes because Protobuf does not really support one + /// DLL using Protobuf definitions from another DLL. + #[prost(bytes = "bytes", tag = "1")] + pub info: ::prost::bytes::Bytes, +} +/// +/// The result of cancelling a query. +/// +/// The result should be wrapped in a google.protobuf.Any message. +/// +/// This command is deprecated since 13.0.0. Use the "CancelFlightInfo" +/// action with DoAction instead. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ActionCancelQueryResult { + #[prost(enumeration = "action_cancel_query_result::CancelResult", tag = "1")] + pub result: i32, +} +/// Nested message and enum types in `ActionCancelQueryResult`. +pub mod action_cancel_query_result { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum CancelResult { + /// The cancellation status is unknown. Servers should avoid using + /// this value (send a NOT_FOUND error if the requested query is + /// not known). Clients can retry the request. + Unspecified = 0, + /// The cancellation request is complete. Subsequent requests with + /// the same payload may return CANCELLED or a NOT_FOUND error. + Cancelled = 1, + /// The cancellation request is in progress. The client may retry + /// the cancellation request. + Cancelling = 2, + /// The query is not cancellable. The client should not retry the + /// cancellation request. + NotCancellable = 3, + } + impl CancelResult { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "CANCEL_RESULT_UNSPECIFIED", + Self::Cancelled => "CANCEL_RESULT_CANCELLED", + Self::Cancelling => "CANCEL_RESULT_CANCELLING", + Self::NotCancellable => "CANCEL_RESULT_NOT_CANCELLABLE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CANCEL_RESULT_UNSPECIFIED" => Some(Self::Unspecified), + "CANCEL_RESULT_CANCELLED" => Some(Self::Cancelled), + "CANCEL_RESULT_CANCELLING" => Some(Self::Cancelling), + "CANCEL_RESULT_NOT_CANCELLABLE" => Some(Self::NotCancellable), + _ => None, + } + } + } +} +/// Options for CommandGetSqlInfo. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlInfo { - // Server Information [0-500): Provides basic information about the Flight SQL Server. - - /// Retrieves a UTF-8 string with the name of the Flight SQL Server. + /// Retrieves a UTF-8 string with the name of the Flight SQL Server. FlightSqlServerName = 0, - /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. FlightSqlServerVersion = 1, - /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. FlightSqlServerArrowVersion = 2, - /// - /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. /// - /// Returns: - /// - false: if read-write - /// - true: if read only + /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. + /// + /// Returns: + /// - false: if read-write + /// - true: if read only FlightSqlServerReadOnly = 3, - // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. - /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports executing + /// SQL queries. + /// + /// Note that the absence of this info (as opposed to a false value) does not necessarily + /// mean that SQL is not supported, as this property was not originally defined. + FlightSqlServerSql = 4, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports executing + /// Substrait plans. + FlightSqlServerSubstrait = 5, + /// + /// Retrieves a string value indicating the minimum supported Substrait version, or null + /// if Substrait is not supported. + FlightSqlServerSubstraitMinVersion = 6, + /// + /// Retrieves a string value indicating the maximum supported Substrait version, or null + /// if Substrait is not supported. + FlightSqlServerSubstraitMaxVersion = 7, + /// + /// Retrieves an int32 indicating whether the Flight SQL Server supports the + /// BeginTransaction/EndTransaction/BeginSavepoint/EndSavepoint actions. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of catalogs. - /// - true: if it supports CREATE and DROP of catalogs. + /// Even if this is not supported, the database may still support explicit "BEGIN + /// TRANSACTION"/"COMMIT" SQL statements (see SQL_TRANSACTIONS_SUPPORTED); this property + /// is only about whether the server implements the Flight SQL API endpoints. + /// + /// The possible values are listed in `SqlSupportedTransaction`. + FlightSqlServerTransaction = 8, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports explicit + /// query cancellation (the CancelQuery action). + FlightSqlServerCancel = 9, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports executing + /// bulk ingestion. + FlightSqlServerBulkIngestion = 10, + /// + /// Retrieves a boolean value indicating whether transactions are supported for bulk ingestion. If not, invoking + /// the method commit in the context of a bulk ingestion is a noop, and the isolation level is + /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + /// + /// Returns: + /// - false: if bulk ingestion transactions are unsupported; + /// - true: if bulk ingestion transactions are supported. + FlightSqlServerIngestTransactionsSupported = 11, + /// + /// Retrieves an int32 indicating the timeout (in milliseconds) for prepared statement handles. + /// + /// If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + FlightSqlServerStatementTimeout = 100, + /// + /// Retrieves an int32 indicating the timeout (in milliseconds) for transactions, since transactions are not tied to a connection. + /// + /// If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + FlightSqlServerTransactionTimeout = 101, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// + /// Returns: + /// - false: if it doesn't support CREATE and DROP of catalogs. + /// - true: if it supports CREATE and DROP of catalogs. SqlDdlCatalog = 500, /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of schemas. - /// - true: if it supports CREATE and DROP of schemas. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of schemas. + /// - true: if it supports CREATE and DROP of schemas. SqlDdlSchema = 501, /// - /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of tables. - /// - true: if it supports CREATE and DROP of tables. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of tables. + /// - true: if it supports CREATE and DROP of tables. SqlDdlTable = 502, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + /// Retrieves a int32 ordinal representing the case sensitivity of catalog, table, schema and table names. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlIdentifierCase = 503, - /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. SqlIdentifierQuoteChar = 504, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + /// Retrieves a int32 describing the case sensitivity of quoted identifiers. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlQuotedIdentifierCase = 505, /// - /// Retrieves a boolean value indicating whether all tables are selectable. + /// Retrieves a boolean value indicating whether all tables are selectable. /// - /// Returns: - /// - false: if not all tables are selectable or if none are; - /// - true: if all tables are selectable. + /// Returns: + /// - false: if not all tables are selectable or if none are; + /// - true: if all tables are selectable. SqlAllTablesAreSelectable = 506, /// - /// Retrieves the null ordering. + /// Retrieves the null ordering. /// - /// Returns a uint32 ordinal for the null ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlNullOrdering`. + /// Returns a int32 ordinal for the null ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlNullOrdering`. SqlNullOrdering = 507, - /// Retrieves a UTF-8 string list with values of the supported keywords. + /// Retrieves a UTF-8 string list with values of the supported keywords. SqlKeywords = 508, - /// Retrieves a UTF-8 string list with values of the supported numeric functions. + /// Retrieves a UTF-8 string list with values of the supported numeric functions. SqlNumericFunctions = 509, - /// Retrieves a UTF-8 string list with values of the supported string functions. + /// Retrieves a UTF-8 string list with values of the supported string functions. SqlStringFunctions = 510, - /// Retrieves a UTF-8 string list with values of the supported system functions. + /// Retrieves a UTF-8 string list with values of the supported system functions. SqlSystemFunctions = 511, - /// Retrieves a UTF-8 string list with values of the supported datetime functions. + /// Retrieves a UTF-8 string list with values of the supported datetime functions. SqlDatetimeFunctions = 512, /// - /// Retrieves the UTF-8 string that can be used to escape wildcard characters. - /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern - /// (and therefore use one of the wildcard characters). - /// The '_' character represents any single character; the '%' character represents any sequence of zero or more - /// characters. + /// Retrieves the UTF-8 string that can be used to escape wildcard characters. + /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + /// (and therefore use one of the wildcard characters). + /// The '_' character represents any single character; the '%' character represents any sequence of zero or more + /// characters. SqlSearchStringEscape = 513, /// - /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names - /// (those beyond a-z, A-Z, 0-9 and _). + /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + /// (those beyond a-z, A-Z, 0-9 and _). SqlExtraNameCharacters = 514, /// - /// Retrieves a boolean value indicating whether column aliasing is supported. - /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns - /// as required. + /// Retrieves a boolean value indicating whether column aliasing is supported. + /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + /// as required. /// - /// Returns: - /// - false: if column aliasing is unsupported; - /// - true: if column aliasing is supported. + /// Returns: + /// - false: if column aliasing is unsupported; + /// - true: if column aliasing is supported. SqlSupportsColumnAliasing = 515, /// - /// Retrieves a boolean value indicating whether concatenations between null and non-null values being - /// null are supported. + /// Retrieves a boolean value indicating whether concatenations between null and non-null values being + /// null are supported. /// - /// - Returns: - /// - false: if concatenations between null and non-null values being null are unsupported; - /// - true: if concatenations between null and non-null values being null are supported. + /// - Returns: + /// - false: if concatenations between null and non-null values being null are unsupported; + /// - true: if concatenations between null and non-null values being null are supported. SqlNullPlusNullIsNull = 516, /// - /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, - /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on - /// SqlSupportsConvert enum. - /// The returned map will be: map> + /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + /// SqlSupportsConvert enum. + /// The returned map will be: map> SqlSupportsConvert = 517, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if table correlation names are unsupported; - /// - true: if table correlation names are supported. + /// Returns: + /// - false: if table correlation names are unsupported; + /// - true: if table correlation names are supported. SqlSupportsTableCorrelationNames = 518, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if different table correlation names are unsupported; - /// - true: if different table correlation names are supported + /// Returns: + /// - false: if different table correlation names are unsupported; + /// - true: if different table correlation names are supported SqlSupportsDifferentTableCorrelationNames = 519, /// - /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. /// - /// Returns: - /// - false: if expressions in ORDER BY are unsupported; - /// - true: if expressions in ORDER BY are supported; + /// Returns: + /// - false: if expressions in ORDER BY are unsupported; + /// - true: if expressions in ORDER BY are supported; SqlSupportsExpressionsInOrderBy = 520, /// - /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY - /// clause is supported. + /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + /// clause is supported. /// - /// Returns: - /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; - /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + /// Returns: + /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. SqlSupportsOrderByUnrelated = 521, /// - /// Retrieves the supported GROUP BY commands; + /// Retrieves the supported GROUP BY commands; /// - /// Returns an int32 bitmask value representing the supported commands. - /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// Returns an int32 bitmask value representing the supported commands. + /// The returned bitmask should be parsed in order to retrieve the supported commands. /// - /// For instance: - /// - return 0 (\b0) => [] (GROUP BY is unsupported); - /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; - /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; - /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. - /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + /// For instance: + /// - return 0 (\b0) => \[\] (GROUP BY is unsupported); + /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; + /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; + /// - return 3 (\b11) => \[SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT\]. + /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. SqlSupportedGroupBy = 522, /// - /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. /// - /// Returns: - /// - false: if specifying a LIKE escape clause is unsupported; - /// - true: if specifying a LIKE escape clause is supported. + /// Returns: + /// - false: if specifying a LIKE escape clause is unsupported; + /// - true: if specifying a LIKE escape clause is supported. SqlSupportsLikeEscapeClause = 523, /// - /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. + /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. /// - /// Returns: - /// - false: if columns cannot be defined as non-nullable; - /// - true: if columns may be defined as non-nullable. + /// Returns: + /// - false: if columns cannot be defined as non-nullable; + /// - true: if columns may be defined as non-nullable. SqlSupportsNonNullableColumns = 524, /// - /// Retrieves the supported SQL grammar level as per the ODBC specification. - /// - /// Returns an int32 bitmask value representing the supported SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (SQL grammar is unsupported); - /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; - /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; - /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; - /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; - /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. - /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + /// Retrieves the supported SQL grammar level as per the ODBC specification. + /// + /// Returns an int32 bitmask value representing the supported SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (SQL grammar is unsupported); + /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; + /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; + /// - return 3 (\b11) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR\]; + /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; + /// - return 5 (\b101) => \[SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 6 (\b110) => \[SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 7 (\b111) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]. + /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. SqlSupportedGrammar = 525, /// - /// Retrieves the supported ANSI92 SQL grammar level. - /// - /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported commands. - /// - /// For instance: - /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); - /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; - /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; - /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; - /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; - /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; - /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; - /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. - /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + /// Retrieves the supported ANSI92 SQL grammar level. + /// + /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (ANSI92 SQL grammar is unsupported); + /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; + /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; + /// - return 3 (\b11) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL\]; + /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; + /// - return 5 (\b101) => \[ANSI92_ENTRY_SQL, ANSI92_FULL_SQL\]; + /// - return 6 (\b110) => \[ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]; + /// - return 7 (\b111) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]. + /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. SqlAnsi92SupportedLevel = 526, /// - /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. /// - /// Returns: - /// - false: if the SQL Integrity Enhancement Facility is supported; - /// - true: if the SQL Integrity Enhancement Facility is supported. + /// Returns: + /// - false: if the SQL Integrity Enhancement Facility is supported; + /// - true: if the SQL Integrity Enhancement Facility is supported. SqlSupportsIntegrityEnhancementFacility = 527, /// - /// Retrieves the support level for SQL OUTER JOINs. + /// Retrieves the support level for SQL OUTER JOINs. /// - /// Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + /// Returns a int32 ordinal for the SQL ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. SqlOuterJoinsSupportLevel = 528, - /// Retrieves a UTF-8 string with the preferred term for "schema". + /// Retrieves a UTF-8 string with the preferred term for "schema". SqlSchemaTerm = 529, - /// Retrieves a UTF-8 string with the preferred term for "procedure". + /// Retrieves a UTF-8 string with the preferred term for "procedure". SqlProcedureTerm = 530, - /// Retrieves a UTF-8 string with the preferred term for "catalog". + /// + /// Retrieves a UTF-8 string with the preferred term for "catalog". + /// If a empty string is returned its assumed that the server does NOT supports catalogs. SqlCatalogTerm = 531, /// - /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. /// - /// - false: if a catalog does not appear at the start of a fully qualified table name; - /// - true: if a catalog appears at the start of a fully qualified table name. + /// - false: if a catalog does not appear at the start of a fully qualified table name; + /// - true: if a catalog appears at the start of a fully qualified table name. SqlCatalogAtStart = 532, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL schema. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL schema); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL schema. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported actions for SQL schema); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. + /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlSchemasSupportedActions = 533, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL catalog); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported actions for SQL catalog); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. + /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlCatalogsSupportedActions = 534, /// - /// Retrieves the supported SQL positioned commands. + /// Retrieves the supported SQL positioned commands. /// - /// Returns an int32 bitmask value representing the supported SQL positioned commands. - /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + /// Returns an int32 bitmask value representing the supported SQL positioned commands. + /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; - /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; - /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; + /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; + /// - return 3 (\b11) => \[SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE\]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. SqlSupportedPositionedCommands = 535, /// - /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. /// - /// Returns: - /// - false: if SELECT FOR UPDATE statements are unsupported; - /// - true: if SELECT FOR UPDATE statements are supported. + /// Returns: + /// - false: if SELECT FOR UPDATE statements are unsupported; + /// - true: if SELECT FOR UPDATE statements are supported. SqlSelectForUpdateSupported = 536, /// - /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax - /// are supported. + /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + /// are supported. /// - /// Returns: - /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; - /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. + /// Returns: + /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. SqlStoredProceduresSupported = 537, /// - /// Retrieves the supported SQL subqueries. - /// - /// Returns an int32 bitmask value representing the supported SQL subqueries. - /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL subqueries); - /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; - /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; - /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; - /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; - /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; - /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; - /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - ... - /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + /// Retrieves the supported SQL subqueries. + /// + /// Returns an int32 bitmask value representing the supported SQL subqueries. + /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL subqueries); + /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; + /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 3 (\b11) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; + /// - return 5 (\b101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS\]; + /// - return 6 (\b110) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 7 (\b111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS\]; + /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 9 (\b1001) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 10 (\b1010) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 11 (\b1011) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 12 (\b1100) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 13 (\b1101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 14 (\b1110) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 15 (\b1111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - ... + /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. SqlSupportedSubqueries = 538, /// - /// Retrieves a boolean value indicating whether correlated subqueries are supported. + /// Retrieves a boolean value indicating whether correlated subqueries are supported. /// - /// Returns: - /// - false: if correlated subqueries are unsupported; - /// - true: if correlated subqueries are supported. + /// Returns: + /// - false: if correlated subqueries are unsupported; + /// - true: if correlated subqueries are supported. SqlCorrelatedSubqueriesSupported = 539, /// - /// Retrieves the supported SQL UNIONs. + /// Retrieves the supported SQL UNIONs. /// - /// Returns an int32 bitmask value representing the supported SQL UNIONs. - /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + /// Returns an int32 bitmask value representing the supported SQL UNIONs. + /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_UNION\]; - /// - return 2 (\b10) => \[SQL_UNION_ALL\]; - /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_UNION\]; + /// - return 2 (\b10) => \[SQL_UNION_ALL\]; + /// - return 3 (\b11) => \[SQL_UNION, SQL_UNION_ALL\]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. SqlSupportedUnions = 540, - /// Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + /// Retrieves a int64 value representing the maximum number of hex characters allowed in an inline binary literal. SqlMaxBinaryLiteralLength = 541, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + /// Retrieves a int64 value representing the maximum number of characters allowed for a character literal. SqlMaxCharLiteralLength = 542, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + /// Retrieves a int64 value representing the maximum number of characters allowed for a column name. SqlMaxColumnNameLength = 543, - /// Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + /// Retrieves a int64 value representing the maximum number of columns allowed in a GROUP BY clause. SqlMaxColumnsInGroupBy = 544, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an index. + /// Retrieves a int64 value representing the maximum number of columns allowed in an index. SqlMaxColumnsInIndex = 545, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + /// Retrieves a int64 value representing the maximum number of columns allowed in an ORDER BY clause. SqlMaxColumnsInOrderBy = 546, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + /// Retrieves a int64 value representing the maximum number of columns allowed in a SELECT list. SqlMaxColumnsInSelect = 547, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a table. + /// Retrieves a int64 value representing the maximum number of columns allowed in a table. SqlMaxColumnsInTable = 548, - /// Retrieves a uint32 value representing the maximum number of concurrent connections possible. + /// Retrieves a int64 value representing the maximum number of concurrent connections possible. SqlMaxConnections = 549, - /// Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + /// Retrieves a int64 value the maximum number of characters allowed in a cursor name. SqlMaxCursorNameLength = 550, /// - /// Retrieves a uint32 value representing the maximum number of bytes allowed for an index, - /// including all of the parts of the index. + /// Retrieves a int64 value representing the maximum number of bytes allowed for an index, + /// including all of the parts of the index. SqlMaxIndexLength = 551, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a schema name. SqlDbSchemaNameLength = 552, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a procedure name. SqlMaxProcedureNameLength = 553, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a catalog name. SqlMaxCatalogNameLength = 554, - /// Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + /// Retrieves a int64 value representing the maximum number of bytes allowed in a single row. SqlMaxRowSize = 555, /// - /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL - /// data types LONGVARCHAR and LONGVARBINARY. + /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + /// data types LONGVARCHAR and LONGVARBINARY. /// - /// Returns: - /// - false: if return value for the JDBC method getMaxRowSize does + /// Returns: + /// - false: if return value for the JDBC method getMaxRowSize does /// not include the SQL data types LONGVARCHAR and LONGVARBINARY; - /// - true: if return value for the JDBC method getMaxRowSize includes + /// - true: if return value for the JDBC method getMaxRowSize includes /// the SQL data types LONGVARCHAR and LONGVARBINARY. SqlMaxRowSizeIncludesBlobs = 556, /// - /// Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; - /// a result of 0 (zero) means that there is no limit or the limit is not known. + /// Retrieves a int64 value representing the maximum number of characters allowed for an SQL statement; + /// a result of 0 (zero) means that there is no limit or the limit is not known. SqlMaxStatementLength = 557, - /// Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + /// Retrieves a int64 value representing the maximum number of active statements that can be open at the same time. SqlMaxStatements = 558, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a table name. SqlMaxTableNameLength = 559, - /// Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + /// Retrieves a int64 value representing the maximum number of tables allowed in a SELECT statement. SqlMaxTablesInSelect = 560, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a user name. SqlMaxUsernameLength = 561, /// - /// Retrieves this database's default transaction isolation level as described in - /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves this database's default transaction isolation level as described in + /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. /// - /// Returns a uint32 ordinal for the SQL transaction isolation level. + /// Returns a int32 ordinal for the SQL transaction isolation level. SqlDefaultTransactionIsolation = 562, /// - /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a - /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. /// - /// Returns: - /// - false: if transactions are unsupported; - /// - true: if transactions are supported. + /// Returns: + /// - false: if transactions are unsupported; + /// - true: if transactions are supported. SqlTransactionsSupported = 563, /// - /// Retrieves the supported transactions isolation levels. - /// - /// Returns an int32 bitmask value representing the supported transactions isolation levels. - /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); - /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; - /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; - /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; - /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; - /// - ... - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves the supported transactions isolation levels. + /// + /// Returns an int32 bitmask value representing the supported transactions isolation levels. + /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL transactions isolation levels); + /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; + /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 3 (\b11) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 5 (\b101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 6 (\b110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 7 (\b111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 9 (\b1001) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 10 (\b1010) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 11 (\b1011) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 12 (\b1100) => \[SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 13 (\b1101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 14 (\b1110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 15 (\b1111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; + /// - ... + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. SqlSupportedTransactionsIsolationLevels = 564, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces - /// the transaction to commit. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces + /// the transaction to commit. /// - /// Returns: - /// - false: if a data definition statement within a transaction does not force the transaction to commit; - /// - true: if a data definition statement within a transaction forces the transaction to commit. + /// Returns: + /// - false: if a data definition statement within a transaction does not force the transaction to commit; + /// - true: if a data definition statement within a transaction forces the transaction to commit. SqlDataDefinitionCausesTransactionCommit = 565, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. /// - /// Returns: - /// - false: if a data definition statement within a transaction is taken into account; - /// - true: a data definition statement within a transaction is ignored. + /// Returns: + /// - false: if a data definition statement within a transaction is taken into account; + /// - true: a data definition statement within a transaction is ignored. SqlDataDefinitionsInTransactionsIgnored = 566, /// - /// Retrieves an int32 bitmask value representing the supported result set types. - /// The returned bitmask should be parsed in order to retrieve the supported result set types. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported result set types); - /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; - /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; - /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; - /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; - /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; - /// - ... - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + /// Retrieves an int32 bitmask value representing the supported result set types. + /// The returned bitmask should be parsed in order to retrieve the supported result set types. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported result set types); + /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; + /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 3 (\b11) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 5 (\b101) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 6 (\b110) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 7 (\b111) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; + /// - ... + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. SqlSupportedResultSetTypes = 567, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetUnspecified = 568, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetForwardOnly = 569, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollSensitive = 570, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollInsensitive = 571, /// - /// Retrieves a boolean value indicating whether this database supports batch updates. + /// Retrieves a boolean value indicating whether this database supports batch updates. /// - /// - false: if this database does not support batch updates; - /// - true: if this database supports batch updates. + /// - false: if this database does not support batch updates; + /// - true: if this database supports batch updates. SqlBatchUpdatesSupported = 572, /// - /// Retrieves a boolean value indicating whether this database supports savepoints. + /// Retrieves a boolean value indicating whether this database supports savepoints. /// - /// Returns: - /// - false: if this database does not support savepoints; - /// - true: if this database supports savepoints. + /// Returns: + /// - false: if this database does not support savepoints; + /// - true: if this database supports savepoints. SqlSavepointsSupported = 573, /// - /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. + /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. /// - /// Returns: - /// - false: if named parameters in callable statements are unsupported; - /// - true: if named parameters in callable statements are supported. + /// Returns: + /// - false: if named parameters in callable statements are unsupported; + /// - true: if named parameters in callable statements are supported. SqlNamedParametersSupported = 574, /// - /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. /// - /// Returns: - /// - false: if updates made to a LOB are made directly to the LOB; - /// - true: if updates made to a LOB are made on a copy. + /// Returns: + /// - false: if updates made to a LOB are made directly to the LOB; + /// - true: if updates made to a LOB are made on a copy. SqlLocatorsUpdateCopy = 575, /// - /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions - /// using the stored procedure escape syntax is supported. + /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions + /// using the stored procedure escape syntax is supported. /// - /// Returns: - /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; - /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + /// Returns: + /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. SqlStoredFunctionsUsingCallSyntaxSupported = 576, } impl SqlInfo { @@ -1010,87 +1654,309 @@ impl SqlInfo { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlInfo::FlightSqlServerName => "FLIGHT_SQL_SERVER_NAME", - SqlInfo::FlightSqlServerVersion => "FLIGHT_SQL_SERVER_VERSION", - SqlInfo::FlightSqlServerArrowVersion => "FLIGHT_SQL_SERVER_ARROW_VERSION", - SqlInfo::FlightSqlServerReadOnly => "FLIGHT_SQL_SERVER_READ_ONLY", - SqlInfo::SqlDdlCatalog => "SQL_DDL_CATALOG", - SqlInfo::SqlDdlSchema => "SQL_DDL_SCHEMA", - SqlInfo::SqlDdlTable => "SQL_DDL_TABLE", - SqlInfo::SqlIdentifierCase => "SQL_IDENTIFIER_CASE", - SqlInfo::SqlIdentifierQuoteChar => "SQL_IDENTIFIER_QUOTE_CHAR", - SqlInfo::SqlQuotedIdentifierCase => "SQL_QUOTED_IDENTIFIER_CASE", - SqlInfo::SqlAllTablesAreSelectable => "SQL_ALL_TABLES_ARE_SELECTABLE", - SqlInfo::SqlNullOrdering => "SQL_NULL_ORDERING", - SqlInfo::SqlKeywords => "SQL_KEYWORDS", - SqlInfo::SqlNumericFunctions => "SQL_NUMERIC_FUNCTIONS", - SqlInfo::SqlStringFunctions => "SQL_STRING_FUNCTIONS", - SqlInfo::SqlSystemFunctions => "SQL_SYSTEM_FUNCTIONS", - SqlInfo::SqlDatetimeFunctions => "SQL_DATETIME_FUNCTIONS", - SqlInfo::SqlSearchStringEscape => "SQL_SEARCH_STRING_ESCAPE", - SqlInfo::SqlExtraNameCharacters => "SQL_EXTRA_NAME_CHARACTERS", - SqlInfo::SqlSupportsColumnAliasing => "SQL_SUPPORTS_COLUMN_ALIASING", - SqlInfo::SqlNullPlusNullIsNull => "SQL_NULL_PLUS_NULL_IS_NULL", - SqlInfo::SqlSupportsConvert => "SQL_SUPPORTS_CONVERT", - SqlInfo::SqlSupportsTableCorrelationNames => "SQL_SUPPORTS_TABLE_CORRELATION_NAMES", - SqlInfo::SqlSupportsDifferentTableCorrelationNames => "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES", - SqlInfo::SqlSupportsExpressionsInOrderBy => "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY", - SqlInfo::SqlSupportsOrderByUnrelated => "SQL_SUPPORTS_ORDER_BY_UNRELATED", - SqlInfo::SqlSupportedGroupBy => "SQL_SUPPORTED_GROUP_BY", - SqlInfo::SqlSupportsLikeEscapeClause => "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE", - SqlInfo::SqlSupportsNonNullableColumns => "SQL_SUPPORTS_NON_NULLABLE_COLUMNS", - SqlInfo::SqlSupportedGrammar => "SQL_SUPPORTED_GRAMMAR", - SqlInfo::SqlAnsi92SupportedLevel => "SQL_ANSI92_SUPPORTED_LEVEL", - SqlInfo::SqlSupportsIntegrityEnhancementFacility => "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY", - SqlInfo::SqlOuterJoinsSupportLevel => "SQL_OUTER_JOINS_SUPPORT_LEVEL", - SqlInfo::SqlSchemaTerm => "SQL_SCHEMA_TERM", - SqlInfo::SqlProcedureTerm => "SQL_PROCEDURE_TERM", - SqlInfo::SqlCatalogTerm => "SQL_CATALOG_TERM", - SqlInfo::SqlCatalogAtStart => "SQL_CATALOG_AT_START", - SqlInfo::SqlSchemasSupportedActions => "SQL_SCHEMAS_SUPPORTED_ACTIONS", - SqlInfo::SqlCatalogsSupportedActions => "SQL_CATALOGS_SUPPORTED_ACTIONS", - SqlInfo::SqlSupportedPositionedCommands => "SQL_SUPPORTED_POSITIONED_COMMANDS", - SqlInfo::SqlSelectForUpdateSupported => "SQL_SELECT_FOR_UPDATE_SUPPORTED", - SqlInfo::SqlStoredProceduresSupported => "SQL_STORED_PROCEDURES_SUPPORTED", - SqlInfo::SqlSupportedSubqueries => "SQL_SUPPORTED_SUBQUERIES", - SqlInfo::SqlCorrelatedSubqueriesSupported => "SQL_CORRELATED_SUBQUERIES_SUPPORTED", - SqlInfo::SqlSupportedUnions => "SQL_SUPPORTED_UNIONS", - SqlInfo::SqlMaxBinaryLiteralLength => "SQL_MAX_BINARY_LITERAL_LENGTH", - SqlInfo::SqlMaxCharLiteralLength => "SQL_MAX_CHAR_LITERAL_LENGTH", - SqlInfo::SqlMaxColumnNameLength => "SQL_MAX_COLUMN_NAME_LENGTH", - SqlInfo::SqlMaxColumnsInGroupBy => "SQL_MAX_COLUMNS_IN_GROUP_BY", - SqlInfo::SqlMaxColumnsInIndex => "SQL_MAX_COLUMNS_IN_INDEX", - SqlInfo::SqlMaxColumnsInOrderBy => "SQL_MAX_COLUMNS_IN_ORDER_BY", - SqlInfo::SqlMaxColumnsInSelect => "SQL_MAX_COLUMNS_IN_SELECT", - SqlInfo::SqlMaxColumnsInTable => "SQL_MAX_COLUMNS_IN_TABLE", - SqlInfo::SqlMaxConnections => "SQL_MAX_CONNECTIONS", - SqlInfo::SqlMaxCursorNameLength => "SQL_MAX_CURSOR_NAME_LENGTH", - SqlInfo::SqlMaxIndexLength => "SQL_MAX_INDEX_LENGTH", - SqlInfo::SqlDbSchemaNameLength => "SQL_DB_SCHEMA_NAME_LENGTH", - SqlInfo::SqlMaxProcedureNameLength => "SQL_MAX_PROCEDURE_NAME_LENGTH", - SqlInfo::SqlMaxCatalogNameLength => "SQL_MAX_CATALOG_NAME_LENGTH", - SqlInfo::SqlMaxRowSize => "SQL_MAX_ROW_SIZE", - SqlInfo::SqlMaxRowSizeIncludesBlobs => "SQL_MAX_ROW_SIZE_INCLUDES_BLOBS", - SqlInfo::SqlMaxStatementLength => "SQL_MAX_STATEMENT_LENGTH", - SqlInfo::SqlMaxStatements => "SQL_MAX_STATEMENTS", - SqlInfo::SqlMaxTableNameLength => "SQL_MAX_TABLE_NAME_LENGTH", - SqlInfo::SqlMaxTablesInSelect => "SQL_MAX_TABLES_IN_SELECT", - SqlInfo::SqlMaxUsernameLength => "SQL_MAX_USERNAME_LENGTH", - SqlInfo::SqlDefaultTransactionIsolation => "SQL_DEFAULT_TRANSACTION_ISOLATION", - SqlInfo::SqlTransactionsSupported => "SQL_TRANSACTIONS_SUPPORTED", - SqlInfo::SqlSupportedTransactionsIsolationLevels => "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS", - SqlInfo::SqlDataDefinitionCausesTransactionCommit => "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT", - SqlInfo::SqlDataDefinitionsInTransactionsIgnored => "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED", - SqlInfo::SqlSupportedResultSetTypes => "SQL_SUPPORTED_RESULT_SET_TYPES", - SqlInfo::SqlSupportedConcurrenciesForResultSetUnspecified => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED", - SqlInfo::SqlSupportedConcurrenciesForResultSetForwardOnly => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY", - SqlInfo::SqlSupportedConcurrenciesForResultSetScrollSensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE", - SqlInfo::SqlSupportedConcurrenciesForResultSetScrollInsensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE", - SqlInfo::SqlBatchUpdatesSupported => "SQL_BATCH_UPDATES_SUPPORTED", - SqlInfo::SqlSavepointsSupported => "SQL_SAVEPOINTS_SUPPORTED", - SqlInfo::SqlNamedParametersSupported => "SQL_NAMED_PARAMETERS_SUPPORTED", - SqlInfo::SqlLocatorsUpdateCopy => "SQL_LOCATORS_UPDATE_COPY", - SqlInfo::SqlStoredFunctionsUsingCallSyntaxSupported => "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED", + Self::FlightSqlServerName => "FLIGHT_SQL_SERVER_NAME", + Self::FlightSqlServerVersion => "FLIGHT_SQL_SERVER_VERSION", + Self::FlightSqlServerArrowVersion => "FLIGHT_SQL_SERVER_ARROW_VERSION", + Self::FlightSqlServerReadOnly => "FLIGHT_SQL_SERVER_READ_ONLY", + Self::FlightSqlServerSql => "FLIGHT_SQL_SERVER_SQL", + Self::FlightSqlServerSubstrait => "FLIGHT_SQL_SERVER_SUBSTRAIT", + Self::FlightSqlServerSubstraitMinVersion => { + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION" + } + Self::FlightSqlServerSubstraitMaxVersion => { + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION" + } + Self::FlightSqlServerTransaction => "FLIGHT_SQL_SERVER_TRANSACTION", + Self::FlightSqlServerCancel => "FLIGHT_SQL_SERVER_CANCEL", + Self::FlightSqlServerBulkIngestion => "FLIGHT_SQL_SERVER_BULK_INGESTION", + Self::FlightSqlServerIngestTransactionsSupported => { + "FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED" + } + Self::FlightSqlServerStatementTimeout => { + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT" + } + Self::FlightSqlServerTransactionTimeout => { + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT" + } + Self::SqlDdlCatalog => "SQL_DDL_CATALOG", + Self::SqlDdlSchema => "SQL_DDL_SCHEMA", + Self::SqlDdlTable => "SQL_DDL_TABLE", + Self::SqlIdentifierCase => "SQL_IDENTIFIER_CASE", + Self::SqlIdentifierQuoteChar => "SQL_IDENTIFIER_QUOTE_CHAR", + Self::SqlQuotedIdentifierCase => "SQL_QUOTED_IDENTIFIER_CASE", + Self::SqlAllTablesAreSelectable => "SQL_ALL_TABLES_ARE_SELECTABLE", + Self::SqlNullOrdering => "SQL_NULL_ORDERING", + Self::SqlKeywords => "SQL_KEYWORDS", + Self::SqlNumericFunctions => "SQL_NUMERIC_FUNCTIONS", + Self::SqlStringFunctions => "SQL_STRING_FUNCTIONS", + Self::SqlSystemFunctions => "SQL_SYSTEM_FUNCTIONS", + Self::SqlDatetimeFunctions => "SQL_DATETIME_FUNCTIONS", + Self::SqlSearchStringEscape => "SQL_SEARCH_STRING_ESCAPE", + Self::SqlExtraNameCharacters => "SQL_EXTRA_NAME_CHARACTERS", + Self::SqlSupportsColumnAliasing => "SQL_SUPPORTS_COLUMN_ALIASING", + Self::SqlNullPlusNullIsNull => "SQL_NULL_PLUS_NULL_IS_NULL", + Self::SqlSupportsConvert => "SQL_SUPPORTS_CONVERT", + Self::SqlSupportsTableCorrelationNames => { + "SQL_SUPPORTS_TABLE_CORRELATION_NAMES" + } + Self::SqlSupportsDifferentTableCorrelationNames => { + "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES" + } + Self::SqlSupportsExpressionsInOrderBy => { + "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY" + } + Self::SqlSupportsOrderByUnrelated => "SQL_SUPPORTS_ORDER_BY_UNRELATED", + Self::SqlSupportedGroupBy => "SQL_SUPPORTED_GROUP_BY", + Self::SqlSupportsLikeEscapeClause => "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE", + Self::SqlSupportsNonNullableColumns => "SQL_SUPPORTS_NON_NULLABLE_COLUMNS", + Self::SqlSupportedGrammar => "SQL_SUPPORTED_GRAMMAR", + Self::SqlAnsi92SupportedLevel => "SQL_ANSI92_SUPPORTED_LEVEL", + Self::SqlSupportsIntegrityEnhancementFacility => { + "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY" + } + Self::SqlOuterJoinsSupportLevel => "SQL_OUTER_JOINS_SUPPORT_LEVEL", + Self::SqlSchemaTerm => "SQL_SCHEMA_TERM", + Self::SqlProcedureTerm => "SQL_PROCEDURE_TERM", + Self::SqlCatalogTerm => "SQL_CATALOG_TERM", + Self::SqlCatalogAtStart => "SQL_CATALOG_AT_START", + Self::SqlSchemasSupportedActions => "SQL_SCHEMAS_SUPPORTED_ACTIONS", + Self::SqlCatalogsSupportedActions => "SQL_CATALOGS_SUPPORTED_ACTIONS", + Self::SqlSupportedPositionedCommands => "SQL_SUPPORTED_POSITIONED_COMMANDS", + Self::SqlSelectForUpdateSupported => "SQL_SELECT_FOR_UPDATE_SUPPORTED", + Self::SqlStoredProceduresSupported => "SQL_STORED_PROCEDURES_SUPPORTED", + Self::SqlSupportedSubqueries => "SQL_SUPPORTED_SUBQUERIES", + Self::SqlCorrelatedSubqueriesSupported => { + "SQL_CORRELATED_SUBQUERIES_SUPPORTED" + } + Self::SqlSupportedUnions => "SQL_SUPPORTED_UNIONS", + Self::SqlMaxBinaryLiteralLength => "SQL_MAX_BINARY_LITERAL_LENGTH", + Self::SqlMaxCharLiteralLength => "SQL_MAX_CHAR_LITERAL_LENGTH", + Self::SqlMaxColumnNameLength => "SQL_MAX_COLUMN_NAME_LENGTH", + Self::SqlMaxColumnsInGroupBy => "SQL_MAX_COLUMNS_IN_GROUP_BY", + Self::SqlMaxColumnsInIndex => "SQL_MAX_COLUMNS_IN_INDEX", + Self::SqlMaxColumnsInOrderBy => "SQL_MAX_COLUMNS_IN_ORDER_BY", + Self::SqlMaxColumnsInSelect => "SQL_MAX_COLUMNS_IN_SELECT", + Self::SqlMaxColumnsInTable => "SQL_MAX_COLUMNS_IN_TABLE", + Self::SqlMaxConnections => "SQL_MAX_CONNECTIONS", + Self::SqlMaxCursorNameLength => "SQL_MAX_CURSOR_NAME_LENGTH", + Self::SqlMaxIndexLength => "SQL_MAX_INDEX_LENGTH", + Self::SqlDbSchemaNameLength => "SQL_DB_SCHEMA_NAME_LENGTH", + Self::SqlMaxProcedureNameLength => "SQL_MAX_PROCEDURE_NAME_LENGTH", + Self::SqlMaxCatalogNameLength => "SQL_MAX_CATALOG_NAME_LENGTH", + Self::SqlMaxRowSize => "SQL_MAX_ROW_SIZE", + Self::SqlMaxRowSizeIncludesBlobs => "SQL_MAX_ROW_SIZE_INCLUDES_BLOBS", + Self::SqlMaxStatementLength => "SQL_MAX_STATEMENT_LENGTH", + Self::SqlMaxStatements => "SQL_MAX_STATEMENTS", + Self::SqlMaxTableNameLength => "SQL_MAX_TABLE_NAME_LENGTH", + Self::SqlMaxTablesInSelect => "SQL_MAX_TABLES_IN_SELECT", + Self::SqlMaxUsernameLength => "SQL_MAX_USERNAME_LENGTH", + Self::SqlDefaultTransactionIsolation => "SQL_DEFAULT_TRANSACTION_ISOLATION", + Self::SqlTransactionsSupported => "SQL_TRANSACTIONS_SUPPORTED", + Self::SqlSupportedTransactionsIsolationLevels => { + "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS" + } + Self::SqlDataDefinitionCausesTransactionCommit => { + "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT" + } + Self::SqlDataDefinitionsInTransactionsIgnored => { + "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED" + } + Self::SqlSupportedResultSetTypes => "SQL_SUPPORTED_RESULT_SET_TYPES", + Self::SqlSupportedConcurrenciesForResultSetUnspecified => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED" + } + Self::SqlSupportedConcurrenciesForResultSetForwardOnly => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY" + } + Self::SqlSupportedConcurrenciesForResultSetScrollSensitive => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE" + } + Self::SqlSupportedConcurrenciesForResultSetScrollInsensitive => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE" + } + Self::SqlBatchUpdatesSupported => "SQL_BATCH_UPDATES_SUPPORTED", + Self::SqlSavepointsSupported => "SQL_SAVEPOINTS_SUPPORTED", + Self::SqlNamedParametersSupported => "SQL_NAMED_PARAMETERS_SUPPORTED", + Self::SqlLocatorsUpdateCopy => "SQL_LOCATORS_UPDATE_COPY", + Self::SqlStoredFunctionsUsingCallSyntaxSupported => { + "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FLIGHT_SQL_SERVER_NAME" => Some(Self::FlightSqlServerName), + "FLIGHT_SQL_SERVER_VERSION" => Some(Self::FlightSqlServerVersion), + "FLIGHT_SQL_SERVER_ARROW_VERSION" => Some(Self::FlightSqlServerArrowVersion), + "FLIGHT_SQL_SERVER_READ_ONLY" => Some(Self::FlightSqlServerReadOnly), + "FLIGHT_SQL_SERVER_SQL" => Some(Self::FlightSqlServerSql), + "FLIGHT_SQL_SERVER_SUBSTRAIT" => Some(Self::FlightSqlServerSubstrait), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION" => { + Some(Self::FlightSqlServerSubstraitMinVersion) + } + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION" => { + Some(Self::FlightSqlServerSubstraitMaxVersion) + } + "FLIGHT_SQL_SERVER_TRANSACTION" => Some(Self::FlightSqlServerTransaction), + "FLIGHT_SQL_SERVER_CANCEL" => Some(Self::FlightSqlServerCancel), + "FLIGHT_SQL_SERVER_BULK_INGESTION" => { + Some(Self::FlightSqlServerBulkIngestion) + } + "FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED" => { + Some(Self::FlightSqlServerIngestTransactionsSupported) + } + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT" => { + Some(Self::FlightSqlServerStatementTimeout) + } + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT" => { + Some(Self::FlightSqlServerTransactionTimeout) + } + "SQL_DDL_CATALOG" => Some(Self::SqlDdlCatalog), + "SQL_DDL_SCHEMA" => Some(Self::SqlDdlSchema), + "SQL_DDL_TABLE" => Some(Self::SqlDdlTable), + "SQL_IDENTIFIER_CASE" => Some(Self::SqlIdentifierCase), + "SQL_IDENTIFIER_QUOTE_CHAR" => Some(Self::SqlIdentifierQuoteChar), + "SQL_QUOTED_IDENTIFIER_CASE" => Some(Self::SqlQuotedIdentifierCase), + "SQL_ALL_TABLES_ARE_SELECTABLE" => Some(Self::SqlAllTablesAreSelectable), + "SQL_NULL_ORDERING" => Some(Self::SqlNullOrdering), + "SQL_KEYWORDS" => Some(Self::SqlKeywords), + "SQL_NUMERIC_FUNCTIONS" => Some(Self::SqlNumericFunctions), + "SQL_STRING_FUNCTIONS" => Some(Self::SqlStringFunctions), + "SQL_SYSTEM_FUNCTIONS" => Some(Self::SqlSystemFunctions), + "SQL_DATETIME_FUNCTIONS" => Some(Self::SqlDatetimeFunctions), + "SQL_SEARCH_STRING_ESCAPE" => Some(Self::SqlSearchStringEscape), + "SQL_EXTRA_NAME_CHARACTERS" => Some(Self::SqlExtraNameCharacters), + "SQL_SUPPORTS_COLUMN_ALIASING" => Some(Self::SqlSupportsColumnAliasing), + "SQL_NULL_PLUS_NULL_IS_NULL" => Some(Self::SqlNullPlusNullIsNull), + "SQL_SUPPORTS_CONVERT" => Some(Self::SqlSupportsConvert), + "SQL_SUPPORTS_TABLE_CORRELATION_NAMES" => { + Some(Self::SqlSupportsTableCorrelationNames) + } + "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES" => { + Some(Self::SqlSupportsDifferentTableCorrelationNames) + } + "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY" => { + Some(Self::SqlSupportsExpressionsInOrderBy) + } + "SQL_SUPPORTS_ORDER_BY_UNRELATED" => Some(Self::SqlSupportsOrderByUnrelated), + "SQL_SUPPORTED_GROUP_BY" => Some(Self::SqlSupportedGroupBy), + "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE" => Some(Self::SqlSupportsLikeEscapeClause), + "SQL_SUPPORTS_NON_NULLABLE_COLUMNS" => { + Some(Self::SqlSupportsNonNullableColumns) + } + "SQL_SUPPORTED_GRAMMAR" => Some(Self::SqlSupportedGrammar), + "SQL_ANSI92_SUPPORTED_LEVEL" => Some(Self::SqlAnsi92SupportedLevel), + "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY" => { + Some(Self::SqlSupportsIntegrityEnhancementFacility) + } + "SQL_OUTER_JOINS_SUPPORT_LEVEL" => Some(Self::SqlOuterJoinsSupportLevel), + "SQL_SCHEMA_TERM" => Some(Self::SqlSchemaTerm), + "SQL_PROCEDURE_TERM" => Some(Self::SqlProcedureTerm), + "SQL_CATALOG_TERM" => Some(Self::SqlCatalogTerm), + "SQL_CATALOG_AT_START" => Some(Self::SqlCatalogAtStart), + "SQL_SCHEMAS_SUPPORTED_ACTIONS" => Some(Self::SqlSchemasSupportedActions), + "SQL_CATALOGS_SUPPORTED_ACTIONS" => Some(Self::SqlCatalogsSupportedActions), + "SQL_SUPPORTED_POSITIONED_COMMANDS" => { + Some(Self::SqlSupportedPositionedCommands) + } + "SQL_SELECT_FOR_UPDATE_SUPPORTED" => Some(Self::SqlSelectForUpdateSupported), + "SQL_STORED_PROCEDURES_SUPPORTED" => Some(Self::SqlStoredProceduresSupported), + "SQL_SUPPORTED_SUBQUERIES" => Some(Self::SqlSupportedSubqueries), + "SQL_CORRELATED_SUBQUERIES_SUPPORTED" => { + Some(Self::SqlCorrelatedSubqueriesSupported) + } + "SQL_SUPPORTED_UNIONS" => Some(Self::SqlSupportedUnions), + "SQL_MAX_BINARY_LITERAL_LENGTH" => Some(Self::SqlMaxBinaryLiteralLength), + "SQL_MAX_CHAR_LITERAL_LENGTH" => Some(Self::SqlMaxCharLiteralLength), + "SQL_MAX_COLUMN_NAME_LENGTH" => Some(Self::SqlMaxColumnNameLength), + "SQL_MAX_COLUMNS_IN_GROUP_BY" => Some(Self::SqlMaxColumnsInGroupBy), + "SQL_MAX_COLUMNS_IN_INDEX" => Some(Self::SqlMaxColumnsInIndex), + "SQL_MAX_COLUMNS_IN_ORDER_BY" => Some(Self::SqlMaxColumnsInOrderBy), + "SQL_MAX_COLUMNS_IN_SELECT" => Some(Self::SqlMaxColumnsInSelect), + "SQL_MAX_COLUMNS_IN_TABLE" => Some(Self::SqlMaxColumnsInTable), + "SQL_MAX_CONNECTIONS" => Some(Self::SqlMaxConnections), + "SQL_MAX_CURSOR_NAME_LENGTH" => Some(Self::SqlMaxCursorNameLength), + "SQL_MAX_INDEX_LENGTH" => Some(Self::SqlMaxIndexLength), + "SQL_DB_SCHEMA_NAME_LENGTH" => Some(Self::SqlDbSchemaNameLength), + "SQL_MAX_PROCEDURE_NAME_LENGTH" => Some(Self::SqlMaxProcedureNameLength), + "SQL_MAX_CATALOG_NAME_LENGTH" => Some(Self::SqlMaxCatalogNameLength), + "SQL_MAX_ROW_SIZE" => Some(Self::SqlMaxRowSize), + "SQL_MAX_ROW_SIZE_INCLUDES_BLOBS" => Some(Self::SqlMaxRowSizeIncludesBlobs), + "SQL_MAX_STATEMENT_LENGTH" => Some(Self::SqlMaxStatementLength), + "SQL_MAX_STATEMENTS" => Some(Self::SqlMaxStatements), + "SQL_MAX_TABLE_NAME_LENGTH" => Some(Self::SqlMaxTableNameLength), + "SQL_MAX_TABLES_IN_SELECT" => Some(Self::SqlMaxTablesInSelect), + "SQL_MAX_USERNAME_LENGTH" => Some(Self::SqlMaxUsernameLength), + "SQL_DEFAULT_TRANSACTION_ISOLATION" => { + Some(Self::SqlDefaultTransactionIsolation) + } + "SQL_TRANSACTIONS_SUPPORTED" => Some(Self::SqlTransactionsSupported), + "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS" => { + Some(Self::SqlSupportedTransactionsIsolationLevels) + } + "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT" => { + Some(Self::SqlDataDefinitionCausesTransactionCommit) + } + "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED" => { + Some(Self::SqlDataDefinitionsInTransactionsIgnored) + } + "SQL_SUPPORTED_RESULT_SET_TYPES" => Some(Self::SqlSupportedResultSetTypes), + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED" => { + Some(Self::SqlSupportedConcurrenciesForResultSetUnspecified) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY" => { + Some(Self::SqlSupportedConcurrenciesForResultSetForwardOnly) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE" => { + Some(Self::SqlSupportedConcurrenciesForResultSetScrollSensitive) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE" => { + Some(Self::SqlSupportedConcurrenciesForResultSetScrollInsensitive) + } + "SQL_BATCH_UPDATES_SUPPORTED" => Some(Self::SqlBatchUpdatesSupported), + "SQL_SAVEPOINTS_SUPPORTED" => Some(Self::SqlSavepointsSupported), + "SQL_NAMED_PARAMETERS_SUPPORTED" => Some(Self::SqlNamedParametersSupported), + "SQL_LOCATORS_UPDATE_COPY" => Some(Self::SqlLocatorsUpdateCopy), + "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED" => { + Some(Self::SqlStoredFunctionsUsingCallSyntaxSupported) + } + _ => None, + } + } +} +/// The level of support for Flight SQL transaction RPCs. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedTransaction { + /// Unknown/not indicated/no support + None = 0, + /// Transactions, but not savepoints. + /// A savepoint is a mark within a transaction that can be individually + /// rolled back to. Not all databases support savepoints. + Transaction = 1, + /// Transactions and savepoints + Savepoint = 2, +} +impl SqlSupportedTransaction { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::None => "SQL_SUPPORTED_TRANSACTION_NONE", + Self::Transaction => "SQL_SUPPORTED_TRANSACTION_TRANSACTION", + Self::Savepoint => "SQL_SUPPORTED_TRANSACTION_SAVEPOINT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_SUPPORTED_TRANSACTION_NONE" => Some(Self::None), + "SQL_SUPPORTED_TRANSACTION_TRANSACTION" => Some(Self::Transaction), + "SQL_SUPPORTED_TRANSACTION_SAVEPOINT" => Some(Self::Savepoint), + _ => None, } } } @@ -1109,10 +1975,24 @@ impl SqlSupportedCaseSensitivity { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedCaseSensitivity::SqlCaseSensitivityUnknown => "SQL_CASE_SENSITIVITY_UNKNOWN", - SqlSupportedCaseSensitivity::SqlCaseSensitivityCaseInsensitive => "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE", - SqlSupportedCaseSensitivity::SqlCaseSensitivityUppercase => "SQL_CASE_SENSITIVITY_UPPERCASE", - SqlSupportedCaseSensitivity::SqlCaseSensitivityLowercase => "SQL_CASE_SENSITIVITY_LOWERCASE", + Self::SqlCaseSensitivityUnknown => "SQL_CASE_SENSITIVITY_UNKNOWN", + Self::SqlCaseSensitivityCaseInsensitive => { + "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE" + } + Self::SqlCaseSensitivityUppercase => "SQL_CASE_SENSITIVITY_UPPERCASE", + Self::SqlCaseSensitivityLowercase => "SQL_CASE_SENSITIVITY_LOWERCASE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_CASE_SENSITIVITY_UNKNOWN" => Some(Self::SqlCaseSensitivityUnknown), + "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE" => { + Some(Self::SqlCaseSensitivityCaseInsensitive) + } + "SQL_CASE_SENSITIVITY_UPPERCASE" => Some(Self::SqlCaseSensitivityUppercase), + "SQL_CASE_SENSITIVITY_LOWERCASE" => Some(Self::SqlCaseSensitivityLowercase), + _ => None, } } } @@ -1131,10 +2011,20 @@ impl SqlNullOrdering { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlNullOrdering::SqlNullsSortedHigh => "SQL_NULLS_SORTED_HIGH", - SqlNullOrdering::SqlNullsSortedLow => "SQL_NULLS_SORTED_LOW", - SqlNullOrdering::SqlNullsSortedAtStart => "SQL_NULLS_SORTED_AT_START", - SqlNullOrdering::SqlNullsSortedAtEnd => "SQL_NULLS_SORTED_AT_END", + Self::SqlNullsSortedHigh => "SQL_NULLS_SORTED_HIGH", + Self::SqlNullsSortedLow => "SQL_NULLS_SORTED_LOW", + Self::SqlNullsSortedAtStart => "SQL_NULLS_SORTED_AT_START", + Self::SqlNullsSortedAtEnd => "SQL_NULLS_SORTED_AT_END", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_NULLS_SORTED_HIGH" => Some(Self::SqlNullsSortedHigh), + "SQL_NULLS_SORTED_LOW" => Some(Self::SqlNullsSortedLow), + "SQL_NULLS_SORTED_AT_START" => Some(Self::SqlNullsSortedAtStart), + "SQL_NULLS_SORTED_AT_END" => Some(Self::SqlNullsSortedAtEnd), + _ => None, } } } @@ -1152,9 +2042,18 @@ impl SupportedSqlGrammar { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SupportedSqlGrammar::SqlMinimumGrammar => "SQL_MINIMUM_GRAMMAR", - SupportedSqlGrammar::SqlCoreGrammar => "SQL_CORE_GRAMMAR", - SupportedSqlGrammar::SqlExtendedGrammar => "SQL_EXTENDED_GRAMMAR", + Self::SqlMinimumGrammar => "SQL_MINIMUM_GRAMMAR", + Self::SqlCoreGrammar => "SQL_CORE_GRAMMAR", + Self::SqlExtendedGrammar => "SQL_EXTENDED_GRAMMAR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_MINIMUM_GRAMMAR" => Some(Self::SqlMinimumGrammar), + "SQL_CORE_GRAMMAR" => Some(Self::SqlCoreGrammar), + "SQL_EXTENDED_GRAMMAR" => Some(Self::SqlExtendedGrammar), + _ => None, } } } @@ -1172,9 +2071,18 @@ impl SupportedAnsi92SqlGrammarLevel { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SupportedAnsi92SqlGrammarLevel::Ansi92EntrySql => "ANSI92_ENTRY_SQL", - SupportedAnsi92SqlGrammarLevel::Ansi92IntermediateSql => "ANSI92_INTERMEDIATE_SQL", - SupportedAnsi92SqlGrammarLevel::Ansi92FullSql => "ANSI92_FULL_SQL", + Self::Ansi92EntrySql => "ANSI92_ENTRY_SQL", + Self::Ansi92IntermediateSql => "ANSI92_INTERMEDIATE_SQL", + Self::Ansi92FullSql => "ANSI92_FULL_SQL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "ANSI92_ENTRY_SQL" => Some(Self::Ansi92EntrySql), + "ANSI92_INTERMEDIATE_SQL" => Some(Self::Ansi92IntermediateSql), + "ANSI92_FULL_SQL" => Some(Self::Ansi92FullSql), + _ => None, } } } @@ -1192,9 +2100,18 @@ impl SqlOuterJoinsSupportLevel { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlOuterJoinsSupportLevel::SqlJoinsUnsupported => "SQL_JOINS_UNSUPPORTED", - SqlOuterJoinsSupportLevel::SqlLimitedOuterJoins => "SQL_LIMITED_OUTER_JOINS", - SqlOuterJoinsSupportLevel::SqlFullOuterJoins => "SQL_FULL_OUTER_JOINS", + Self::SqlJoinsUnsupported => "SQL_JOINS_UNSUPPORTED", + Self::SqlLimitedOuterJoins => "SQL_LIMITED_OUTER_JOINS", + Self::SqlFullOuterJoins => "SQL_FULL_OUTER_JOINS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_JOINS_UNSUPPORTED" => Some(Self::SqlJoinsUnsupported), + "SQL_LIMITED_OUTER_JOINS" => Some(Self::SqlLimitedOuterJoins), + "SQL_FULL_OUTER_JOINS" => Some(Self::SqlFullOuterJoins), + _ => None, } } } @@ -1211,8 +2128,16 @@ impl SqlSupportedGroupBy { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedGroupBy::SqlGroupByUnrelated => "SQL_GROUP_BY_UNRELATED", - SqlSupportedGroupBy::SqlGroupByBeyondSelect => "SQL_GROUP_BY_BEYOND_SELECT", + Self::SqlGroupByUnrelated => "SQL_GROUP_BY_UNRELATED", + Self::SqlGroupByBeyondSelect => "SQL_GROUP_BY_BEYOND_SELECT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_GROUP_BY_UNRELATED" => Some(Self::SqlGroupByUnrelated), + "SQL_GROUP_BY_BEYOND_SELECT" => Some(Self::SqlGroupByBeyondSelect), + _ => None, } } } @@ -1230,9 +2155,24 @@ impl SqlSupportedElementActions { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedElementActions::SqlElementInProcedureCalls => "SQL_ELEMENT_IN_PROCEDURE_CALLS", - SqlSupportedElementActions::SqlElementInIndexDefinitions => "SQL_ELEMENT_IN_INDEX_DEFINITIONS", - SqlSupportedElementActions::SqlElementInPrivilegeDefinitions => "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS", + Self::SqlElementInProcedureCalls => "SQL_ELEMENT_IN_PROCEDURE_CALLS", + Self::SqlElementInIndexDefinitions => "SQL_ELEMENT_IN_INDEX_DEFINITIONS", + Self::SqlElementInPrivilegeDefinitions => { + "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_ELEMENT_IN_PROCEDURE_CALLS" => Some(Self::SqlElementInProcedureCalls), + "SQL_ELEMENT_IN_INDEX_DEFINITIONS" => { + Some(Self::SqlElementInIndexDefinitions) + } + "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS" => { + Some(Self::SqlElementInPrivilegeDefinitions) + } + _ => None, } } } @@ -1249,8 +2189,16 @@ impl SqlSupportedPositionedCommands { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedPositionedCommands::SqlPositionedDelete => "SQL_POSITIONED_DELETE", - SqlSupportedPositionedCommands::SqlPositionedUpdate => "SQL_POSITIONED_UPDATE", + Self::SqlPositionedDelete => "SQL_POSITIONED_DELETE", + Self::SqlPositionedUpdate => "SQL_POSITIONED_UPDATE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_POSITIONED_DELETE" => Some(Self::SqlPositionedDelete), + "SQL_POSITIONED_UPDATE" => Some(Self::SqlPositionedUpdate), + _ => None, } } } @@ -1269,10 +2217,20 @@ impl SqlSupportedSubqueries { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedSubqueries::SqlSubqueriesInComparisons => "SQL_SUBQUERIES_IN_COMPARISONS", - SqlSupportedSubqueries::SqlSubqueriesInExists => "SQL_SUBQUERIES_IN_EXISTS", - SqlSupportedSubqueries::SqlSubqueriesInIns => "SQL_SUBQUERIES_IN_INS", - SqlSupportedSubqueries::SqlSubqueriesInQuantifieds => "SQL_SUBQUERIES_IN_QUANTIFIEDS", + Self::SqlSubqueriesInComparisons => "SQL_SUBQUERIES_IN_COMPARISONS", + Self::SqlSubqueriesInExists => "SQL_SUBQUERIES_IN_EXISTS", + Self::SqlSubqueriesInIns => "SQL_SUBQUERIES_IN_INS", + Self::SqlSubqueriesInQuantifieds => "SQL_SUBQUERIES_IN_QUANTIFIEDS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_SUBQUERIES_IN_COMPARISONS" => Some(Self::SqlSubqueriesInComparisons), + "SQL_SUBQUERIES_IN_EXISTS" => Some(Self::SqlSubqueriesInExists), + "SQL_SUBQUERIES_IN_INS" => Some(Self::SqlSubqueriesInIns), + "SQL_SUBQUERIES_IN_QUANTIFIEDS" => Some(Self::SqlSubqueriesInQuantifieds), + _ => None, } } } @@ -1289,8 +2247,16 @@ impl SqlSupportedUnions { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedUnions::SqlUnion => "SQL_UNION", - SqlSupportedUnions::SqlUnionAll => "SQL_UNION_ALL", + Self::SqlUnion => "SQL_UNION", + Self::SqlUnionAll => "SQL_UNION_ALL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_UNION" => Some(Self::SqlUnion), + "SQL_UNION_ALL" => Some(Self::SqlUnionAll), + _ => None, } } } @@ -1310,11 +2276,24 @@ impl SqlTransactionIsolationLevel { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlTransactionIsolationLevel::SqlTransactionNone => "SQL_TRANSACTION_NONE", - SqlTransactionIsolationLevel::SqlTransactionReadUncommitted => "SQL_TRANSACTION_READ_UNCOMMITTED", - SqlTransactionIsolationLevel::SqlTransactionReadCommitted => "SQL_TRANSACTION_READ_COMMITTED", - SqlTransactionIsolationLevel::SqlTransactionRepeatableRead => "SQL_TRANSACTION_REPEATABLE_READ", - SqlTransactionIsolationLevel::SqlTransactionSerializable => "SQL_TRANSACTION_SERIALIZABLE", + Self::SqlTransactionNone => "SQL_TRANSACTION_NONE", + Self::SqlTransactionReadUncommitted => "SQL_TRANSACTION_READ_UNCOMMITTED", + Self::SqlTransactionReadCommitted => "SQL_TRANSACTION_READ_COMMITTED", + Self::SqlTransactionRepeatableRead => "SQL_TRANSACTION_REPEATABLE_READ", + Self::SqlTransactionSerializable => "SQL_TRANSACTION_SERIALIZABLE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_TRANSACTION_NONE" => Some(Self::SqlTransactionNone), + "SQL_TRANSACTION_READ_UNCOMMITTED" => { + Some(Self::SqlTransactionReadUncommitted) + } + "SQL_TRANSACTION_READ_COMMITTED" => Some(Self::SqlTransactionReadCommitted), + "SQL_TRANSACTION_REPEATABLE_READ" => Some(Self::SqlTransactionRepeatableRead), + "SQL_TRANSACTION_SERIALIZABLE" => Some(Self::SqlTransactionSerializable), + _ => None, } } } @@ -1332,9 +2311,22 @@ impl SqlSupportedTransactions { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedTransactions::SqlTransactionUnspecified => "SQL_TRANSACTION_UNSPECIFIED", - SqlSupportedTransactions::SqlDataDefinitionTransactions => "SQL_DATA_DEFINITION_TRANSACTIONS", - SqlSupportedTransactions::SqlDataManipulationTransactions => "SQL_DATA_MANIPULATION_TRANSACTIONS", + Self::SqlTransactionUnspecified => "SQL_TRANSACTION_UNSPECIFIED", + Self::SqlDataDefinitionTransactions => "SQL_DATA_DEFINITION_TRANSACTIONS", + Self::SqlDataManipulationTransactions => "SQL_DATA_MANIPULATION_TRANSACTIONS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_TRANSACTION_UNSPECIFIED" => Some(Self::SqlTransactionUnspecified), + "SQL_DATA_DEFINITION_TRANSACTIONS" => { + Some(Self::SqlDataDefinitionTransactions) + } + "SQL_DATA_MANIPULATION_TRANSACTIONS" => { + Some(Self::SqlDataManipulationTransactions) + } + _ => None, } } } @@ -1353,10 +2345,28 @@ impl SqlSupportedResultSetType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedResultSetType::SqlResultSetTypeUnspecified => "SQL_RESULT_SET_TYPE_UNSPECIFIED", - SqlSupportedResultSetType::SqlResultSetTypeForwardOnly => "SQL_RESULT_SET_TYPE_FORWARD_ONLY", - SqlSupportedResultSetType::SqlResultSetTypeScrollInsensitive => "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE", - SqlSupportedResultSetType::SqlResultSetTypeScrollSensitive => "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE", + Self::SqlResultSetTypeUnspecified => "SQL_RESULT_SET_TYPE_UNSPECIFIED", + Self::SqlResultSetTypeForwardOnly => "SQL_RESULT_SET_TYPE_FORWARD_ONLY", + Self::SqlResultSetTypeScrollInsensitive => { + "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE" + } + Self::SqlResultSetTypeScrollSensitive => { + "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_RESULT_SET_TYPE_UNSPECIFIED" => Some(Self::SqlResultSetTypeUnspecified), + "SQL_RESULT_SET_TYPE_FORWARD_ONLY" => Some(Self::SqlResultSetTypeForwardOnly), + "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE" => { + Some(Self::SqlResultSetTypeScrollInsensitive) + } + "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE" => { + Some(Self::SqlResultSetTypeScrollSensitive) + } + _ => None, } } } @@ -1374,9 +2384,30 @@ impl SqlSupportedResultSetConcurrency { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUnspecified => "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED", - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyReadOnly => "SQL_RESULT_SET_CONCURRENCY_READ_ONLY", - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUpdatable => "SQL_RESULT_SET_CONCURRENCY_UPDATABLE", + Self::SqlResultSetConcurrencyUnspecified => { + "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED" + } + Self::SqlResultSetConcurrencyReadOnly => { + "SQL_RESULT_SET_CONCURRENCY_READ_ONLY" + } + Self::SqlResultSetConcurrencyUpdatable => { + "SQL_RESULT_SET_CONCURRENCY_UPDATABLE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED" => { + Some(Self::SqlResultSetConcurrencyUnspecified) + } + "SQL_RESULT_SET_CONCURRENCY_READ_ONLY" => { + Some(Self::SqlResultSetConcurrencyReadOnly) + } + "SQL_RESULT_SET_CONCURRENCY_UPDATABLE" => { + Some(Self::SqlResultSetConcurrencyUpdatable) + } + _ => None, } } } @@ -1411,26 +2442,354 @@ impl SqlSupportsConvert { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportsConvert::SqlConvertBigint => "SQL_CONVERT_BIGINT", - SqlSupportsConvert::SqlConvertBinary => "SQL_CONVERT_BINARY", - SqlSupportsConvert::SqlConvertBit => "SQL_CONVERT_BIT", - SqlSupportsConvert::SqlConvertChar => "SQL_CONVERT_CHAR", - SqlSupportsConvert::SqlConvertDate => "SQL_CONVERT_DATE", - SqlSupportsConvert::SqlConvertDecimal => "SQL_CONVERT_DECIMAL", - SqlSupportsConvert::SqlConvertFloat => "SQL_CONVERT_FLOAT", - SqlSupportsConvert::SqlConvertInteger => "SQL_CONVERT_INTEGER", - SqlSupportsConvert::SqlConvertIntervalDayTime => "SQL_CONVERT_INTERVAL_DAY_TIME", - SqlSupportsConvert::SqlConvertIntervalYearMonth => "SQL_CONVERT_INTERVAL_YEAR_MONTH", - SqlSupportsConvert::SqlConvertLongvarbinary => "SQL_CONVERT_LONGVARBINARY", - SqlSupportsConvert::SqlConvertLongvarchar => "SQL_CONVERT_LONGVARCHAR", - SqlSupportsConvert::SqlConvertNumeric => "SQL_CONVERT_NUMERIC", - SqlSupportsConvert::SqlConvertReal => "SQL_CONVERT_REAL", - SqlSupportsConvert::SqlConvertSmallint => "SQL_CONVERT_SMALLINT", - SqlSupportsConvert::SqlConvertTime => "SQL_CONVERT_TIME", - SqlSupportsConvert::SqlConvertTimestamp => "SQL_CONVERT_TIMESTAMP", - SqlSupportsConvert::SqlConvertTinyint => "SQL_CONVERT_TINYINT", - SqlSupportsConvert::SqlConvertVarbinary => "SQL_CONVERT_VARBINARY", - SqlSupportsConvert::SqlConvertVarchar => "SQL_CONVERT_VARCHAR", + Self::SqlConvertBigint => "SQL_CONVERT_BIGINT", + Self::SqlConvertBinary => "SQL_CONVERT_BINARY", + Self::SqlConvertBit => "SQL_CONVERT_BIT", + Self::SqlConvertChar => "SQL_CONVERT_CHAR", + Self::SqlConvertDate => "SQL_CONVERT_DATE", + Self::SqlConvertDecimal => "SQL_CONVERT_DECIMAL", + Self::SqlConvertFloat => "SQL_CONVERT_FLOAT", + Self::SqlConvertInteger => "SQL_CONVERT_INTEGER", + Self::SqlConvertIntervalDayTime => "SQL_CONVERT_INTERVAL_DAY_TIME", + Self::SqlConvertIntervalYearMonth => "SQL_CONVERT_INTERVAL_YEAR_MONTH", + Self::SqlConvertLongvarbinary => "SQL_CONVERT_LONGVARBINARY", + Self::SqlConvertLongvarchar => "SQL_CONVERT_LONGVARCHAR", + Self::SqlConvertNumeric => "SQL_CONVERT_NUMERIC", + Self::SqlConvertReal => "SQL_CONVERT_REAL", + Self::SqlConvertSmallint => "SQL_CONVERT_SMALLINT", + Self::SqlConvertTime => "SQL_CONVERT_TIME", + Self::SqlConvertTimestamp => "SQL_CONVERT_TIMESTAMP", + Self::SqlConvertTinyint => "SQL_CONVERT_TINYINT", + Self::SqlConvertVarbinary => "SQL_CONVERT_VARBINARY", + Self::SqlConvertVarchar => "SQL_CONVERT_VARCHAR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_CONVERT_BIGINT" => Some(Self::SqlConvertBigint), + "SQL_CONVERT_BINARY" => Some(Self::SqlConvertBinary), + "SQL_CONVERT_BIT" => Some(Self::SqlConvertBit), + "SQL_CONVERT_CHAR" => Some(Self::SqlConvertChar), + "SQL_CONVERT_DATE" => Some(Self::SqlConvertDate), + "SQL_CONVERT_DECIMAL" => Some(Self::SqlConvertDecimal), + "SQL_CONVERT_FLOAT" => Some(Self::SqlConvertFloat), + "SQL_CONVERT_INTEGER" => Some(Self::SqlConvertInteger), + "SQL_CONVERT_INTERVAL_DAY_TIME" => Some(Self::SqlConvertIntervalDayTime), + "SQL_CONVERT_INTERVAL_YEAR_MONTH" => Some(Self::SqlConvertIntervalYearMonth), + "SQL_CONVERT_LONGVARBINARY" => Some(Self::SqlConvertLongvarbinary), + "SQL_CONVERT_LONGVARCHAR" => Some(Self::SqlConvertLongvarchar), + "SQL_CONVERT_NUMERIC" => Some(Self::SqlConvertNumeric), + "SQL_CONVERT_REAL" => Some(Self::SqlConvertReal), + "SQL_CONVERT_SMALLINT" => Some(Self::SqlConvertSmallint), + "SQL_CONVERT_TIME" => Some(Self::SqlConvertTime), + "SQL_CONVERT_TIMESTAMP" => Some(Self::SqlConvertTimestamp), + "SQL_CONVERT_TINYINT" => Some(Self::SqlConvertTinyint), + "SQL_CONVERT_VARBINARY" => Some(Self::SqlConvertVarbinary), + "SQL_CONVERT_VARCHAR" => Some(Self::SqlConvertVarchar), + _ => None, + } + } +} +/// * +/// The JDBC/ODBC-defined type of any object. +/// All the values here are the same as in the JDBC and ODBC specs. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum XdbcDataType { + XdbcUnknownType = 0, + XdbcChar = 1, + XdbcNumeric = 2, + XdbcDecimal = 3, + XdbcInteger = 4, + XdbcSmallint = 5, + XdbcFloat = 6, + XdbcReal = 7, + XdbcDouble = 8, + XdbcDatetime = 9, + XdbcInterval = 10, + XdbcVarchar = 12, + XdbcDate = 91, + XdbcTime = 92, + XdbcTimestamp = 93, + XdbcLongvarchar = -1, + XdbcBinary = -2, + XdbcVarbinary = -3, + XdbcLongvarbinary = -4, + XdbcBigint = -5, + XdbcTinyint = -6, + XdbcBit = -7, + XdbcWchar = -8, + XdbcWvarchar = -9, +} +impl XdbcDataType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::XdbcUnknownType => "XDBC_UNKNOWN_TYPE", + Self::XdbcChar => "XDBC_CHAR", + Self::XdbcNumeric => "XDBC_NUMERIC", + Self::XdbcDecimal => "XDBC_DECIMAL", + Self::XdbcInteger => "XDBC_INTEGER", + Self::XdbcSmallint => "XDBC_SMALLINT", + Self::XdbcFloat => "XDBC_FLOAT", + Self::XdbcReal => "XDBC_REAL", + Self::XdbcDouble => "XDBC_DOUBLE", + Self::XdbcDatetime => "XDBC_DATETIME", + Self::XdbcInterval => "XDBC_INTERVAL", + Self::XdbcVarchar => "XDBC_VARCHAR", + Self::XdbcDate => "XDBC_DATE", + Self::XdbcTime => "XDBC_TIME", + Self::XdbcTimestamp => "XDBC_TIMESTAMP", + Self::XdbcLongvarchar => "XDBC_LONGVARCHAR", + Self::XdbcBinary => "XDBC_BINARY", + Self::XdbcVarbinary => "XDBC_VARBINARY", + Self::XdbcLongvarbinary => "XDBC_LONGVARBINARY", + Self::XdbcBigint => "XDBC_BIGINT", + Self::XdbcTinyint => "XDBC_TINYINT", + Self::XdbcBit => "XDBC_BIT", + Self::XdbcWchar => "XDBC_WCHAR", + Self::XdbcWvarchar => "XDBC_WVARCHAR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "XDBC_UNKNOWN_TYPE" => Some(Self::XdbcUnknownType), + "XDBC_CHAR" => Some(Self::XdbcChar), + "XDBC_NUMERIC" => Some(Self::XdbcNumeric), + "XDBC_DECIMAL" => Some(Self::XdbcDecimal), + "XDBC_INTEGER" => Some(Self::XdbcInteger), + "XDBC_SMALLINT" => Some(Self::XdbcSmallint), + "XDBC_FLOAT" => Some(Self::XdbcFloat), + "XDBC_REAL" => Some(Self::XdbcReal), + "XDBC_DOUBLE" => Some(Self::XdbcDouble), + "XDBC_DATETIME" => Some(Self::XdbcDatetime), + "XDBC_INTERVAL" => Some(Self::XdbcInterval), + "XDBC_VARCHAR" => Some(Self::XdbcVarchar), + "XDBC_DATE" => Some(Self::XdbcDate), + "XDBC_TIME" => Some(Self::XdbcTime), + "XDBC_TIMESTAMP" => Some(Self::XdbcTimestamp), + "XDBC_LONGVARCHAR" => Some(Self::XdbcLongvarchar), + "XDBC_BINARY" => Some(Self::XdbcBinary), + "XDBC_VARBINARY" => Some(Self::XdbcVarbinary), + "XDBC_LONGVARBINARY" => Some(Self::XdbcLongvarbinary), + "XDBC_BIGINT" => Some(Self::XdbcBigint), + "XDBC_TINYINT" => Some(Self::XdbcTinyint), + "XDBC_BIT" => Some(Self::XdbcBit), + "XDBC_WCHAR" => Some(Self::XdbcWchar), + "XDBC_WVARCHAR" => Some(Self::XdbcWvarchar), + _ => None, + } + } +} +/// * +/// Detailed subtype information for XDBC_TYPE_DATETIME and XDBC_TYPE_INTERVAL. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum XdbcDatetimeSubcode { + XdbcSubcodeUnknown = 0, + XdbcSubcodeYear = 1, + XdbcSubcodeTime = 2, + XdbcSubcodeTimestamp = 3, + XdbcSubcodeTimeWithTimezone = 4, + XdbcSubcodeTimestampWithTimezone = 5, + XdbcSubcodeSecond = 6, + XdbcSubcodeYearToMonth = 7, + XdbcSubcodeDayToHour = 8, + XdbcSubcodeDayToMinute = 9, + XdbcSubcodeDayToSecond = 10, + XdbcSubcodeHourToMinute = 11, + XdbcSubcodeHourToSecond = 12, + XdbcSubcodeMinuteToSecond = 13, + XdbcSubcodeIntervalYear = 101, + XdbcSubcodeIntervalMonth = 102, + XdbcSubcodeIntervalDay = 103, + XdbcSubcodeIntervalHour = 104, + XdbcSubcodeIntervalMinute = 105, + XdbcSubcodeIntervalSecond = 106, + XdbcSubcodeIntervalYearToMonth = 107, + XdbcSubcodeIntervalDayToHour = 108, + XdbcSubcodeIntervalDayToMinute = 109, + XdbcSubcodeIntervalDayToSecond = 110, + XdbcSubcodeIntervalHourToMinute = 111, + XdbcSubcodeIntervalHourToSecond = 112, + XdbcSubcodeIntervalMinuteToSecond = 113, +} +impl XdbcDatetimeSubcode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::XdbcSubcodeUnknown => "XDBC_SUBCODE_UNKNOWN", + Self::XdbcSubcodeYear => "XDBC_SUBCODE_YEAR", + Self::XdbcSubcodeTime => "XDBC_SUBCODE_TIME", + Self::XdbcSubcodeTimestamp => "XDBC_SUBCODE_TIMESTAMP", + Self::XdbcSubcodeTimeWithTimezone => "XDBC_SUBCODE_TIME_WITH_TIMEZONE", + Self::XdbcSubcodeTimestampWithTimezone => { + "XDBC_SUBCODE_TIMESTAMP_WITH_TIMEZONE" + } + Self::XdbcSubcodeSecond => "XDBC_SUBCODE_SECOND", + Self::XdbcSubcodeYearToMonth => "XDBC_SUBCODE_YEAR_TO_MONTH", + Self::XdbcSubcodeDayToHour => "XDBC_SUBCODE_DAY_TO_HOUR", + Self::XdbcSubcodeDayToMinute => "XDBC_SUBCODE_DAY_TO_MINUTE", + Self::XdbcSubcodeDayToSecond => "XDBC_SUBCODE_DAY_TO_SECOND", + Self::XdbcSubcodeHourToMinute => "XDBC_SUBCODE_HOUR_TO_MINUTE", + Self::XdbcSubcodeHourToSecond => "XDBC_SUBCODE_HOUR_TO_SECOND", + Self::XdbcSubcodeMinuteToSecond => "XDBC_SUBCODE_MINUTE_TO_SECOND", + Self::XdbcSubcodeIntervalYear => "XDBC_SUBCODE_INTERVAL_YEAR", + Self::XdbcSubcodeIntervalMonth => "XDBC_SUBCODE_INTERVAL_MONTH", + Self::XdbcSubcodeIntervalDay => "XDBC_SUBCODE_INTERVAL_DAY", + Self::XdbcSubcodeIntervalHour => "XDBC_SUBCODE_INTERVAL_HOUR", + Self::XdbcSubcodeIntervalMinute => "XDBC_SUBCODE_INTERVAL_MINUTE", + Self::XdbcSubcodeIntervalSecond => "XDBC_SUBCODE_INTERVAL_SECOND", + Self::XdbcSubcodeIntervalYearToMonth => "XDBC_SUBCODE_INTERVAL_YEAR_TO_MONTH", + Self::XdbcSubcodeIntervalDayToHour => "XDBC_SUBCODE_INTERVAL_DAY_TO_HOUR", + Self::XdbcSubcodeIntervalDayToMinute => "XDBC_SUBCODE_INTERVAL_DAY_TO_MINUTE", + Self::XdbcSubcodeIntervalDayToSecond => "XDBC_SUBCODE_INTERVAL_DAY_TO_SECOND", + Self::XdbcSubcodeIntervalHourToMinute => { + "XDBC_SUBCODE_INTERVAL_HOUR_TO_MINUTE" + } + Self::XdbcSubcodeIntervalHourToSecond => { + "XDBC_SUBCODE_INTERVAL_HOUR_TO_SECOND" + } + Self::XdbcSubcodeIntervalMinuteToSecond => { + "XDBC_SUBCODE_INTERVAL_MINUTE_TO_SECOND" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "XDBC_SUBCODE_UNKNOWN" => Some(Self::XdbcSubcodeUnknown), + "XDBC_SUBCODE_YEAR" => Some(Self::XdbcSubcodeYear), + "XDBC_SUBCODE_TIME" => Some(Self::XdbcSubcodeTime), + "XDBC_SUBCODE_TIMESTAMP" => Some(Self::XdbcSubcodeTimestamp), + "XDBC_SUBCODE_TIME_WITH_TIMEZONE" => Some(Self::XdbcSubcodeTimeWithTimezone), + "XDBC_SUBCODE_TIMESTAMP_WITH_TIMEZONE" => { + Some(Self::XdbcSubcodeTimestampWithTimezone) + } + "XDBC_SUBCODE_SECOND" => Some(Self::XdbcSubcodeSecond), + "XDBC_SUBCODE_YEAR_TO_MONTH" => Some(Self::XdbcSubcodeYearToMonth), + "XDBC_SUBCODE_DAY_TO_HOUR" => Some(Self::XdbcSubcodeDayToHour), + "XDBC_SUBCODE_DAY_TO_MINUTE" => Some(Self::XdbcSubcodeDayToMinute), + "XDBC_SUBCODE_DAY_TO_SECOND" => Some(Self::XdbcSubcodeDayToSecond), + "XDBC_SUBCODE_HOUR_TO_MINUTE" => Some(Self::XdbcSubcodeHourToMinute), + "XDBC_SUBCODE_HOUR_TO_SECOND" => Some(Self::XdbcSubcodeHourToSecond), + "XDBC_SUBCODE_MINUTE_TO_SECOND" => Some(Self::XdbcSubcodeMinuteToSecond), + "XDBC_SUBCODE_INTERVAL_YEAR" => Some(Self::XdbcSubcodeIntervalYear), + "XDBC_SUBCODE_INTERVAL_MONTH" => Some(Self::XdbcSubcodeIntervalMonth), + "XDBC_SUBCODE_INTERVAL_DAY" => Some(Self::XdbcSubcodeIntervalDay), + "XDBC_SUBCODE_INTERVAL_HOUR" => Some(Self::XdbcSubcodeIntervalHour), + "XDBC_SUBCODE_INTERVAL_MINUTE" => Some(Self::XdbcSubcodeIntervalMinute), + "XDBC_SUBCODE_INTERVAL_SECOND" => Some(Self::XdbcSubcodeIntervalSecond), + "XDBC_SUBCODE_INTERVAL_YEAR_TO_MONTH" => { + Some(Self::XdbcSubcodeIntervalYearToMonth) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_HOUR" => { + Some(Self::XdbcSubcodeIntervalDayToHour) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_MINUTE" => { + Some(Self::XdbcSubcodeIntervalDayToMinute) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalDayToSecond) + } + "XDBC_SUBCODE_INTERVAL_HOUR_TO_MINUTE" => { + Some(Self::XdbcSubcodeIntervalHourToMinute) + } + "XDBC_SUBCODE_INTERVAL_HOUR_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalHourToSecond) + } + "XDBC_SUBCODE_INTERVAL_MINUTE_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalMinuteToSecond) + } + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Nullable { + /// * + /// Indicates that the fields does not allow the use of null values. + NullabilityNoNulls = 0, + /// * + /// Indicates that the fields allow the use of null values. + NullabilityNullable = 1, + /// * + /// Indicates that nullability of the fields cannot be determined. + NullabilityUnknown = 2, +} +impl Nullable { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullabilityNoNulls => "NULLABILITY_NO_NULLS", + Self::NullabilityNullable => "NULLABILITY_NULLABLE", + Self::NullabilityUnknown => "NULLABILITY_UNKNOWN", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULLABILITY_NO_NULLS" => Some(Self::NullabilityNoNulls), + "NULLABILITY_NULLABLE" => Some(Self::NullabilityNullable), + "NULLABILITY_UNKNOWN" => Some(Self::NullabilityUnknown), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Searchable { + /// * + /// Indicates that column cannot be used in a WHERE clause. + None = 0, + /// * + /// Indicates that the column can be used in a WHERE clause if it is using a + /// LIKE operator. + Char = 1, + /// * + /// Indicates that the column can be used In a WHERE clause with any + /// operator other than LIKE. + /// + /// - Allowed operators: comparison, quantified comparison, BETWEEN, + /// DISTINCT, IN, MATCH, and UNIQUE. + Basic = 2, + /// * + /// Indicates that the column can be used in a WHERE clause using any operator. + Full = 3, +} +impl Searchable { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::None => "SEARCHABLE_NONE", + Self::Char => "SEARCHABLE_CHAR", + Self::Basic => "SEARCHABLE_BASIC", + Self::Full => "SEARCHABLE_FULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SEARCHABLE_NONE" => Some(Self::None), + "SEARCHABLE_CHAR" => Some(Self::Char), + "SEARCHABLE_BASIC" => Some(Self::Basic), + "SEARCHABLE_FULL" => Some(Self::Full), + _ => None, } } } @@ -1450,11 +2809,22 @@ impl UpdateDeleteRules { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - UpdateDeleteRules::Cascade => "CASCADE", - UpdateDeleteRules::Restrict => "RESTRICT", - UpdateDeleteRules::SetNull => "SET_NULL", - UpdateDeleteRules::NoAction => "NO_ACTION", - UpdateDeleteRules::SetDefault => "SET_DEFAULT", + Self::Cascade => "CASCADE", + Self::Restrict => "RESTRICT", + Self::SetNull => "SET_NULL", + Self::NoAction => "NO_ACTION", + Self::SetDefault => "SET_DEFAULT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CASCADE" => Some(Self::Cascade), + "RESTRICT" => Some(Self::Restrict), + "SET_NULL" => Some(Self::SetNull), + "NO_ACTION" => Some(Self::NoAction), + "SET_DEFAULT" => Some(Self::SetDefault), + _ => None, } } } diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs new file mode 100644 index 000000000000..ef52aa27ef50 --- /dev/null +++ b/arrow-flight/src/sql/client.rs @@ -0,0 +1,776 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A FlightSQL Client [`FlightSqlServiceClient`] + +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; +use std::collections::HashMap; +use std::str::FromStr; +use tonic::metadata::AsciiMetadataKey; + +use crate::decode::FlightRecordBatchStream; +use crate::encode::FlightDataEncoderBuilder; +use crate::error::FlightError; +use crate::flight_service_client::FlightServiceClient; +use crate::sql::gen::action_end_transaction_request::EndTransaction; +use crate::sql::server::{ + BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION, +}; +use crate::sql::{ + ActionBeginTransactionRequest, ActionBeginTransactionResult, + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate, + DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, +}; +use crate::streams::FallibleRequestStream; +use crate::trailers::extract_lazy_trailers; +use crate::{ + Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, PutResult, Ticket, +}; +use arrow_array::RecordBatch; +use arrow_buffer::Buffer; +use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::reader::read_record_batch; +use arrow_ipc::{root_as_message, MessageHeader}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; +use futures::{stream, Stream, TryStreamExt}; +use prost::Message; +use tonic::transport::Channel; +use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; + +/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data +/// by FlightSQL protocol. +#[derive(Debug, Clone)] +pub struct FlightSqlServiceClient { + token: Option, + headers: HashMap, + flight_client: FlightServiceClient, +} + +/// A FlightSql protocol client that can run queries against FlightSql servers +/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances. +/// Github issues are welcomed. +impl FlightSqlServiceClient { + /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel` + pub fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + pub fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + token: None, + flight_client: inner, + headers: HashMap::default(), + } + } + + /// Return a reference to the underlying [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.flight_client + } + + /// Return a mutable reference to the underlying [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.flight_client + } + + /// Consume this client and return the underlying [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { + self.flight_client + } + + /// Set auth token to the given value. + pub fn set_token(&mut self, token: String) { + self.token = Some(token); + } + + /// Clear the auth token. + pub fn clear_token(&mut self) { + self.token = None; + } + + /// Share the bearer token with potentially different `DoGet` clients + pub fn token(&self) -> Option<&String> { + self.token.as_ref() + } + + /// Set header value. + pub fn set_header(&mut self, key: impl Into, value: impl Into) { + let key: String = key.into(); + let value: String = value.into(); + self.headers.insert(key, value); + } + + async fn get_flight_info_for_command( + &mut self, + cmd: M, + ) -> Result { + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers(descriptor.into_request())?; + let fi = self + .flight_client + .get_flight_info(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + Ok(fi) + } + + /// Execute a query on the server. + pub async fn execute( + &mut self, + query: String, + transaction_id: Option, + ) -> Result { + let cmd = CommandStatementQuery { + query, + transaction_id, + }; + self.get_flight_info_for_command(cmd).await + } + + /// Perform a `handshake` with the server, passing credentials and establishing a session. + /// + /// If the server returns an "authorization" header, it is automatically parsed and set as + /// a token for future requests. Any other data returned by the server in the handshake + /// response is returned as a binary blob. + pub async fn handshake(&mut self, username: &str, password: &str) -> Result { + let cmd = HandshakeRequest { + protocol_version: 0, + payload: Default::default(), + }; + let mut req = tonic::Request::new(stream::iter(vec![cmd])); + let val = BASE64_STANDARD.encode(format!("{username}:{password}")); + let val = format!("Basic {val}") + .parse() + .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; + req.metadata_mut().insert("authorization", val); + let req = self.set_request_headers(req)?; + let resp = self + .flight_client + .handshake(req) + .await + .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; + if let Some(auth) = resp.metadata().get("authorization") { + let auth = auth + .to_str() + .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; + let bearer = "Bearer "; + if !auth.starts_with(bearer) { + Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + } + let auth = auth[bearer.len()..].to_string(); + self.token = Some(auth); + } + let responses: Vec = resp + .into_inner() + .try_collect() + .await + .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?; + let resp = match responses.as_slice() { + [resp] => resp.payload.clone(), + [] => Bytes::new(), + _ => Err(ArrowError::ParseError( + "Multiple handshake responses".to_string(), + ))?, + }; + Ok(resp) + } + + /// Execute a update query on the server, and return the number of records affected + pub async fn execute_update( + &mut self, + query: String, + transaction_id: Option, + ) -> Result { + let cmd = CommandStatementUpdate { + query, + transaction_id, + }; + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers( + stream::iter(vec![FlightData { + flight_descriptor: Some(descriptor), + ..Default::default() + }]) + .into_request(), + )?; + let mut result = self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + + /// Execute a bulk ingest on the server and return the number of records added + pub async fn execute_ingest( + &mut self, + command: CommandStatementIngest, + stream: S, + ) -> Result + where + S: Stream> + Send + 'static, + { + let (sender, receiver) = futures::channel::oneshot::channel(); + + let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); + let flight_data = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .build(stream); + + // Intercept client errors and send them to the one shot channel above + let flight_data = Box::pin(flight_data); + let flight_data: FallibleRequestStream = + FallibleRequestStream::new(sender, flight_data); + + let req = self.set_request_headers(flight_data.into_streaming_request())?; + let mut result = self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + + // check if the there were any errors in the input stream provided note + // if receiver.await fails, it means the sender was dropped and there is + // no message to return. + if let Ok(msg) = receiver.await { + return Err(ArrowError::ExternalError(Box::new(msg))); + } + + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + + /// Request a list of catalogs as tabular FlightInfo results + pub async fn get_catalogs(&mut self) -> Result { + self.get_flight_info_for_command(CommandGetCatalogs {}) + .await + } + + /// Request a list of database schemas as tabular FlightInfo results + pub async fn get_db_schemas( + &mut self, + request: CommandGetDbSchemas, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Given a flight ticket, request to be sent the stream. Returns record batch stream reader + pub async fn do_get( + &mut self, + ticket: impl IntoRequest, + ) -> Result { + let req = self.set_request_headers(ticket.into_request())?; + + let (md, response_stream, _ext) = self + .flight_client + .do_get(req) + .await + .map_err(status_to_arrow_error)? + .into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); + + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) + } + + /// Push a stream to the flight service associated with a particular flight stream. + pub async fn do_put( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> Result, ArrowError> { + let req = self.set_request_headers(request.into_streaming_request())?; + Ok(self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner()) + } + + /// DoAction allows a flight client to do a specific action against a flight service + pub async fn do_action( + &mut self, + request: impl IntoRequest, + ) -> Result, ArrowError> { + let req = self.set_request_headers(request.into_request())?; + Ok(self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner()) + } + + /// Request a list of tables. + pub async fn get_tables( + &mut self, + request: CommandGetTables, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Request the primary keys for a table. + pub async fn get_primary_keys( + &mut self, + request: CommandGetPrimaryKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves a description about the foreign key columns that reference the + /// primary key columns of the given table. + pub async fn get_exported_keys( + &mut self, + request: CommandGetExportedKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves the foreign key columns for the given table. + pub async fn get_imported_keys( + &mut self, + request: CommandGetImportedKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves a description of the foreign key columns in the given foreign key + /// table that reference the primary key or the columns representing a unique + /// constraint of the parent table (could be the same or a different table). + pub async fn get_cross_reference( + &mut self, + request: CommandGetCrossReference, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Request a list of table types. + pub async fn get_table_types(&mut self) -> Result { + self.get_flight_info_for_command(CommandGetTableTypes {}) + .await + } + + /// Request a list of SQL information. + pub async fn get_sql_info( + &mut self, + sql_infos: Vec, + ) -> Result { + let request = CommandGetSqlInfo { + info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(), + }; + self.get_flight_info_for_command(request).await + } + + /// Request XDBC SQL information. + pub async fn get_xdbc_type_info( + &mut self, + request: CommandGetXdbcTypeInfo, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Create a prepared statement object. + pub async fn prepare( + &mut self, + query: String, + transaction_id: Option, + ) -> Result, ArrowError> { + let cmd = ActionCreatePreparedStatementRequest { + query, + transaction_id, + }; + let action = Action { + r#type: CREATE_PREPARED_STATEMENT.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let mut result = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap(); + let dataset_schema = match prepared_result.dataset_schema.len() { + 0 => Schema::empty(), + _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?, + }; + let parameter_schema = match prepared_result.parameter_schema.len() { + 0 => Schema::empty(), + _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?, + }; + Ok(PreparedStatement::new( + self.clone(), + prepared_result.prepared_statement_handle, + dataset_schema, + parameter_schema, + )) + } + + /// Request to begin a transaction. + pub async fn begin_transaction(&mut self) -> Result { + let cmd = ActionBeginTransactionRequest {}; + let action = Action { + r#type: BEGIN_TRANSACTION.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let mut result = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap(); + Ok(begin_result.transaction_id) + } + + /// Request to commit/rollback a transaction. + pub async fn end_transaction( + &mut self, + transaction_id: Bytes, + action: EndTransaction, + ) -> Result<(), ArrowError> { + let cmd = ActionEndTransactionRequest { + transaction_id, + action: action as i32, + }; + let action = Action { + r#type: END_TRANSACTION.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let _ = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + Ok(()) + } + + /// Explicitly shut down and clean up the client. + pub async fn close(&mut self) -> Result<(), ArrowError> { + // TODO: consume self instead of &mut self to explicitly prevent reuse? + Ok(()) + } + + fn set_request_headers( + &self, + mut req: tonic::Request, + ) -> Result, ArrowError> { + for (k, v) in &self.headers { + let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) + })?; + let v = v.parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}")) + })?; + req.metadata_mut().insert(k, v); + } + if let Some(token) = &self.token { + let val = format!("Bearer {token}").parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert token to header value: {e}")) + })?; + req.metadata_mut().insert("authorization", val); + } + Ok(req) + } +} + +/// A PreparedStatement +#[derive(Debug, Clone)] +pub struct PreparedStatement { + flight_sql_client: FlightSqlServiceClient, + parameter_binding: Option, + handle: Bytes, + dataset_schema: Schema, + parameter_schema: Schema, +} + +impl PreparedStatement { + pub(crate) fn new( + flight_client: FlightSqlServiceClient, + handle: impl Into, + dataset_schema: Schema, + parameter_schema: Schema, + ) -> Self { + PreparedStatement { + flight_sql_client: flight_client, + parameter_binding: None, + handle: handle.into(), + dataset_schema, + parameter_schema, + } + } + + /// Executes the prepared statement query on the server. + pub async fn execute(&mut self) -> Result { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let result = self + .flight_sql_client + .get_flight_info_for_command(cmd) + .await?; + Ok(result) + } + + /// Executes the prepared statement update query on the server. + pub async fn execute_update(&mut self) -> Result { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementUpdate { + prepared_statement_handle: self.handle.clone(), + }; + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let mut result = self + .flight_sql_client + .do_put(stream::iter(vec![FlightData { + flight_descriptor: Some(descriptor), + ..Default::default() + }])) + .await?; + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + + /// Retrieve the parameter schema from the query. + pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> { + Ok(&self.parameter_schema) + } + + /// Retrieve the ResultSet schema from the query. + pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> { + Ok(&self.dataset_schema) + } + + /// Set a RecordBatch that contains the parameters that will be bind. + pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> { + self.parameter_binding = Some(parameter_binding); + Ok(()) + } + + /// Submit parameters to the server, if any have been set on this prepared statement instance + /// Updates our stored prepared statement handle with the handle given by the server response. + async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + if let Some(ref params_batch) = self.parameter_binding { + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let flight_stream_builder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .with_schema(params_batch.schema()); + let flight_data = flight_stream_builder + .build(futures::stream::iter( + self.parameter_binding.clone().map(Ok), + )) + .try_collect::>() + .await + .map_err(flight_error_to_arrow_error)?; + + // Attempt to update the stored handle with any updated handle in the DoPut result. + // Older servers do not respond with a result for DoPut, so skip this step when + // the stream closes with no response. + if let Some(result) = self + .flight_sql_client + .do_put(stream::iter(flight_data)) + .await? + .message() + .await + .map_err(status_to_arrow_error)? + { + if let Some(handle) = self.unpack_prepared_statement_handle(&result)? { + self.handle = handle; + } + } + } + Ok(()) + } + + /// Decodes the app_metadata stored in a [`PutResult`] as a + /// [`DoPutPreparedStatementResult`] and then returns + /// the inner prepared statement handle as [`Bytes`] + fn unpack_prepared_statement_handle( + &self, + put_result: &PutResult, + ) -> Result, ArrowError> { + let result: DoPutPreparedStatementResult = + Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?; + Ok(result.prepared_statement_handle) + } + + /// Close the prepared statement, so that this PreparedStatement can not used + /// anymore and server can free up any resources. + pub async fn close(mut self) -> Result<(), ArrowError> { + let cmd = ActionClosePreparedStatementRequest { + prepared_statement_handle: self.handle.clone(), + }; + let action = Action { + r#type: CLOSE_PREPARED_STATEMENT.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let _ = self.flight_sql_client.do_action(action).await?; + Ok(()) + } +} + +fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError { + ArrowError::IpcError(err.to_string()) +} + +fn status_to_arrow_error(status: tonic::Status) -> ArrowError { + ArrowError::IpcError(format!("{status:?}")) +} + +fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { + match err { + FlightError::Arrow(e) => e, + e => ArrowError::ExternalError(Box::new(e)), + } +} + +// A polymorphic structure to natively represent different types of data contained in `FlightData` +pub enum ArrowFlightData { + RecordBatch(RecordBatch), + Schema(Schema), +} + +/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation +pub fn arrow_data_from_flight_data( + flight_data: FlightData, + arrow_schema_ref: &SchemaRef, +) -> Result { + let ipc_message = root_as_message(&flight_data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; + + match ipc_message.header_type() { + MessageHeader::RecordBatch => { + let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a record batch".to_string(), + ) + })?; + + let dictionaries_by_field = HashMap::new(); + let record_batch = read_record_batch( + &Buffer::from_bytes(flight_data.data_body.into()), + ipc_record_batch, + arrow_schema_ref.clone(), + &dictionaries_by_field, + None, + &ipc_message.version(), + )?; + Ok(ArrowFlightData::RecordBatch(record_batch)) + } + MessageHeader::Schema => { + let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a schema".to_string(), + ) + })?; + + let arrow_schema = fb_to_schema(ipc_schema); + Ok(ArrowFlightData::Schema(arrow_schema)) + } + MessageHeader::DictionaryBatch => { + let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a dictionary batch".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(), + )) + } + MessageHeader::Tensor => { + let _ = ipc_message.header_as_tensor().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a tensor".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc tensor to an arrow type".to_string(), + )) + } + MessageHeader::SparseTensor => { + let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a sparse tensor".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(), + )) + } + _ => Err(ArrowError::ComputeError(format!( + "Unable to convert message with header_type: '{:?}' to arrow data", + ipc_message.header_type() + ))), + } +} diff --git a/arrow-flight/src/sql/metadata/catalogs.rs b/arrow-flight/src/sql/metadata/catalogs.rs new file mode 100644 index 000000000000..327fed81077b --- /dev/null +++ b/arrow-flight/src/sql/metadata/catalogs.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use once_cell::sync::Lazy; + +use crate::error::Result; +use crate::sql::CommandGetCatalogs; + +/// A builder for a [`CommandGetCatalogs`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +pub struct GetCatalogsBuilder { + catalogs: Vec, +} + +impl CommandGetCatalogs { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetCatalogsBuilder { + self.into() + } +} + +impl From for GetCatalogsBuilder { + fn from(_: CommandGetCatalogs) -> Self { + Self::new() + } +} + +impl Default for GetCatalogsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl GetCatalogsBuilder { + /// Create a new instance of [`GetCatalogsBuilder`] + pub fn new() -> Self { + Self { + catalogs: Vec::new(), + } + } + + /// Append a row + pub fn append(&mut self, catalog_name: impl Into) { + self.catalogs.push(catalog_name.into()); + } + + /// builds a `RecordBatch` with the correct schema for a + /// [`CommandGetCatalogs`] response + pub fn build(self) -> Result { + let Self { catalogs } = self; + + let batch = RecordBatch::try_new( + Arc::clone(&GET_CATALOG_SCHEMA), + vec![Arc::new(StringArray::from_iter_values(catalogs)) as _], + )?; + + Ok(batch) + } + + /// Returns the schema that will result from [`CommandGetCatalogs`] + /// + /// [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs + pub fn schema(&self) -> SchemaRef { + get_catalogs_schema() + } +} + +fn get_catalogs_schema() -> SchemaRef { + Arc::clone(&GET_CATALOG_SCHEMA) +} + +/// The schema for GetCatalogs +static GET_CATALOG_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![Field::new( + "catalog_name", + DataType::Utf8, + false, + )])) +}); diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs new file mode 100644 index 000000000000..303d11cd74ca --- /dev/null +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetDbSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. +//! +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + +use std::sync::Arc; + +use arrow_arith::boolean::and; +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetDbSchemas; + +/// A builder for a [`CommandGetDbSchemas`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +pub struct GetDbSchemasBuilder { + // Specifies the Catalog to search for the tables. + // - An empty string retrieves those without a catalog. + // - If omitted the catalog name is not used to narrow the search. + catalog_filter: Option, + // Optional filters to apply + db_schema_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for schema names + db_schema_name: StringBuilder, +} + +impl CommandGetDbSchemas { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetDbSchemasBuilder { + self.into() + } +} + +impl From for GetDbSchemasBuilder { + fn from(value: CommandGetDbSchemas) -> Self { + Self::new(value.catalog, value.db_schema_filter_pattern) + } +} + +impl GetDbSchemasBuilder { + /// Create a new instance of [`GetDbSchemasBuilder`] + /// + /// # Parameters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// + /// [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + ) -> Self { + Self { + catalog_filter: catalog.map(|v| v.into()), + db_schema_filter_pattern: db_schema_filter_pattern.map(|v| v.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + } + } + + /// Append a row + /// + /// In case the catalog should be considered as empty, pass in an empty string '""'. + pub fn append(&mut self, catalog_name: impl AsRef, schema_name: impl AsRef) { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetDbSchemas` response + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { + catalog_filter, + db_schema_filter_pattern, + mut catalog_name, + mut db_schema_name, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + + let mut filters = vec![]; + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) + } + + if let Some(catalog_filter_name) = catalog_filter { + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); + } + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + ], + )?; + + // Apply the filters if needed + let filtered_batch = if let Some(filter) = total_filter { + filter_record_batch(&batch, &filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name + let indices = lexsort_to_indices(filtered_batch.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(filtered_batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetDbSchemas`] + pub fn schema(&self) -> SchemaRef { + get_db_schemas_schema() + } +} + +fn get_db_schemas_schema() -> SchemaRef { + Arc::clone(&GET_DB_SCHEMAS_SCHEMA) +} + +/// The schema for GetDbSchemas +static GET_DB_SCHEMAS_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_db_schemas_schema(), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "b_schema", "a_schema", "b_schema", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + #[test] + fn test_schemas_are_filtered() { + let ref_batch = get_ref_batch(); + + let mut builder = GetDbSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch); + + let mut builder = GetDbSchemasBuilder::new(None::, Some("a%")); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![0, 2]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } + + #[test] + fn test_schemas_are_sorted() { + let ref_batch = get_ref_batch(); + + let mut builder = GetDbSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("a_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } + + #[test] + fn test_builder_from_query() { + let ref_batch = get_ref_batch(); + let query = CommandGetDbSchemas { + catalog: Some("a_catalog".into()), + db_schema_filter_pattern: Some("b%".into()), + }; + + let mut builder = query.into_builder(); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![1]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } +} diff --git a/arrow-flight/src/sql/metadata/mod.rs b/arrow-flight/src/sql/metadata/mod.rs new file mode 100644 index 000000000000..fd71149a3180 --- /dev/null +++ b/arrow-flight/src/sql/metadata/mod.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Builders and function for building responses to FlightSQL metadata +//! / information schema requests. +//! +//! - [`GetCatalogsBuilder`] for building responses to [`CommandGetCatalogs`] queries. +//! - [`GetDbSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. +//! - [`GetTablesBuilder`]for building responses to [`CommandGetTables`] queries. +//! - [`SqlInfoDataBuilder`]for building responses to [`CommandGetSqlInfo`] queries. +//! - [`XdbcTypeInfoDataBuilder`]for building responses to [`CommandGetXdbcTypeInfo`] queries. +//! +//! [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas +//! [`CommandGetTables`]: crate::sql::CommandGetTables +//! [`CommandGetSqlInfo`]: crate::sql::CommandGetSqlInfo +//! [`CommandGetXdbcTypeInfo`]: crate::sql::CommandGetXdbcTypeInfo + +mod catalogs; +mod db_schemas; +mod sql_info; +mod table_types; +mod tables; +mod xdbc_info; + +pub use catalogs::GetCatalogsBuilder; +pub use db_schemas::GetDbSchemasBuilder; +pub use sql_info::{SqlInfoData, SqlInfoDataBuilder}; +pub use tables::GetTablesBuilder; +pub use xdbc_info::{XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder}; + +use arrow_array::ArrayRef; +use arrow_array::UInt32Array; +use arrow_row::RowConverter; +use arrow_row::SortField; + +/// Helper function to sort all the columns in an array +fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { + let fields = arrays + .iter() + .map(|a| SortField::new(a.data_type().clone())) + .collect(); + let converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(arrays).unwrap(); + let mut sort: Vec<_> = rows.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32)) +} + +#[cfg(test)] +mod tests { + use arrow_array::RecordBatch; + use arrow_cast::pretty::pretty_format_batches; + pub fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { + let formatted = pretty_format_batches(batches).unwrap().to_string(); + let actual_lines: Vec<_> = formatted.trim().lines().collect(); + assert_eq!( + &actual_lines, expected_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + } +} diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs new file mode 100644 index 000000000000..97304d3c872d --- /dev/null +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for building responses to [`CommandGetSqlInfo`] metadata requests. +//! +//! - [`SqlInfoDataBuilder`] - a builder for collecting sql infos +//! and building a conformant `RecordBatch` with sql info server metadata. +//! - [`SqlInfoData`] - a helper type wrapping a `RecordBatch` +//! used for storing sql info server metadata. +//! - [`GetSqlInfoBuilder`] - a builder for consructing [`CommandGetSqlInfo`] responses. +//! + +use std::collections::{BTreeMap, HashMap}; +use std::sync::Arc; + +use arrow_arith::boolean::or; +use arrow_array::array::{Array, UInt32Array, UnionArray}; +use arrow_array::builder::{ + ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, + StringBuilder, UInt32Builder, +}; +use arrow_array::{RecordBatch, Scalar}; +use arrow_data::ArrayData; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, UnionFields, UnionMode}; +use arrow_select::filter::filter_record_batch; +use once_cell::sync::Lazy; + +use crate::error::Result; +use crate::sql::{CommandGetSqlInfo, SqlInfo}; + +/// Represents a dynamic value +#[derive(Debug, Clone, PartialEq)] +pub enum SqlInfoValue { + String(String), + Bool(bool), + BigInt(i64), + Bitmask(i32), + StringList(Vec), + ListMap(BTreeMap>), +} + +impl From<&str> for SqlInfoValue { + fn from(value: &str) -> Self { + Self::String(value.to_string()) + } +} + +impl From for SqlInfoValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for SqlInfoValue { + fn from(value: i32) -> Self { + Self::Bitmask(value) + } +} + +impl From for SqlInfoValue { + fn from(value: i64) -> Self { + Self::BigInt(value) + } +} + +impl From<&[&str]> for SqlInfoValue { + fn from(values: &[&str]) -> Self { + let values = values.iter().map(|s| s.to_string()).collect(); + Self::StringList(values) + } +} + +impl From> for SqlInfoValue { + fn from(values: Vec) -> Self { + Self::StringList(values) + } +} + +impl From>> for SqlInfoValue { + fn from(value: BTreeMap>) -> Self { + Self::ListMap(value) + } +} + +impl From>> for SqlInfoValue { + fn from(value: HashMap>) -> Self { + Self::ListMap(value.into_iter().collect()) + } +} + +impl From<&HashMap>> for SqlInfoValue { + fn from(value: &HashMap>) -> Self { + Self::ListMap( + value + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect(), + ) + } +} + +/// Something that can be converted into u32 (the represenation of a [`SqlInfo`] name) +pub trait SqlInfoName { + fn as_u32(&self) -> u32; +} + +impl SqlInfoName for SqlInfo { + fn as_u32(&self) -> u32 { + // SqlInfos are u32 in the flight spec, but for some reason + // SqlInfo repr is an i32, so convert between them + u32::try_from(i32::from(*self)).expect("SqlInfo fit into u32") + } +} + +// Allow passing u32 directly into to with_sql_info +impl SqlInfoName for u32 { + fn as_u32(&self) -> u32 { + *self + } +} + +/// Handles creating the dense [`UnionArray`] described by [flightsql] +/// +/// incrementally build types/offset of the dense union. See [Union Spec] for details. +/// +/// ```text +/// * value: dense_union< +/// * string_value: utf8, +/// * bool_value: bool, +/// * bigint_value: int64, +/// * int32_bitmask: int32, +/// * string_list: list +/// * int32_to_int32_list_map: map> +/// * > +/// ``` +///[flightsql]: https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/FlightSql.proto#L32-L43 +///[Union Spec]: https://arrow.apache.org/docs/format/Columnar.html#dense-union +struct SqlInfoUnionBuilder { + // Values for each child type + string_values: StringBuilder, + bool_values: BooleanBuilder, + bigint_values: Int64Builder, + int32_bitmask_values: Int32Builder, + string_list_values: ListBuilder, + int32_to_int32_list_map_values: MapBuilder>, + type_ids: Int8Builder, + offsets: Int32Builder, +} + +/// [`DataType`] for the output union array +static UNION_TYPE: Lazy = Lazy::new(|| { + let fields = vec![ + Field::new("string_value", DataType::Utf8, false), + Field::new("bool_value", DataType::Boolean, false), + Field::new("bigint_value", DataType::Int64, false), + Field::new("int32_bitmask", DataType::Int32, false), + // treat list as nullable b/c that is what the builders make + Field::new( + "string_list", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "int32_to_int32_list_map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new( + "values", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ])), + false, + )), + false, + ), + true, + ), + ]; + + // create "type ids", one for each type, assume they go from 0 .. num_fields + let type_ids: Vec = (0..fields.len()).map(|v| v as i8).collect(); + + DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) +}); + +impl SqlInfoUnionBuilder { + pub fn new() -> Self { + Self { + string_values: StringBuilder::new(), + bool_values: BooleanBuilder::new(), + bigint_values: Int64Builder::new(), + int32_bitmask_values: Int32Builder::new(), + string_list_values: ListBuilder::new(StringBuilder::new()), + int32_to_int32_list_map_values: MapBuilder::new( + None, + Int32Builder::new(), + ListBuilder::new(Int32Builder::new()), + ), + type_ids: Int8Builder::new(), + offsets: Int32Builder::new(), + } + } + + /// Returns the DataType created by this builder + pub fn schema() -> &'static DataType { + &UNION_TYPE + } + + /// Append the specified value to this builder + pub fn append_value(&mut self, v: &SqlInfoValue) -> Result<()> { + // typeid is which child and len is the child array's length + // *after* adding the value + let (type_id, len) = match v { + SqlInfoValue::String(v) => { + self.string_values.append_value(v); + (0, self.string_values.len()) + } + SqlInfoValue::Bool(v) => { + self.bool_values.append_value(*v); + (1, self.bool_values.len()) + } + SqlInfoValue::BigInt(v) => { + self.bigint_values.append_value(*v); + (2, self.bigint_values.len()) + } + SqlInfoValue::Bitmask(v) => { + self.int32_bitmask_values.append_value(*v); + (3, self.int32_bitmask_values.len()) + } + SqlInfoValue::StringList(values) => { + // build list + for v in values { + self.string_list_values.values().append_value(v); + } + // complete the list + self.string_list_values.append(true); + (4, self.string_list_values.len()) + } + SqlInfoValue::ListMap(values) => { + // build map + for (k, v) in values.clone() { + self.int32_to_int32_list_map_values.keys().append_value(k); + self.int32_to_int32_list_map_values + .values() + .append_value(v.into_iter().map(Some)); + } + // complete the list + self.int32_to_int32_list_map_values.append(true)?; + (5, self.int32_to_int32_list_map_values.len()) + } + }; + + self.type_ids.append_value(type_id); + let len = i32::try_from(len).expect("offset fit in i32"); + self.offsets.append_value(len - 1); + Ok(()) + } + + /// Complete the construction and build the [`UnionArray`] + pub fn finish(self) -> UnionArray { + let Self { + mut string_values, + mut bool_values, + mut bigint_values, + mut int32_bitmask_values, + mut string_list_values, + mut int32_to_int32_list_map_values, + mut type_ids, + mut offsets, + } = self; + let type_ids = type_ids.finish(); + let offsets = offsets.finish(); + + // form the correct ArrayData + + let len = offsets.len(); + let null_bit_buffer = None; + let offset = 0; + + let buffers = vec![ + type_ids.into_data().buffers()[0].clone(), + offsets.into_data().buffers()[0].clone(), + ]; + + let child_data = vec![ + string_values.finish().into_data(), + bool_values.finish().into_data(), + bigint_values.finish().into_data(), + int32_bitmask_values.finish().into_data(), + string_list_values.finish().into_data(), + int32_to_int32_list_map_values.finish().into_data(), + ]; + + let data = ArrayData::try_new( + UNION_TYPE.clone(), + len, + null_bit_buffer, + offset, + buffers, + child_data, + ) + .expect("Correctly created UnionArray"); + + UnionArray::from(data) + } +} + +/// Helper to create [`CommandGetSqlInfo`] responses. +/// +/// [`CommandGetSqlInfo`] are metadata requests used by a Flight SQL +/// server to communicate supported capabilities to Flight SQL clients. +/// +/// Servers constuct - usually static - [`SqlInfoData`] via the [`SqlInfoDataBuilder`], +/// and build responses using [`CommandGetSqlInfo::into_builder`] +#[derive(Debug, Clone, PartialEq)] +pub struct SqlInfoDataBuilder { + /// Use BTreeMap to ensure the values are sorted by value as + /// to make output consistent + /// + /// Use u32 to support "custom" sql info values that are not + /// part of the SqlInfo enum + infos: BTreeMap, +} + +impl Default for SqlInfoDataBuilder { + fn default() -> Self { + Self::new() + } +} + +impl SqlInfoDataBuilder { + pub fn new() -> Self { + Self { + infos: BTreeMap::new(), + } + } + + /// register the specific sql metadata item + pub fn append(&mut self, name: impl SqlInfoName, value: impl Into) { + self.infos.insert(name.as_u32(), value.into()); + } + + /// Encode the contents of this list according to the [FlightSQL spec] + /// + /// [FlightSQL spec]: https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/FlightSql.proto#L32-L43 + pub fn build(self) -> Result { + let mut name_builder = UInt32Builder::new(); + let mut value_builder = SqlInfoUnionBuilder::new(); + + let mut names: Vec<_> = self.infos.keys().cloned().collect(); + names.sort_unstable(); + + for key in names { + let (name, value) = self.infos.get_key_value(&key).unwrap(); + name_builder.append_value(*name); + value_builder.append_value(value)? + } + + let batch = RecordBatch::try_from_iter(vec![ + ("info_name", Arc::new(name_builder.finish()) as _), + ("value", Arc::new(value_builder.finish()) as _), + ])?; + + Ok(SqlInfoData { batch }) + } + + /// Return the [`Schema`] for a GetSchema RPC call with [`crate::sql::CommandGetSqlInfo`] + pub fn schema() -> &'static Schema { + // It is always the same + &SQL_INFO_SCHEMA + } +} + +/// A builder for [`SqlInfoData`] which is used to create [`CommandGetSqlInfo`] responses. +/// +/// # Example +/// ``` +/// # use arrow_flight::sql::{metadata::SqlInfoDataBuilder, SqlInfo, SqlSupportedTransaction}; +/// // Create the list of metadata describing the server +/// let mut builder = SqlInfoDataBuilder::new(); +/// builder.append(SqlInfo::FlightSqlServerName, "server name"); +/// // ... add other SqlInfo here .. +/// builder.append( +/// SqlInfo::FlightSqlServerTransaction, +/// SqlSupportedTransaction::Transaction as i32, +/// ); +/// +/// // Create the batch to send back to the client +/// let info_data = builder.build().unwrap(); +/// ``` +/// +/// [protos]: https://github.com/apache/arrow/blob/6d3d2fca2c9693231fa1e52c142ceef563fc23f9/format/FlightSql.proto#L71-L820 +pub struct SqlInfoData { + batch: RecordBatch, +} + +impl SqlInfoData { + /// Return a [`RecordBatch`] containing only the requested `u32`, if any + /// from [`CommandGetSqlInfo`] + pub fn record_batch(&self, info: impl IntoIterator) -> Result { + let arr = self.batch.column(0); + let type_filter = info + .into_iter() + .map(|tt| { + let s = UInt32Array::from(vec![tt]); + eq(arr, &Scalar::new(&s)) + }) + .collect::, _>>()? + .into_iter() + // We know the arrays are of same length as they are produced from the same root array + .reduce(|filter, arr| or(&filter, &arr).unwrap()); + if let Some(filter) = type_filter { + Ok(filter_record_batch(&self.batch, &filter)?) + } else { + Ok(self.batch.clone()) + } + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetSqlInfo`] + pub fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// A builder for a [`CommandGetSqlInfo`] response. +pub struct GetSqlInfoBuilder<'a> { + /// requested `SqlInfo`s. If empty means return all infos. + info: Vec, + infos: &'a SqlInfoData, +} + +impl CommandGetSqlInfo { + /// Create a builder suitable for constructing a response + pub fn into_builder(self, infos: &SqlInfoData) -> GetSqlInfoBuilder { + GetSqlInfoBuilder { + info: self.info, + infos, + } + } +} + +impl GetSqlInfoBuilder<'_> { + /// Builds a `RecordBatch` with the correct schema for a [`CommandGetSqlInfo`] response + pub fn build(self) -> Result { + self.infos.record_batch(self.info) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetSqlInfo`] + pub fn schema(&self) -> SchemaRef { + self.infos.schema() + } +} + +// The schema produced by [`SqlInfoData`] +static SQL_INFO_SCHEMA: Lazy = Lazy::new(|| { + Schema::new(vec![ + Field::new("info_name", DataType::UInt32, false), + Field::new("value", SqlInfoUnionBuilder::schema().clone(), false), + ]) +}); + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::SqlInfoDataBuilder; + use crate::sql::metadata::tests::assert_batches_eq; + use crate::sql::{SqlInfo, SqlNullOrdering, SqlSupportedTransaction, SqlSupportsConvert}; + + #[test] + fn test_sql_infos() { + let mut convert: HashMap> = HashMap::new(); + convert.insert( + SqlSupportsConvert::SqlConvertInteger as i32, + vec![ + SqlSupportsConvert::SqlConvertFloat as i32, + SqlSupportsConvert::SqlConvertReal as i32, + ], + ); + + let mut builder = SqlInfoDataBuilder::new(); + // str + builder.append(SqlInfo::SqlIdentifierQuoteChar, r#"""#); + // bool + builder.append(SqlInfo::SqlDdlCatalog, false); + // i32 + builder.append( + SqlInfo::SqlNullOrdering, + SqlNullOrdering::SqlNullsSortedHigh as i32, + ); + // i64 + builder.append(SqlInfo::SqlMaxBinaryLiteralLength, i32::MAX as i64); + // [str] + builder.append(SqlInfo::SqlKeywords, &["SELECT", "DELETE"] as &[&str]); + builder.append(SqlInfo::SqlSupportsConvert, &convert); + + let batch = builder.build().unwrap().record_batch(None).unwrap(); + + let expected = vec![ + "+-----------+----------------------------------------+", + "| info_name | value |", + "+-----------+----------------------------------------+", + "| 500 | {bool_value=false} |", + "| 504 | {string_value=\"} |", + "| 507 | {int32_bitmask=0} |", + "| 508 | {string_list=[SELECT, DELETE]} |", + "| 517 | {int32_to_int32_list_map={7: [6, 13]}} |", + "| 541 | {bigint_value=2147483647} |", + "+-----------+----------------------------------------+", + ]; + + assert_batches_eq(&[batch], &expected); + } + + #[test] + fn test_filter_sql_infos() { + let mut builder = SqlInfoDataBuilder::new(); + builder.append(SqlInfo::FlightSqlServerName, "server name"); + builder.append( + SqlInfo::FlightSqlServerTransaction, + SqlSupportedTransaction::Transaction as i32, + ); + let data = builder.build().unwrap(); + + let batch = data.record_batch(None).unwrap(); + assert_eq!(batch.num_rows(), 2); + + let batch = data + .record_batch([SqlInfo::FlightSqlServerTransaction as u32]) + .unwrap(); + let mut ref_builder = SqlInfoDataBuilder::new(); + ref_builder.append( + SqlInfo::FlightSqlServerTransaction, + SqlSupportedTransaction::Transaction as i32, + ); + let ref_batch = ref_builder.build().unwrap().record_batch(None).unwrap(); + + assert_eq!(batch, ref_batch); + } +} diff --git a/arrow-flight/src/sql/metadata/table_types.rs b/arrow-flight/src/sql/metadata/table_types.rs new file mode 100644 index 000000000000..54cfe6fe27a7 --- /dev/null +++ b/arrow-flight/src/sql/metadata/table_types.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetTableTypesBuilder`] for building responses to [`CommandGetTableTypes`] queries. +//! +//! [`CommandGetTableTypes`]: crate::sql::CommandGetTableTypes + +use std::sync::Arc; + +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::take::take; +use once_cell::sync::Lazy; + +use crate::error::*; +use crate::sql::CommandGetTableTypes; + +use super::lexsort_to_indices; + +/// A builder for a [`CommandGetTableTypes`] response. +/// +/// Builds rows like this: +/// +/// * table_type: utf8, +#[derive(Default)] +pub struct GetTableTypesBuilder { + // array builder for table types + table_type: StringBuilder, +} + +impl CommandGetTableTypes { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetTableTypesBuilder { + self.into() + } +} + +impl From for GetTableTypesBuilder { + fn from(_value: CommandGetTableTypes) -> Self { + Self::new() + } +} + +impl GetTableTypesBuilder { + /// Create a new instance of [`GetTableTypesBuilder`] + pub fn new() -> Self { + Self { + table_type: StringBuilder::new(), + } + } + + /// Append a row + pub fn append(&mut self, table_type: impl AsRef) { + self.table_type.append_value(table_type); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetTableTypes` response + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { mut table_type } = self; + + // Make the arrays + let table_type = table_type.finish(); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(table_type) as ArrayRef])?; + + // Order filtered results by table_type + let indices = lexsort_to_indices(batch.columns()); + let columns = batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetTableTypes`] + pub fn schema(&self) -> SchemaRef { + get_table_types_schema() + } +} + +fn get_table_types_schema() -> SchemaRef { + Arc::clone(&GET_TABLE_TYPES_SCHEMA) +} + +/// The schema for [`CommandGetTableTypes`]. +static GET_TABLE_TYPES_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![Field::new( + "table_type", + DataType::Utf8, + false, + )])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::StringArray; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_table_types_schema(), + vec![Arc::new(StringArray::from(vec![ + "a_table_type", + "b_table_type", + "c_table_type", + "d_table_type", + ])) as ArrayRef], + ) + .unwrap() + } + + #[test] + fn test_table_types_are_sorted() { + let ref_batch = get_ref_batch(); + + let mut builder = GetTableTypesBuilder::new(); + builder.append("b_table_type"); + builder.append("a_table_type"); + builder.append("d_table_type"); + builder.append("c_table_type"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } + + #[test] + fn test_builder_from_query() { + let ref_batch = get_ref_batch(); + let query = CommandGetTableTypes {}; + + let mut builder = query.into_builder(); + builder.append("a_table_type"); + builder.append("b_table_type"); + builder.append("c_table_type"); + builder.append("d_table_type"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } +} diff --git a/arrow-flight/src/sql/metadata/tables.rs b/arrow-flight/src/sql/metadata/tables.rs new file mode 100644 index 000000000000..7ffb76fa1d5f --- /dev/null +++ b/arrow-flight/src/sql/metadata/tables.rs @@ -0,0 +1,476 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetTablesBuilder`] for building responses to [`CommandGetTables`] queries. +//! +//! [`CommandGetTables`]: crate::sql::CommandGetTables + +use std::sync::Arc; + +use arrow_arith::boolean::{and, or}; +use arrow_array::builder::{BinaryBuilder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetTables; +use crate::{IpcMessage, IpcWriteOptions, SchemaAsIpc}; + +/// A builder for a [`CommandGetTables`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +/// * table_name: utf8 not null, +/// * table_type: utf8 not null, +/// * (optional) table_schema: bytes not null (schema of the table as described +/// in Schema.fbs::Schema it is serialized as an IPC message.) +pub struct GetTablesBuilder { + catalog_filter: Option, + table_types_filter: Vec, + // Optional filters to apply to schemas + db_schema_filter_pattern: Option, + // Optional filters to apply to tables + table_name_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for db schema names + db_schema_name: StringBuilder, + // array builder for tables names + table_name: StringBuilder, + // array builder for table types + table_type: StringBuilder, + // array builder for table schemas + table_schema: Option, +} + +impl CommandGetTables { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetTablesBuilder { + self.into() + } +} + +impl From for GetTablesBuilder { + fn from(value: CommandGetTables) -> Self { + Self::new( + value.catalog, + value.db_schema_filter_pattern, + value.table_name_filter_pattern, + value.table_types, + value.include_schema, + ) + } +} + +impl GetTablesBuilder { + /// Create a new instance of [`GetTablesBuilder`] + /// + /// # Parameters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_name_filter_pattern`: Specifies a filter pattern for tables to search for. + /// When no pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_types`: Specifies a filter of table types which must match. + /// An empy Vec matches all table types. + /// - `include_schema`: Specifies if the Arrow schema should be returned for found tables. + /// + /// [`CommandGetTables`]: crate::sql::CommandGetTables + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + table_name_filter_pattern: Option>, + table_types: impl IntoIterator>, + include_schema: bool, + ) -> Self { + let table_schema = if include_schema { + Some(BinaryBuilder::new()) + } else { + None + }; + Self { + catalog_filter: catalog.map(|s| s.into()), + table_types_filter: table_types.into_iter().map(|tt| tt.into()).collect(), + db_schema_filter_pattern: db_schema_filter_pattern.map(|s| s.into()), + table_name_filter_pattern: table_name_filter_pattern.map(|t| t.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + table_name: StringBuilder::new(), + table_type: StringBuilder::new(), + table_schema, + } + } + + /// Append a row + pub fn append( + &mut self, + catalog_name: impl AsRef, + schema_name: impl AsRef, + table_name: impl AsRef, + table_type: impl AsRef, + table_schema: &Schema, + ) -> Result<()> { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + self.table_name.append_value(table_name); + self.table_type.append_value(table_type); + if let Some(self_table_schema) = self.table_schema.as_mut() { + let options = IpcWriteOptions::default(); + // encode the schema into the correct form + let message: std::result::Result = + SchemaAsIpc::new(table_schema, &options).try_into(); + let IpcMessage(schema) = message?; + self_table_schema.append_value(schema); + } + + Ok(()) + } + + /// builds a `RecordBatch` for `CommandGetTables` + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { + catalog_filter, + table_types_filter, + db_schema_filter_pattern, + table_name_filter_pattern, + + mut catalog_name, + mut db_schema_name, + mut table_name, + mut table_type, + table_schema, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + let table_name = table_name.finish(); + let table_type = table_type.finish(); + let table_schema = table_schema.map(|mut table_schema| table_schema.finish()); + + // apply any filters, getting a BooleanArray that represents + // the rows that passed the filter + let mut filters = vec![]; + + if let Some(catalog_filter_name) = catalog_filter { + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); + } + + let tt_filter = table_types_filter + .into_iter() + .map(|tt| eq(&table_type, &StringArray::new_scalar(tt))) + .collect::, _>>()? + .into_iter() + // We know the arrays are of same length as they are produced fromn the same root array + .reduce(|filter, arr| or(&filter, &arr).unwrap()); + if let Some(filter) = tt_filter { + filters.push(filter); + } + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) + } + + if let Some(table_name_filter_pattern) = table_name_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(table_name_filter_pattern); + filters.push(like(&table_name, &scalar)?) + } + + let batch = if let Some(table_schema) = table_schema { + RecordBatch::try_new( + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + Arc::new(table_schema) as ArrayRef, + ], + ) + } else { + RecordBatch::try_new( + // schema is different if table_schema is none + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + ], + ) + }?; + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + // Apply the filters if needed + let filtered_batch = if let Some(total_filter) = total_filter { + filter_record_batch(&batch, &total_filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name, then table_name, then table_type + // https://github.com/apache/arrow/blob/130f9e981aa98c25de5f5bfe55185db270cec313/format/FlightSql.proto#LL1202C1-L1202C1 + let sort_cols = filtered_batch.project(&[0, 1, 2, 3])?; + let indices = lexsort_to_indices(sort_cols.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(filtered_batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned from [`CommandGetTables`] + /// + /// Note the schema differs based on the values of `include_schema + /// + /// [`CommandGetTables`]: crate::sql::CommandGetTables + pub fn schema(&self) -> SchemaRef { + get_tables_schema(self.include_schema()) + } + + /// Should the "schema" column be included + pub fn include_schema(&self) -> bool { + self.table_schema.is_some() + } +} + +fn get_tables_schema(include_schema: bool) -> SchemaRef { + if include_schema { + Arc::clone(&GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA) + } else { + Arc::clone(&GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA) + } +} + +/// The schema for GetTables without `table_schema` column +static GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])) +}); + +/// The schema for GetTables with `table_schema` column +static GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + Field::new("table_schema", DataType::Binary, false), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_tables_schema(false), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "a_schema", "b_schema", "b_schema", "a_schema", "a_schema", + "b_schema", "b_schema", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_table", "b_table", "a_table", "b_table", "a_table", "a_table", "b_table", + "b_table", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "TABLE", "TABLE", "TABLE", "TABLE", "TABLE", "VIEW", "TABLE", "VIEW", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + fn get_ref_builder( + catalog: Option<&str>, + db_schema_filter_pattern: Option<&str>, + table_name_filter_pattern: Option<&str>, + table_types: Vec<&str>, + include_schema: bool, + ) -> GetTablesBuilder { + let dummy_schema = Schema::empty(); + let tables = [ + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ]; + let mut builder = GetTablesBuilder::new( + catalog, + db_schema_filter_pattern, + table_name_filter_pattern, + table_types, + include_schema, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + builder + } + + #[test] + fn test_tables_are_filtered() { + let ref_batch = get_ref_batch(); + + let builder = get_ref_builder(None, None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + + let builder = get_ref_builder(None, Some("a%"), Some("a%"), Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 4, 5]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(Some("a_catalog"), None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 1, 2, 3]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(None, None, None, vec!["VIEW"], false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![5, 7]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + } + + #[test] + fn test_tables_are_sorted() { + let ref_batch = get_ref_batch(); + let dummy_schema = Schema::empty(); + + let tables = [ + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ]; + let mut builder = GetTablesBuilder::new( + None::, + None::, + None::, + None::, + false, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + } +} diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs new file mode 100644 index 000000000000..2e635d3037bc --- /dev/null +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for [`CommandGetXdbcTypeInfo`] metadata requests. +//! +//! - [`XdbcTypeInfo`] - a typed struct that holds the xdbc info corresponding to expected schema. +//! - [`XdbcTypeInfoDataBuilder`] - a builder for collecting type infos +//! and building a conformant `RecordBatch`. +//! - [`XdbcTypeInfoData`] - a helper type wrapping a `RecordBatch` +//! used for storing xdbc server metadata. +//! - [`GetXdbcTypeInfoBuilder`] - a builder for consructing [`CommandGetXdbcTypeInfo`] responses. +//! +use std::sync::Arc; + +use arrow_array::builder::{BooleanBuilder, Int32Builder, ListBuilder, StringBuilder}; +use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch, Scalar}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::filter::filter_record_batch; +use arrow_select::take::take; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::{CommandGetXdbcTypeInfo, Nullable, Searchable, XdbcDataType, XdbcDatetimeSubcode}; + +/// Data structure representing type information for xdbc types. +#[derive(Debug, Clone, Default)] +pub struct XdbcTypeInfo { + pub type_name: String, + pub data_type: XdbcDataType, + pub column_size: Option, + pub literal_prefix: Option, + pub literal_suffix: Option, + pub create_params: Option>, + pub nullable: Nullable, + pub case_sensitive: bool, + pub searchable: Searchable, + pub unsigned_attribute: Option, + pub fixed_prec_scale: bool, + pub auto_increment: Option, + pub local_type_name: Option, + pub minimum_scale: Option, + pub maximum_scale: Option, + pub sql_data_type: XdbcDataType, + pub datetime_subcode: Option, + pub num_prec_radix: Option, + pub interval_precision: Option, +} + +/// Helper to create [`CommandGetXdbcTypeInfo`] responses. +/// +/// [`CommandGetXdbcTypeInfo`] are metadata requests used by a Flight SQL +/// server to communicate supported capabilities to Flight SQL clients. +/// +/// Servers constuct - usually static - [`XdbcTypeInfoData`] via the [`XdbcTypeInfoDataBuilder`], +/// and build responses using [`CommandGetXdbcTypeInfo::into_builder`]. +pub struct XdbcTypeInfoData { + batch: RecordBatch, +} + +impl XdbcTypeInfoData { + /// Return the raw (not encoded) RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn record_batch(&self, data_type: impl Into>) -> Result { + if let Some(dt) = data_type.into() { + let scalar = Int32Array::from(vec![dt]); + let filter = eq(self.batch.column(1), &Scalar::new(&scalar))?; + Ok(filter_record_batch(&self.batch, &filter)?) + } else { + Ok(self.batch.clone()) + } + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +pub struct XdbcTypeInfoDataBuilder { + infos: Vec, +} + +impl Default for XdbcTypeInfoDataBuilder { + fn default() -> Self { + Self::new() + } +} + +/// A builder for [`XdbcTypeInfoData`] which is used to create [`CommandGetXdbcTypeInfo`] responses. +/// +/// # Example +/// ``` +/// use arrow_flight::sql::{Nullable, Searchable, XdbcDataType}; +/// use arrow_flight::sql::metadata::{XdbcTypeInfo, XdbcTypeInfoDataBuilder}; +/// // Create the list of metadata describing the server. Since this would not change at +/// // runtime, using once_cell::Lazy or similar patterns to constuct the list is a common approach. +/// let mut builder = XdbcTypeInfoDataBuilder::new(); +/// builder.append(XdbcTypeInfo { +/// type_name: "INTEGER".into(), +/// data_type: XdbcDataType::XdbcInteger, +/// column_size: Some(32), +/// literal_prefix: None, +/// literal_suffix: None, +/// create_params: None, +/// nullable: Nullable::NullabilityNullable, +/// case_sensitive: false, +/// searchable: Searchable::Full, +/// unsigned_attribute: Some(false), +/// fixed_prec_scale: false, +/// auto_increment: Some(false), +/// local_type_name: Some("INTEGER".into()), +/// minimum_scale: None, +/// maximum_scale: None, +/// sql_data_type: XdbcDataType::XdbcInteger, +/// datetime_subcode: None, +/// num_prec_radix: Some(2), +/// interval_precision: None, +/// }); +/// let info_list = builder.build().unwrap(); +/// +/// // to access the underlying record batch +/// let batch = info_list.record_batch(None); +/// ``` +impl XdbcTypeInfoDataBuilder { + /// Create a new instance of [`XdbcTypeInfoDataBuilder`]. + pub fn new() -> Self { + Self { infos: Vec::new() } + } + + /// Append a new row + pub fn append(&mut self, info: XdbcTypeInfo) { + self.infos.push(info); + } + + /// Create helper structure for handling xdbc metadata requests. + pub fn build(self) -> Result { + let mut type_name_builder = StringBuilder::new(); + let mut data_type_builder = Int32Builder::new(); + let mut column_size_builder = Int32Builder::new(); + let mut literal_prefix_builder = StringBuilder::new(); + let mut literal_suffix_builder = StringBuilder::new(); + let mut create_params_builder = ListBuilder::new(StringBuilder::new()); + let mut nullable_builder = Int32Builder::new(); + let mut case_sensitive_builder = BooleanBuilder::new(); + let mut searchable_builder = Int32Builder::new(); + let mut unsigned_attribute_builder = BooleanBuilder::new(); + let mut fixed_prec_scale_builder = BooleanBuilder::new(); + let mut auto_increment_builder = BooleanBuilder::new(); + let mut local_type_name_builder = StringBuilder::new(); + let mut minimum_scale_builder = Int32Builder::new(); + let mut maximum_scale_builder = Int32Builder::new(); + let mut sql_data_type_builder = Int32Builder::new(); + let mut datetime_subcode_builder = Int32Builder::new(); + let mut num_prec_radix_builder = Int32Builder::new(); + let mut interval_precision_builder = Int32Builder::new(); + + self.infos.into_iter().for_each(|info| { + type_name_builder.append_value(info.type_name); + data_type_builder.append_value(info.data_type as i32); + column_size_builder.append_option(info.column_size); + literal_prefix_builder.append_option(info.literal_prefix); + literal_suffix_builder.append_option(info.literal_suffix); + if let Some(params) = info.create_params { + if !params.is_empty() { + for param in params { + create_params_builder.values().append_value(param); + } + create_params_builder.append(true); + } else { + create_params_builder.append_null(); + } + } else { + create_params_builder.append_null(); + } + nullable_builder.append_value(info.nullable as i32); + case_sensitive_builder.append_value(info.case_sensitive); + searchable_builder.append_value(info.searchable as i32); + unsigned_attribute_builder.append_option(info.unsigned_attribute); + fixed_prec_scale_builder.append_value(info.fixed_prec_scale); + auto_increment_builder.append_option(info.auto_increment); + local_type_name_builder.append_option(info.local_type_name); + minimum_scale_builder.append_option(info.minimum_scale); + maximum_scale_builder.append_option(info.maximum_scale); + sql_data_type_builder.append_value(info.sql_data_type as i32); + datetime_subcode_builder.append_option(info.datetime_subcode.map(|code| code as i32)); + num_prec_radix_builder.append_option(info.num_prec_radix); + interval_precision_builder.append_option(info.interval_precision); + }); + + let type_name = Arc::new(type_name_builder.finish()); + let data_type = Arc::new(data_type_builder.finish()); + let column_size = Arc::new(column_size_builder.finish()); + let literal_prefix = Arc::new(literal_prefix_builder.finish()); + let literal_suffix = Arc::new(literal_suffix_builder.finish()); + let (field, offsets, values, nulls) = create_params_builder.finish().into_parts(); + // Re-defined the field to be non-nullable + let new_field = Arc::new(field.as_ref().clone().with_nullable(false)); + let create_params = Arc::new(ListArray::new(new_field, offsets, values, nulls)) as ArrayRef; + let nullable = Arc::new(nullable_builder.finish()); + let case_sensitive = Arc::new(case_sensitive_builder.finish()); + let searchable = Arc::new(searchable_builder.finish()); + let unsigned_attribute = Arc::new(unsigned_attribute_builder.finish()); + let fixed_prec_scale = Arc::new(fixed_prec_scale_builder.finish()); + let auto_increment = Arc::new(auto_increment_builder.finish()); + let local_type_name = Arc::new(local_type_name_builder.finish()); + let minimum_scale = Arc::new(minimum_scale_builder.finish()); + let maximum_scale = Arc::new(maximum_scale_builder.finish()); + let sql_data_type = Arc::new(sql_data_type_builder.finish()); + let datetime_subcode = Arc::new(datetime_subcode_builder.finish()); + let num_prec_radix = Arc::new(num_prec_radix_builder.finish()); + let interval_precision = Arc::new(interval_precision_builder.finish()); + + let batch = RecordBatch::try_new( + Arc::clone(&GET_XDBC_INFO_SCHEMA), + vec![ + type_name, + data_type, + column_size, + literal_prefix, + literal_suffix, + create_params, + nullable, + case_sensitive, + searchable, + unsigned_attribute, + fixed_prec_scale, + auto_increment, + local_type_name, + minimum_scale, + maximum_scale, + sql_data_type, + datetime_subcode, + num_prec_radix, + interval_precision, + ], + )?; + + // Order batch by data_type and then by type_name + let sort_cols = batch.project(&[1, 0])?; + let indices = lexsort_to_indices(sort_cols.columns()); + let columns = batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(XdbcTypeInfoData { + batch: RecordBatch::try_new(batch.schema(), columns)?, + }) + } + + /// Return the [`Schema`] for a GetSchema RPC call with [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + Arc::clone(&GET_XDBC_INFO_SCHEMA) + } +} + +/// A builder for a [`CommandGetXdbcTypeInfo`] response. +pub struct GetXdbcTypeInfoBuilder<'a> { + data_type: Option, + infos: &'a XdbcTypeInfoData, +} + +impl CommandGetXdbcTypeInfo { + /// Create a builder suitable for constructing a response + pub fn into_builder(self, infos: &XdbcTypeInfoData) -> GetXdbcTypeInfoBuilder { + GetXdbcTypeInfoBuilder { + data_type: self.data_type, + infos, + } + } +} + +impl GetXdbcTypeInfoBuilder<'_> { + /// Builds a `RecordBatch` with the correct schema for a [`CommandGetXdbcTypeInfo`] response + pub fn build(self) -> Result { + self.infos.record_batch(self.data_type) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + self.infos.schema() + } +} + +/// The schema for GetXdbcTypeInfo +static GET_XDBC_INFO_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("type_name", DataType::Utf8, false), + Field::new("data_type", DataType::Int32, false), + Field::new("column_size", DataType::Int32, true), + Field::new("literal_prefix", DataType::Utf8, true), + Field::new("literal_suffix", DataType::Utf8, true), + Field::new( + "create_params", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + Field::new("nullable", DataType::Int32, false), + Field::new("case_sensitive", DataType::Boolean, false), + Field::new("searchable", DataType::Int32, false), + Field::new("unsigned_attribute", DataType::Boolean, true), + Field::new("fixed_prec_scale", DataType::Boolean, false), + Field::new("auto_increment", DataType::Boolean, true), + Field::new("local_type_name", DataType::Utf8, true), + Field::new("minimum_scale", DataType::Int32, true), + Field::new("maximum_scale", DataType::Int32, true), + Field::new("sql_data_type", DataType::Int32, false), + Field::new("datetime_subcode", DataType::Int32, true), + Field::new("num_prec_radix", DataType::Int32, true), + Field::new("interval_precision", DataType::Int32, true), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use crate::sql::metadata::tests::assert_batches_eq; + + #[test] + fn test_create_batch() { + let mut builder = XdbcTypeInfoDataBuilder::new(); + builder.append(XdbcTypeInfo { + type_name: "VARCHAR".into(), + data_type: XdbcDataType::XdbcVarchar, + column_size: Some(i32::MAX), + literal_prefix: Some("'".into()), + literal_suffix: Some("'".into()), + create_params: Some(vec!["length".into()]), + nullable: Nullable::NullabilityNullable, + case_sensitive: true, + searchable: Searchable::Full, + unsigned_attribute: None, + fixed_prec_scale: false, + auto_increment: None, + local_type_name: Some("VARCHAR".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcVarchar, + datetime_subcode: None, + num_prec_radix: None, + interval_precision: None, + }); + builder.append(XdbcTypeInfo { + type_name: "INTEGER".into(), + data_type: XdbcDataType::XdbcInteger, + column_size: Some(32), + literal_prefix: None, + literal_suffix: None, + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: Some(false), + fixed_prec_scale: false, + auto_increment: Some(false), + local_type_name: Some("INTEGER".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInteger, + datetime_subcode: None, + num_prec_radix: Some(2), + interval_precision: None, + }); + builder.append(XdbcTypeInfo { + type_name: "INTERVAL".into(), + data_type: XdbcDataType::XdbcInterval, + column_size: Some(i32::MAX), + literal_prefix: Some("'".into()), + literal_suffix: Some("'".into()), + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: None, + fixed_prec_scale: false, + auto_increment: None, + local_type_name: Some("INTERVAL".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInterval, + datetime_subcode: Some(XdbcDatetimeSubcode::XdbcSubcodeUnknown), + num_prec_radix: None, + interval_precision: None, + }); + let infos = builder.build().unwrap(); + + let batch = infos.record_batch(None).unwrap(); + let expected = vec![ + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| type_name | data_type | column_size | literal_prefix | literal_suffix | create_params | nullable | case_sensitive | searchable | unsigned_attribute | fixed_prec_scale | auto_increment | local_type_name | minimum_scale | maximum_scale | sql_data_type | datetime_subcode | num_prec_radix | interval_precision |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| INTEGER | 4 | 32 | | | | 1 | false | 3 | false | false | false | INTEGER | | | 4 | | 2 | |", + "| INTERVAL | 10 | 2147483647 | ' | ' | | 1 | false | 3 | | false | | INTERVAL | | | 10 | 0 | | |", + "| VARCHAR | 12 | 2147483647 | ' | ' | [length] | 1 | true | 3 | | false | | VARCHAR | | | 12 | | | |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + ]; + assert_batches_eq(&[batch], &expected); + + let batch = infos.record_batch(Some(10)).unwrap(); + let expected = vec![ + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| type_name | data_type | column_size | literal_prefix | literal_suffix | create_params | nullable | case_sensitive | searchable | unsigned_attribute | fixed_prec_scale | auto_increment | local_type_name | minimum_scale | maximum_scale | sql_data_type | datetime_subcode | num_prec_radix | interval_precision |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| INTERVAL | 10 | 2147483647 | ' | ' | | 1 | false | 3 | | false | | INTERVAL | | | 10 | 0 | | |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + ]; + assert_batches_eq(&[batch], &expected); + } +} diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index cd198a1401d1..453f608d353a 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -15,17 +15,57 @@ // specific language governing permissions and limitations // under the License. -use arrow::error::{ArrowError, Result as ArrowResult}; +//! Support for execute SQL queries using [Apache Arrow] [Flight SQL]. +//! +//! [Flight SQL] is built on top of Arrow Flight RPC framework, by +//! defining specific messages, encoded using the protobuf format, +//! sent in the[`FlightDescriptor::cmd`] field to [`FlightService`] +//! endpoints such as[`get_flight_info`] and [`do_get`]. +//! +//! This module contains: +//! 1. [prost] generated structs for FlightSQL messages such as [`CommandStatementQuery`] +//! 2. Helpers for encoding and decoding FlightSQL messages: [`Any`] and [`Command`] +//! 3. A [`FlightSqlServiceClient`] for interacting with FlightSQL servers. +//! 4. A [`FlightSqlService`] to help building FlightSQL servers from [`FlightService`]. +//! 5. Helpers to build responses for FlightSQL metadata APIs: [`metadata`] +//! +//! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html +//! [Apache Arrow]: https://arrow.apache.org +//! [`FlightDescriptor::cmd`]: crate::FlightDescriptor::cmd +//! [`FlightService`]: crate::flight_service_server::FlightService +//! [`get_flight_info`]: crate::flight_service_server::FlightService::get_flight_info +//! [`do_get`]: crate::flight_service_server::FlightService::do_get +//! [`FlightSqlServiceClient`]: client::FlightSqlServiceClient +//! [`FlightSqlService`]: server::FlightSqlService +//! [`metadata`]: crate::sql::metadata +use arrow_schema::ArrowError; +use bytes::Bytes; +use paste::paste; use prost::Message; mod gen { #![allow(clippy::all)] + #![allow(rustdoc::unportable_markdown)] include!("arrow.flight.protocol.sql.rs"); } +pub use gen::action_end_transaction_request::EndTransaction; +pub use gen::command_statement_ingest::table_definition_options::{ + TableExistsOption, TableNotExistOption, +}; +pub use gen::command_statement_ingest::TableDefinitionOptions; +pub use gen::ActionBeginSavepointRequest; +pub use gen::ActionBeginSavepointResult; +pub use gen::ActionBeginTransactionRequest; +pub use gen::ActionBeginTransactionResult; +pub use gen::ActionCancelQueryRequest; +pub use gen::ActionCancelQueryResult; pub use gen::ActionClosePreparedStatementRequest; pub use gen::ActionCreatePreparedStatementRequest; pub use gen::ActionCreatePreparedStatementResult; +pub use gen::ActionCreatePreparedSubstraitPlanRequest; +pub use gen::ActionEndSavepointRequest; +pub use gen::ActionEndTransactionRequest; pub use gen::CommandGetCatalogs; pub use gen::CommandGetCrossReference; pub use gen::CommandGetDbSchemas; @@ -35,11 +75,17 @@ pub use gen::CommandGetPrimaryKeys; pub use gen::CommandGetSqlInfo; pub use gen::CommandGetTableTypes; pub use gen::CommandGetTables; +pub use gen::CommandGetXdbcTypeInfo; pub use gen::CommandPreparedStatementQuery; pub use gen::CommandPreparedStatementUpdate; +pub use gen::CommandStatementIngest; pub use gen::CommandStatementQuery; +pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; +pub use gen::DoPutPreparedStatementResult; pub use gen::DoPutUpdateResult; +pub use gen::Nullable; +pub use gen::Searchable; pub use gen::SqlInfo; pub use gen::SqlNullOrdering; pub use gen::SqlOuterJoinsSupportLevel; @@ -50,14 +96,20 @@ pub use gen::SqlSupportedPositionedCommands; pub use gen::SqlSupportedResultSetConcurrency; pub use gen::SqlSupportedResultSetType; pub use gen::SqlSupportedSubqueries; +pub use gen::SqlSupportedTransaction; pub use gen::SqlSupportedTransactions; pub use gen::SqlSupportedUnions; pub use gen::SqlSupportsConvert; pub use gen::SqlTransactionIsolationLevel; +pub use gen::SubstraitPlan; pub use gen::SupportedSqlGrammar; pub use gen::TicketStatementQuery; pub use gen::UpdateDeleteRules; +pub use gen::XdbcDataType; +pub use gen::XdbcDatetimeSubcode; +pub mod client; +pub mod metadata; pub mod server; /// ProstMessageExt are useful utility methods for prost::Message types @@ -65,34 +117,132 @@ pub trait ProstMessageExt: prost::Message + Default { /// type_url for this Message fn type_url() -> &'static str; - /// Convert this Message to prost_types::Any - fn as_any(&self) -> prost_types::Any; + /// Convert this Message to [`Any`] + fn as_any(&self) -> Any; +} + +/// Macro to coerce a token to an item, specifically +/// to build the `Commands` enum. +/// +/// See: +macro_rules! as_item { + ($i:item) => { + $i + }; } macro_rules! prost_message_ext { - ($($name:ty,)*) => { - $( - impl ProstMessageExt for $name { - fn type_url() -> &'static str { - concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)) + ($($name:tt,)*) => { + paste! { + $( + const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)); + )* + + as_item! { + /// Helper to convert to/from protobuf [`Any`] message + /// to a specific FlightSQL command message. + /// + /// # Example + /// ```rust + /// # use arrow_flight::sql::{Any, CommandStatementQuery, Command}; + /// let flightsql_message = CommandStatementQuery { + /// query: "SELECT * FROM foo".to_string(), + /// transaction_id: None, + /// }; + /// + /// // Given a packed FlightSQL Any message + /// let any_message = Any::pack(&flightsql_message).unwrap(); + /// + /// // decode it to Command: + /// match Command::try_from(any_message).unwrap() { + /// Command::CommandStatementQuery(decoded) => { + /// assert_eq!(flightsql_message, decoded); + /// } + /// _ => panic!("Unexpected decoded message"), + /// } + /// ``` + #[derive(Clone, Debug, PartialEq)] + pub enum Command { + $($name($name),)* + + /// Any message that is not any FlightSQL command. + Unknown(Any), + } + } + + impl Command { + /// Convert the command to [`Any`]. + pub fn into_any(self) -> Any { + match self { + $( + Self::$name(cmd) => cmd.as_any(), + )* + Self::Unknown(any) => any, + } + } + + /// Get the URL for the command. + pub fn type_url(&self) -> &str { + match self { + $( + Self::$name(_) => [<$name:snake:upper _TYPE_URL>], + )* + Self::Unknown(any) => any.type_url.as_str(), + } } + } + + impl TryFrom for Command { + type Error = ArrowError; - fn as_any(&self) -> prost_types::Any { - prost_types::Any { - type_url: <$name>::type_url().to_string(), - value: self.encode_to_vec(), + fn try_from(any: Any) -> Result { + match any.type_url.as_str() { + $( + [<$name:snake:upper _TYPE_URL>] + => { + let m: $name = Message::decode(&*any.value).map_err(|err| { + ArrowError::ParseError(format!("Unable to decode Any value: {err}")) + })?; + Ok(Self::$name(m)) + } + )* + _ => Ok(Self::Unknown(any)), } } } - )* + + $( + impl ProstMessageExt for $name { + fn type_url() -> &'static str { + [<$name:snake:upper _TYPE_URL>] + } + + fn as_any(&self) -> Any { + Any { + type_url: <$name>::type_url().to_string(), + value: self.encode_to_vec().into(), + } + } + } + )* + } }; } // Implement ProstMessageExt for all structs defined in FlightSql.proto prost_message_ext!( + ActionBeginSavepointRequest, + ActionBeginSavepointResult, + ActionBeginTransactionRequest, + ActionBeginTransactionResult, + ActionCancelQueryRequest, + ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, + ActionEndTransactionRequest, CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, @@ -102,48 +252,65 @@ prost_message_ext!( CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, + CommandStatementSubstraitPlan, CommandStatementUpdate, + DoPutPreparedStatementResult, DoPutUpdateResult, TicketStatementQuery, ); -/// ProstAnyExt are useful utility methods for prost_types::Any -/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs) -pub trait ProstAnyExt { - /// Check if `Any` contains a message of given type. - fn is(&self) -> bool; - - /// Extract a message from this `Any`. - /// - /// # Returns - /// - /// * `Ok(None)` when message type mismatch - /// * `Err` when parse failed - fn unpack(&self) -> ArrowResult>; - - /// Pack any message into `prost_types::Any` value. - fn pack(message: &M) -> ArrowResult; +/// An implementation of the protobuf [`Any`] message type +/// +/// Encoded protobuf messages are not self-describing, nor contain any information +/// on the schema of the encoded payload. Consequently to decode a protobuf a client +/// must know the exact schema of the message. +/// +/// This presents a problem for loosely typed APIs, where the exact message payloads +/// are not enumerable, and therefore cannot be enumerated as variants in a [oneof]. +/// +/// One solution is [`Any`] where the encoded payload is paired with a `type_url` +/// identifying the type of encoded message, and the resulting combination encoded. +/// +/// Clients can then decode the outer [`Any`], inspect the `type_url` and if it is +/// a type they recognise, proceed to decode the embedded message `value` +/// +/// [`Any`]: https://developers.google.com/protocol-buffers/docs/proto3#any +/// [oneof]: https://developers.google.com/protocol-buffers/docs/proto3#oneof +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Any { + /// A URL/resource name that uniquely identifies the type of the serialized + /// protocol buffer message. This string must contain at least + /// one "/" character. The last segment of the URL's path must represent + /// the fully qualified name of the type (as in + /// `path/google.protobuf.Duration`). The name should be in a canonical form + /// (e.g., leading "." is not accepted). + #[prost(string, tag = "1")] + pub type_url: String, + /// Must be a valid serialized protocol buffer of the above specified type. + #[prost(bytes = "bytes", tag = "2")] + pub value: Bytes, } -impl ProstAnyExt for prost_types::Any { - fn is(&self) -> bool { +impl Any { + pub fn is(&self) -> bool { M::type_url() == self.type_url } - fn unpack(&self) -> ArrowResult> { + pub fn unpack(&self) -> Result, ArrowError> { if !self.is::() { return Ok(None); } - let m = prost::Message::decode(&*self.value).map_err(|err| { - ArrowError::ParseError(format!("Unable to decode Any value: {}", err)) - })?; + let m = Message::decode(&*self.value) + .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?; Ok(Some(m)) } - fn pack(message: &M) -> ArrowResult { + pub fn pack(message: &M) -> Result { Ok(message.as_any()) } } @@ -165,14 +332,38 @@ mod tests { } #[test] - fn test_prost_any_pack_unpack() -> ArrowResult<()> { + fn test_prost_any_pack_unpack() { let query = CommandStatementQuery { query: "select 1".to_string(), + transaction_id: None, }; - let any = prost_types::Any::pack(&query)?; + let any = Any::pack(&query).unwrap(); assert!(any.is::()); - let unpack_query: CommandStatementQuery = any.unpack()?.unwrap(); + let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap(); assert_eq!(query, unpack_query); - Ok(()) + } + + #[test] + fn test_command() { + let query = CommandStatementQuery { + query: "select 1".to_string(), + transaction_id: None, + }; + let any = Any::pack(&query).unwrap(); + let cmd: Command = any.try_into().unwrap(); + + assert!(matches!(cmd, Command::CommandStatementQuery(_))); + assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL); + + // Unknown variant + + let any = Any { + type_url: "fake_url".to_string(), + value: Default::default(), + }; + + let cmd: Command = any.try_into().unwrap(); + assert!(matches!(cmd, Command::Unknown(_))); + assert_eq!(cmd.type_url(), "fake_url"); } } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index f3208d376497..37b2885b5aff 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -15,35 +15,45 @@ // specific language governing permissions and limitations // under the License. +//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL. + use std::pin::Pin; -use futures::Stream; +use futures::{stream::Peekable, Stream, StreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; use super::{ - super::{ - flight_service_server::FlightService, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, - }, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference, - CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo, - TicketStatementQuery, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, + SqlInfo, TicketStatementQuery, +}; +use crate::{ + flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, + SchemaResult, Ticket, }; -static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; -static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; +pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; +pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; +pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan"; +pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction"; +pub(crate) static END_TRANSACTION: &str = "EndTransaction"; +pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint"; +pub(crate) static END_SAVEPOINT: &str = "EndSavepoint"; +pub(crate) static CANCEL_QUERY: &str = "CancelQuery"; /// Implements FlightSqlService to handle the flight sql protocol #[tonic::async_trait] -pub trait FlightSqlService: - std::marker::Sync + std::marker::Send + std::marker::Sized + 'static -{ +pub trait FlightSqlService: Sync + Send + Sized + 'static { /// When impl FlightSqlService, you can always set FlightService to Self type FlightService: FlightService; @@ -65,7 +75,7 @@ pub trait FlightSqlService: async fn do_get_fallback( &self, _request: Request, - message: prost_types::Any, + message: Any, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented(format!( "do_get: The defined request is invalid: {}", @@ -76,197 +86,480 @@ pub trait FlightSqlService: /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, - query: CommandStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_statement has no default implementation", + )) + } + + /// Get a FlightInfo for executing a substrait plan. + async fn get_flight_info_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_substrait_plan has no default implementation", + )) + } /// Get a FlightInfo for executing an already created prepared statement. async fn get_flight_info_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_prepared_statement has no default implementation", + )) + } /// Get a FlightInfo for listing catalogs. async fn get_flight_info_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_catalogs has no default implementation", + )) + } /// Get a FlightInfo for listing schemas. async fn get_flight_info_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_schemas has no default implementation", + )) + } /// Get a FlightInfo for listing tables. async fn get_flight_info_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_tables has no default implementation", + )) + } /// Get a FlightInfo to extract information about the table types. async fn get_flight_info_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_table_types has no default implementation", + )) + } /// Get a FlightInfo for retrieving other information (See SqlInfo). async fn get_flight_info_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_sql_info has no default implementation", + )) + } /// Get a FlightInfo to extract information about primary and foreign keys. async fn get_flight_info_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_primary_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about exported keys. async fn get_flight_info_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_exported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about imported keys. async fn get_flight_info_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_imported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about cross reference. async fn get_flight_info_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_cross_reference has no default implementation", + )) + } + + /// Get a FlightInfo to extract information about the supported XDBC types. + async fn get_flight_info_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_xdbc_type_info has no default implementation", + )) + } + + /// Implementors may override to handle additional calls to get_flight_info() + async fn get_flight_info_fallback( + &self, + cmd: Command, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented(format!( + "get_flight_info: The defined request is invalid: {}", + cmd.type_url() + ))) + } // do_get /// Get a FlightDataStream containing the query results. async fn do_get_statement( &self, - ticket: TicketStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _ticket: TicketStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_prepared_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_catalogs has no default implementation", + )) + } /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_schemas has no default implementation", + )) + } /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_tables has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_table_types has no default implementation", + )) + } /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_sql_info has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_primary_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_exported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_imported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_cross_reference has no default implementation", + )) + } + + /// Get a FlightDataStream containing the data related to the supported XDBC types. + async fn do_get_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_xdbc_type_info has no default implementation", + )) + } // do_put + /// Implementors may override to handle additional calls to do_put() + async fn do_put_fallback( + &self, + _request: Request, + message: Any, + ) -> Result::DoPutStream>, Status> { + Err(Status::unimplemented(format!( + "do_put: The defined request is invalid: {}", + message.type_url + ))) + } + /// Execute an update SQL statement. async fn do_put_statement_update( &self, - ticket: CommandStatementUpdate, - request: Request>, - ) -> Result; + _ticket: CommandStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_update has no default implementation", + )) + } + + /// Execute a bulk ingestion. + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_ingest has no default implementation", + )) + } /// Bind parameters to given prepared statement. + /// + /// Returns an opaque handle that the client should pass + /// back to the server during subsequent requests with this + /// prepared statement. async fn do_put_prepared_statement_query( &self, - query: CommandPreparedStatementQuery, - request: Request>, - ) -> Result::DoPutStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_prepared_statement_query has no default implementation", + )) + } /// Execute an update SQL prepared statement. async fn do_put_prepared_statement_update( &self, - query: CommandPreparedStatementUpdate, - request: Request>, - ) -> Result; + _query: CommandPreparedStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_prepared_statement_update has no default implementation", + )) + } + + /// Execute a substrait plan + async fn do_put_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_substrait_plan has no default implementation", + )) + } // do_action + /// Implementors may override to handle additional calls to do_action() + async fn do_action_fallback( + &self, + request: Request, + ) -> Result::DoActionStream>, Status> { + Err(Status::invalid_argument(format!( + "do_action: The defined request is invalid: {:?}", + request.get_ref().r#type + ))) + } + + /// Add custom actions to list_actions() result + async fn list_custom_actions(&self) -> Option>> { + None + } + /// Create a prepared statement from given SQL statement. async fn do_action_create_prepared_statement( &self, - query: ActionCreatePreparedStatementRequest, - request: Request, - ) -> Result; + _query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_statement has no default implementation", + )) + } /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, - query: ActionClosePreparedStatementRequest, - request: Request, - ); + _query: ActionClosePreparedStatementRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_close_prepared_statement has no default implementation", + )) + } + + /// Create a prepared substrait plan. + async fn do_action_create_prepared_substrait_plan( + &self, + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_substrait_plan has no default implementation", + )) + } + + /// Begin a transaction + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_transaction has no default implementation", + )) + } + + /// End a transaction + async fn do_action_end_transaction( + &self, + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_transaction has no default implementation", + )) + } + + /// Begin a savepoint + async fn do_action_begin_savepoint( + &self, + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_savepoint has no default implementation", + )) + } + + /// End a savepoint + async fn do_action_end_savepoint( + &self, + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_savepoint has no default implementation", + )) + } + + /// Cancel a query + async fn do_action_cancel_query( + &self, + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_cancel_query has no default implementation", + )) + } + + /// do_exchange + + /// Implementors may override to handle additional calls to do_exchange() + async fn do_exchange_fallback( + &self, + _request: Request>, + ) -> Result::DoExchangeStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } /// Register a new SqlInfo result, making it available when calling GetSqlInfo. async fn register_sql_info(&self, id: i32, result: &SqlInfo); @@ -276,19 +569,16 @@ pub trait FlightSqlService: #[tonic::async_trait] impl FlightService for T where - T: FlightSqlService + std::marker::Send, + T: FlightSqlService + Send, { type HandshakeStream = Pin> + Send + 'static>>; type ListFlightsStream = Pin> + Send + 'static>>; - type DoGetStream = - Pin> + Send + 'static>>; - type DoPutStream = - Pin> + Send + 'static>>; - type DoActionStream = Pin< - Box> + Send + 'static>, - >; + type DoGetStream = Pin> + Send + 'static>>; + type DoPutStream = Pin> + Send + 'static>>; + type DoActionStream = + Pin> + Send + 'static>>; type ListActionsStream = Pin> + Send + 'static>>; type DoExchangeStream = @@ -313,93 +603,56 @@ where &self, request: Request, ) -> Result, Status> { - let message: prost_types::Any = - Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; + let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_statement(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self - .get_flight_info_prepared_statement(handle, request) - .await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_catalogs(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_schemas(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_tables(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_table_types(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_sql_info(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_primary_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_exported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_imported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_cross_reference(token, request).await; + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementQuery(token) => { + self.get_flight_info_statement(token, request).await + } + Command::CommandPreparedStatementQuery(handle) => { + self.get_flight_info_prepared_statement(handle, request) + .await + } + Command::CommandStatementSubstraitPlan(handle) => { + self.get_flight_info_substrait_plan(handle, request).await + } + Command::CommandGetCatalogs(token) => { + self.get_flight_info_catalogs(token, request).await + } + Command::CommandGetDbSchemas(token) => { + return self.get_flight_info_schemas(token, request).await + } + Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await, + Command::CommandGetTableTypes(token) => { + self.get_flight_info_table_types(token, request).await + } + Command::CommandGetSqlInfo(token) => { + self.get_flight_info_sql_info(token, request).await + } + Command::CommandGetPrimaryKeys(token) => { + self.get_flight_info_primary_keys(token, request).await + } + Command::CommandGetExportedKeys(token) => { + self.get_flight_info_exported_keys(token, request).await + } + Command::CommandGetImportedKeys(token) => { + self.get_flight_info_imported_keys(token, request).await + } + Command::CommandGetCrossReference(token) => { + self.get_flight_info_cross_reference(token, request).await + } + Command::CommandGetXdbcTypeInfo(token) => { + self.get_flight_info_xdbc_type_info(token, request).await + } + cmd => self.get_flight_info_fallback(cmd, request).await, } + } - Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {}", - message.type_url - ))) + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) } async fn get_schema( @@ -413,98 +666,101 @@ where &self, request: Request, ) -> Result, Status> { - let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) - .map_err(decode_error_to_status)?; + let msg: Any = + Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?; - fn unpack(msg: prost_types::Any) -> Result { - msg.unpack() - .map_err(arrow_error_to_status)? - .ok_or_else(|| Status::internal("Expected a command, but found none.")) - } - - if msg.is::() { - return self.do_get_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_prepared_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_catalogs(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_schemas(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_tables(unpack(msg)?, request).await; + match Command::try_from(msg).map_err(arrow_error_to_status)? { + Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await, + Command::CommandPreparedStatementQuery(command) => { + self.do_get_prepared_statement(command, request).await + } + Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await, + Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await, + Command::CommandGetTables(command) => self.do_get_tables(command, request).await, + Command::CommandGetTableTypes(command) => { + self.do_get_table_types(command, request).await + } + Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await, + Command::CommandGetPrimaryKeys(command) => { + self.do_get_primary_keys(command, request).await + } + Command::CommandGetExportedKeys(command) => { + self.do_get_exported_keys(command, request).await + } + Command::CommandGetImportedKeys(command) => { + self.do_get_imported_keys(command, request).await + } + Command::CommandGetCrossReference(command) => { + self.do_get_cross_reference(command, request).await + } + Command::CommandGetXdbcTypeInfo(command) => { + self.do_get_xdbc_type_info(command, request).await + } + cmd => self.do_get_fallback(request, cmd.into_any()).await, } - if msg.is::() { - return self.do_get_table_types(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_sql_info(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_primary_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_exported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_imported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_cross_reference(unpack(msg)?, request).await; - } - - self.do_get_fallback(request, msg).await } async fn do_put( &self, - mut request: Request>, + request: Request>, ) -> Result, Status> { - let cmd = request.get_mut().message().await?.unwrap(); - let message: prost_types::Any = - prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) - .map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self.do_put_statement_update(token, request).await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.encode_to_vec(), - })]); - return Ok(Response::new(Box::pin(output))); - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_put_prepared_statement_query(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self - .do_put_prepared_statement_update(handle, request) - .await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.encode_to_vec(), - })]); - return Ok(Response::new(Box::pin(output))); - } + // See issue #4658: https://github.com/apache/arrow-rs/issues/4658 + // To dispatch to the correct `do_put` method, we cannot discard the first message, + // as it may contain the Arrow schema, which the `do_put` handler may need. + // To allow the first message to be reused by the `do_put` handler, + // we wrap this stream in a `Peekable` one, which allows us to peek at + // the first message without discarding it. + let mut request = request.map(PeekableFlightDataStream::new); + let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; - Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {}", - message.type_url - ))) + let message = + Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?; + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementUpdate(command) => { + let record_count = self.do_put_statement_update(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandStatementIngest(command) => { + let record_count = self.do_put_statement_ingest(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandPreparedStatementQuery(command) => { + let result = self + .do_put_prepared_statement_query(command, request) + .await?; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandStatementSubstraitPlan(command) => { + let record_count = self.do_put_substrait_plan(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandPreparedStatementUpdate(command) => { + let record_count = self + .do_put_prepared_statement_update(command, request) + .await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + cmd => self.do_put_fallback(request, cmd.into_any()).await, + } } async fn list_actions( @@ -525,10 +781,63 @@ where Response Message: N/A" .into(), }; - let actions: Vec> = vec![ + let create_prepared_substrait_plan_action_type = ActionType { + r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(), + description: "Creates a reusable prepared substrait plan resource on the server.\n + Request Message: ActionCreatePreparedSubstraitPlanRequest\n + Response Message: ActionCreatePreparedStatementResult" + .into(), + }; + let begin_transaction_action_type = ActionType { + r#type: BEGIN_TRANSACTION.to_string(), + description: "Begins a transaction.\n + Request Message: ActionBeginTransactionRequest\n + Response Message: ActionBeginTransactionResult" + .into(), + }; + let end_transaction_action_type = ActionType { + r#type: END_TRANSACTION.to_string(), + description: "Ends a transaction\n + Request Message: ActionEndTransactionRequest\n + Response Message: N/A" + .into(), + }; + let begin_savepoint_action_type = ActionType { + r#type: BEGIN_SAVEPOINT.to_string(), + description: "Begins a savepoint.\n + Request Message: ActionBeginSavepointRequest\n + Response Message: ActionBeginSavepointResult" + .into(), + }; + let end_savepoint_action_type = ActionType { + r#type: END_SAVEPOINT.to_string(), + description: "Ends a savepoint\n + Request Message: ActionEndSavepointRequest\n + Response Message: N/A" + .into(), + }; + let cancel_query_action_type = ActionType { + r#type: CANCEL_QUERY.to_string(), + description: "Cancels a query\n + Request Message: ActionCancelQueryRequest\n + Response Message: ActionCancelQueryResult" + .into(), + }; + let mut actions: Vec> = vec![ Ok(create_prepared_statement_action_type), Ok(close_prepared_statement_action_type), + Ok(create_prepared_substrait_plan_action_type), + Ok(begin_transaction_action_type), + Ok(end_transaction_action_type), + Ok(begin_savepoint_action_type), + Ok(end_savepoint_action_type), + Ok(cancel_query_action_type), ]; + + if let Some(mut custom_actions) = self.list_custom_actions().await { + actions.append(&mut custom_actions); + } + let output = futures::stream::iter(actions); Ok(Response::new(Box::pin(output) as Self::ListActionsStream)) } @@ -538,8 +847,7 @@ where request: Request, ) -> Result, Status> { if request.get_ref().r#type == CREATE_PREPARED_STATEMENT { - let any: prost_types::Any = Message::decode(&*request.get_ref().body) - .map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedStatementRequest = any .unpack() @@ -553,13 +861,11 @@ where .do_action_create_prepared_statement(cmd, request) .await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { - body: stmt.as_any().encode_to_vec(), + body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); - } - if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { - let any: prost_types::Any = Message::decode(&*request.get_ref().body) - .map_err(decode_error_to_status)?; + } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionClosePreparedStatementRequest = any .unpack() @@ -569,28 +875,191 @@ where "Unable to unpack ActionClosePreparedStatementRequest.", ) })?; - self.do_action_close_prepared_statement(cmd, request).await; + self.do_action_close_prepared_statement(cmd, request) + .await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionCreatePreparedSubstraitPlanRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument( + "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.", + ) + })?; + self.do_action_create_prepared_substrait_plan(cmd, request) + .await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == BEGIN_TRANSACTION { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionBeginTransactionRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.") + })?; + let stmt = self.do_action_begin_transaction(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); + } else if request.get_ref().r#type == END_TRANSACTION { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionEndTransactionRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.") + })?; + self.do_action_end_transaction(cmd, request).await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == BEGIN_SAVEPOINT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionBeginSavepointRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.") + })?; + let stmt = self.do_action_begin_savepoint(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); + } else if request.get_ref().r#type == END_SAVEPOINT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionEndSavepointRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.") + })?; + self.do_action_end_savepoint(cmd, request).await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == CANCEL_QUERY { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionCancelQueryRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.") + })?; + let stmt = self.do_action_cancel_query(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); } - Err(Status::invalid_argument(format!( - "do_action: The defined request is invalid: {:?}", - request.get_ref().r#type - ))) + self.do_action_fallback(request).await } async fn do_exchange( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + self.do_exchange_fallback(request).await } } -fn decode_error_to_status(err: prost::DecodeError) -> tonic::Status { - tonic::Status::invalid_argument(format!("{:?}", err)) +fn decode_error_to_status(err: prost::DecodeError) -> Status { + Status::invalid_argument(format!("{err:?}")) +} + +fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status { + Status::internal(format!("{err:?}")) +} + +/// A wrapper around [`Streaming`] that allows "peeking" at the +/// message at the front of the stream without consuming it. +/// +/// This is needed because sometimes the first message in the stream will contain +/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic +/// must inspect this information. +/// +/// # Example +/// +/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without +/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream. +/// See the following example: +/// +/// ```no_run +/// use arrow_array::RecordBatch; +/// use arrow_flight::decode::FlightRecordBatchStream; +/// use arrow_flight::FlightDescriptor; +/// use arrow_flight::error::FlightError; +/// use arrow_flight::sql::server::PeekableFlightDataStream; +/// use tonic::{Request, Status}; +/// use futures::TryStreamExt; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Status> { +/// let request: Request = todo!(); +/// let stream: PeekableFlightDataStream = request.into_inner(); +/// +/// // The first message contains the flight descriptor and the schema. +/// // Read the flight descriptor without discarding the schema: +/// let flight_descriptor: FlightDescriptor = stream +/// .peek() +/// .await +/// .cloned() +/// .transpose()? +/// .and_then(|data| data.flight_descriptor) +/// .expect("first message should contain flight descriptor"); +/// +/// // Pass the stream through a decoder +/// let batches: Vec = FlightRecordBatchStream::new_from_flight_data( +/// request.into_inner().map_err(|e| e.into()), +/// ) +/// .try_collect() +/// .await?; +/// } +/// ``` +pub struct PeekableFlightDataStream { + inner: Peekable>, } -fn arrow_error_to_status(err: arrow::error::ArrowError) -> tonic::Status { - tonic::Status::internal(format!("{:?}", err)) +impl PeekableFlightDataStream { + fn new(stream: Streaming) -> Self { + Self { + inner: stream.peekable(), + } + } + + /// Convert this stream into a `Streaming`. + /// Any messages observed through [`Self::peek`] will be lost + /// after the conversion. + pub fn into_inner(self) -> Streaming { + self.inner.into_inner() + } + + /// Convert this stream into a `Peekable>`. + /// Preserves the state of the stream, so that calls to [`Self::peek`] + /// and [`Self::poll_next`] are the same. + pub fn into_peekable(self) -> Peekable> { + self.inner + } + + /// Peek at the head of this stream without advancing it. + pub async fn peek(&mut self) -> Option<&Result> { + Pin::new(&mut self.inner).peek().await + } +} + +impl Stream for PeekableFlightDataStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } } diff --git a/arrow-flight/src/streams.rs b/arrow-flight/src/streams.rs new file mode 100644 index 000000000000..e532a80e1ebb --- /dev/null +++ b/arrow-flight/src/streams.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters + +use crate::error::FlightError; +use futures::{ + channel::oneshot::{Receiver, Sender}, + FutureExt, Stream, StreamExt, +}; +use std::pin::Pin; +use std::task::{ready, Poll}; + +/// Wrapper around a fallible stream (one that returns errors) that makes it infallible. +/// +/// Any errors encountered in the stream are ignored are sent to the provided +/// oneshot sender. +/// +/// This can be used to accept a stream of `Result<_>` from a client API and send +/// them to the remote server that wants only the successful results. +pub(crate) struct FallibleRequestStream { + /// sender to notify error + sender: Option>, + /// fallible stream + fallible_stream: Pin> + Send + 'static>>, +} + +impl FallibleRequestStream { + pub(crate) fn new( + sender: Sender, + fallible_stream: Pin> + Send + 'static>>, + ) -> Self { + Self { + sender: Some(sender), + fallible_stream, + } + } +} + +impl Stream for FallibleRequestStream { + type Item = T; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pinned = self.get_mut(); + let mut request_streams = pinned.fallible_stream.as_mut(); + match ready!(request_streams.poll_next_unpin(cx)) { + Some(Ok(data)) => Poll::Ready(Some(data)), + Some(Err(e)) => { + // in theory this should only ever be called once + // as this stream should not be polled again after returning + // None, however we still check for None to be safe + if let Some(sender) = pinned.sender.take() { + // an error means the other end of the channel is not around + // to receive the error, so ignore it + let _ = sender.send(e); + } + Poll::Ready(None) + } + None => Poll::Ready(None), + } + } +} + +/// Wrapper for a tonic response stream that maps errors to `FlightError` and +/// returns errors from a oneshot channel into the stream. +/// +/// The user of this stream can inject an error into the response stream using +/// the one shot receiver. This is used to propagate errors in +/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client +/// provided input stream to the response stream. +/// +/// # Error Priority +/// Error from the receiver are prioritised over the response stream. +/// +/// [`FlightClient::do_put`]: crate::FlightClient::do_put +/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange +pub(crate) struct FallibleTonicResponseStream { + /// Receiver for FlightError + receiver: Receiver, + /// Tonic response stream + response_stream: Pin> + Send + 'static>>, +} + +impl FallibleTonicResponseStream { + pub(crate) fn new( + receiver: Receiver, + response_stream: Pin> + Send + 'static>>, + ) -> Self { + Self { + receiver, + response_stream, + } + } +} + +impl Stream for FallibleTonicResponseStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let pinned = self.get_mut(); + let receiver = &mut pinned.receiver; + // Prioritise sending the error that's been notified over + // polling the response_stream + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + }; + + match ready!(pinned.response_stream.poll_next_unpin(cx)) { + Some(Ok(res)) => Poll::Ready(Some(Ok(res))), + Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), + None => Poll::Ready(None), + } + } +} diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs new file mode 100644 index 000000000000..73136379d69f --- /dev/null +++ b/arrow-flight/src/trailers.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{ready, FutureExt, Stream, StreamExt}; +use tonic::{metadata::MetadataMap, Status, Streaming}; + +/// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response. +/// +/// Note that [`LazyTrailers`] has inner mutability and will only hold actual data after [`ExtractTrailersStream`] is +/// fully consumed (dropping it is not required though). +pub fn extract_lazy_trailers(s: Streaming) -> (ExtractTrailersStream, LazyTrailers) { + let trailers: SharedTrailers = Default::default(); + let stream = ExtractTrailersStream { + inner: s, + trailers: Arc::clone(&trailers), + }; + let lazy_trailers = LazyTrailers { trailers }; + (stream, lazy_trailers) +} + +type SharedTrailers = Arc>>; + +/// [Stream] that stores the gRPC trailers into [`LazyTrailers`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct ExtractTrailersStream { + inner: Streaming, + trailers: SharedTrailers, +} + +impl Stream for ExtractTrailersStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = ready!(self.inner.poll_next_unpin(cx)); + + if res.is_none() { + // stream exhausted => trailers should available + if let Some(trailers) = self + .inner + .trailers() + .now_or_never() + .and_then(|res| res.ok()) + .flatten() + { + *self.trailers.lock().expect("poisoned") = Some(trailers); + } + } + + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// gRPC trailers that are extracted by [`ExtractTrailersStream`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct LazyTrailers { + trailers: SharedTrailers, +} + +impl LazyTrailers { + /// gRPC trailers that are known at the end of a stream. + pub fn get(&self) -> Option { + self.trailers.lock().expect("poisoned").clone() + } +} diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 21a5a8572246..37d7ff9e7293 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -18,24 +18,29 @@ //! Utilities to assist with reading and writing Arrow data as Flight messages use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult}; +use bytes::Bytes; use std::collections::HashMap; +use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::buffer::Buffer; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::{ArrowError, Result}; -use arrow::ipc::{reader, writer, writer::IpcWriteOptions}; -use arrow::record_batch::RecordBatch; -use std::convert::TryInto; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; +use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries /// and a `FlightData` representing the bytes of the batch's values +#[deprecated( + since = "30.0.0", + note = "Use IpcDataGenerator directly with DictionaryTracker to avoid re-sending dictionaries" +)] pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -47,16 +52,38 @@ pub fn flight_data_from_arrow_batch( (flight_dictionaries, flight_batch) } +/// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es +pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result, ArrowError> { + let schema = flight_data.first().ok_or_else(|| { + ArrowError::CastError("Need at least one FlightData for schema".to_string()) + })?; + let message = root_as_message(&schema.data_header[..]) + .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?; + + let ipc_schema: arrow_ipc::Schema = message + .header_as_schema() + .ok_or_else(|| ArrowError::CastError("Cannot get header as Schema".to_string()))?; + let schema = fb_to_schema(ipc_schema); + let schema = Arc::new(schema); + + let mut batches = vec![]; + let dictionaries_by_id = HashMap::new(); + for datum in flight_data[1..].iter() { + let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?; + batches.push(batch); + } + Ok(batches) +} + /// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`. pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, dictionaries_by_id: &HashMap, -) -> Result { +) -> Result { // check that the data_header is a record batch message - let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {:?}", err)) - })?; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; message .header_as_record_batch() @@ -67,7 +94,7 @@ pub fn flight_data_to_arrow_batch( }) .map(|batch| { reader::read_record_batch( - &Buffer::from(&data.data_body), + &Buffer::from(data.data_body.as_ref()), batch, schema, dictionaries_by_id, @@ -80,13 +107,13 @@ pub fn flight_data_to_arrow_batch( /// Convert a `Schema` to `SchemaResult` by converting to an IPC message #[deprecated( since = "4.4.0", - note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" + note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).try_into()" )] pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, -) -> SchemaResult { - SchemaAsIpc::new(schema, options).into() +) -> Result { + SchemaAsIpc::new(schema, options).try_into() } /// Convert a `Schema` to `FlightData` by converting to an IPC message @@ -94,10 +121,7 @@ pub fn flight_schema_from_arrow_schema( since = "4.4.0", note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" )] -pub fn flight_data_from_arrow_schema( - schema: &Schema, - options: &IpcWriteOptions, -) -> FlightData { +pub fn flight_data_from_arrow_schema(schema: &Schema, options: &IpcWriteOptions) -> FlightData { SchemaAsIpc::new(schema, options).into() } @@ -109,8 +133,36 @@ pub fn flight_data_from_arrow_schema( pub fn ipc_message_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, -) -> Result> { +) -> Result { let message = SchemaAsIpc::new(schema, options).try_into()?; let IpcMessage(vals) = message; Ok(vals) } + +/// Convert `RecordBatch`es to wire protocol `FlightData`s +pub fn batches_to_flight_data( + schema: &Schema, + batches: Vec, +) -> Result, ArrowError> { + let options = IpcWriteOptions::default(); + let schema_flight_data: FlightData = SchemaAsIpc::new(schema, &options).into(); + let mut dictionaries = vec![]; + let mut flight_data = vec![]; + + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + + for batch in batches.iter() { + let (encoded_dictionaries, encoded_batch) = + data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?; + + dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into)); + flight_data.push(encoded_batch.into()); + } + let mut stream = vec![schema_flight_data]; + stream.extend(dictionaries); + stream.extend(flight_data); + let flight_data: Vec<_> = stream.into_iter().collect(); + Ok(flight_data) +} diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs new file mode 100644 index 000000000000..25dad0e77a3e --- /dev/null +++ b/arrow-flight/tests/client.rs @@ -0,0 +1,1151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test for "mid level" Client + +mod common; + +use crate::common::fixture::TestFixture; +use arrow_array::{RecordBatch, UInt64Array}; +use arrow_flight::{ + decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action, + ActionType, CancelFlightInfoRequest, CancelFlightInfoResult, CancelStatus, Criteria, Empty, + FlightClient, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, PollInfo, PutResult, RenewFlightEndpointRequest, Ticket, +}; +use arrow_schema::{DataType, Field, Schema}; +use bytes::Bytes; +use common::server::TestFlightServer; +use futures::{Future, StreamExt, TryStreamExt}; +use prost::Message; +use tonic::Status; + +use std::sync::Arc; + +#[tokio::test] +async fn test_handshake() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request_payload = Bytes::from("foo-request-payload"); + let response_payload = Bytes::from("bar-response-payload"); + + let request = HandshakeRequest { + payload: request_payload.clone(), + protocol_version: 0, + }; + + let response = HandshakeResponse { + payload: response_payload.clone(), + protocol_version: 0, + }; + + test_server.set_handshake_response(Ok(response)); + let response = client.handshake(request_payload).await.unwrap(); + assert_eq!(response, response_payload); + assert_eq!(test_server.take_handshake_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_handshake_error() { + do_test(|test_server, mut client| async move { + let request_payload = "foo-request-payload".to_string().into_bytes(); + let e = Status::unauthenticated("DENIED"); + test_server.set_handshake_response(Err(e.clone())); + + let response = client.handshake(request_payload).await.unwrap_err(); + expect_status(response, e); + }) + .await; +} + +/// Verifies that all headers sent from the the client are in the request_metadata +fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) { + let client_metadata = client.metadata().clone().into_headers(); + assert!(!client_metadata.is_empty()); + let metadata = test_server + .take_last_request_metadata() + .expect("No headers in server") + .into_headers(); + + for (k, v) in &client_metadata { + assert_eq!( + metadata.get(k).as_ref(), + Some(&v), + "Missing / Mismatched metadata {k:?} sent {client_metadata:?} got {metadata:?}" + ); + } +} + +fn test_flight_info(request: &FlightDescriptor) -> FlightInfo { + FlightInfo { + schema: Bytes::new(), + endpoint: vec![], + flight_descriptor: Some(request.clone()), + total_bytes: 123, + total_records: 456, + ordered: false, + app_metadata: Bytes::new(), + } +} + +#[tokio::test] +async fn test_get_flight_info() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let expected_response = test_flight_info(&request); + test_server.set_get_flight_info_response(Ok(expected_response.clone())); + + let response = client.get_flight_info(request.clone()).await.unwrap(); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_get_flight_info_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_flight_info_error() { + do_test(|test_server, mut client| async move { + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_flight_info_response(Err(e.clone())); + + let response = client.get_flight_info(request.clone()).await.unwrap_err(); + expect_status(response, e); + }) + .await; +} + +fn test_poll_info(request: &FlightDescriptor) -> PollInfo { + PollInfo { + info: Some(test_flight_info(request)), + flight_descriptor: None, + progress: Some(1.0), + expiration_time: None, + } +} + +#[tokio::test] +async fn test_poll_flight_info() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let expected_response = test_poll_info(&request); + test_server.set_poll_flight_info_response(Ok(expected_response.clone())); + + let response = client.poll_flight_info(request.clone()).await.unwrap(); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_poll_flight_info_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_poll_flight_info_error() { + do_test(|test_server, mut client| async move { + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let e = Status::unauthenticated("DENIED"); + test_server.set_poll_flight_info_response(Err(e.clone())); + + let response = client.poll_flight_info(request.clone()).await.unwrap_err(); + expect_status(response, e); + }) + .await; +} + +// TODO more negative tests (like if there are endpoints defined, etc) + +#[tokio::test] +async fn test_do_get() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let response = vec![Ok(batch.clone())]; + test_server.set_do_get_response(response); + let mut response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + assert_eq!( + response_stream + .headers() + .get("test-resp-header") + .expect("header exists") + .to_str() + .unwrap(), + "some_val", + ); + + // trailers are not available before stream exhaustion + assert!(response_stream.trailers().is_none()); + + let expected_response = vec![batch]; + let response: Vec<_> = (&mut response_stream) + .try_collect() + .await + .expect("Error streaming data"); + assert_eq!(response, expected_response); + + assert_eq!( + response_stream + .trailers() + .expect("stream exhausted") + .get("test-trailer") + .expect("trailer exists") + .to_str() + .unwrap(), + "trailer_val", + ); + + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let response = client.do_get(ticket.clone()).await.unwrap_err(); + + let e = Status::internal("No do_get response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error_in_record_batch_stream() { + do_test(|test_server, mut client| async move { + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let expected_response = vec![Ok(batch), Err(e.clone())]; + + test_server.set_do_get_response(expected_response); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + }) + .await; +} + +#[tokio::test] +async fn test_do_put() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + // encode the batch as a stream of FlightData + let input_flight_data = test_flight_data().await; + + let expected_response = vec![ + PutResult { + app_metadata: Bytes::from("foo-metadata1"), + }, + PutResult { + app_metadata: Bytes::from("bar-metadata2"), + }, + ]; + + test_server.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect()); + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response = client.do_put(input_stream).await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_put response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_stream_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let e = Status::invalid_argument("bad arg"); + + let response = vec![ + Ok(PutResult { + app_metadata: Bytes::from("foo-metadata"), + }), + Err(e.clone()), + ]; + + test_server.set_do_put_response(response); + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_client() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::invalid_argument("bad arg: client"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e.clone(), + ))])); + + // server responds with one good message + let response = vec![Ok(PutResult { + app_metadata: Bytes::from("foo-metadata"), + })]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client + expect_status(response, e); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_client_and_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e_client = Status::invalid_argument("bad arg: client"); + let e_server = Status::invalid_argument("bad arg: server"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e_client.clone(), + ))])); + + // server responds with an error (e.g. because it got truncated data) + let response = vec![Err(e_server)]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client (not the server) + expect_status(response, e_client); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + // encode the batch as a stream of FlightData + let input_flight_data = test_flight_data().await; + let output_flight_data = test_flight_data2().await; + + test_server + .set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect()); + + let response_stream = client + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) + .await + .expect("error making request"); + + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + let expected_stream = futures::stream::iter(output_flight_data).map(Ok); + + let expected_batches: Vec<_> = + FlightRecordBatchStream::new_from_flight_data(expected_stream) + .try_collect() + .await + .unwrap(); + + assert_eq!(response, expected_batches); + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let response = client + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) + .await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_exchange response configured"); + expect_status(response, e); + // server still got the request + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let e = Status::invalid_argument("the error"); + let response = test_flight_data2() + .await + .into_iter() + .enumerate() + .map(|(i, m)| { + if i == 0 { + Ok(m) + } else { + // make all messages after the first an error + Err(e.clone()) + } + }) + .collect(); + + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + expect_status(response, e); + // server still got the request + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error_stream_client() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::invalid_argument("bad arg: client"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e.clone(), + ))])); + + let output_flight_data = FlightData::new() + .with_descriptor(FlightDescriptor::new_cmd("Sample command")) + .with_data_body("body".as_bytes()) + .with_data_header("header".as_bytes()) + .with_app_metadata("metadata".as_bytes()); + + // server responds with one good message + let response = vec![Ok(output_flight_data)]; + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client + expect_status(response, e); + // server still got the request messages until the client sent the error + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error_client_and_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e_client = Status::invalid_argument("bad arg: client"); + let e_server = Status::invalid_argument("bad arg: server"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e_client.clone(), + ))])); + + // server responds with an error (e.g. because it got truncated data) + let response = vec![Err(e_server)]; + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client (not the server) + expect_status(response, e_client); + // server still got the request messages until the client sent the error + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_schema() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]); + + let request = FlightDescriptor::new_cmd("my command"); + test_server.set_get_schema_response(Ok(schema.clone())); + + let response = client + .get_schema(request.clone()) + .await + .expect("error making request"); + + let expected_schema = schema; + let expected_request = request; + + assert_eq!(response, expected_schema); + assert_eq!( + test_server.take_get_schema_request(), + Some(expected_request) + ); + + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_schema_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd("my command"); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_schema_response(Err(e.clone())); + + let response = client.get_schema(request).await.unwrap_err(); + expect_status(response, e); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let infos = vec![ + test_flight_info(&FlightDescriptor::new_cmd("foo")), + test_flight_info(&FlightDescriptor::new_cmd("bar")), + ]; + + let response = infos.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("query") + .await + .expect("error making request"); + + let expected_response = infos; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + let expected_request = Some(Criteria { + expression: "query".into(), + }); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_flights("query").await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_flights response configured"); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))), + Err(e.clone()), + ]; + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("other query") + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "other query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let actions = vec![ + ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }, + ActionType { + r#type: "type 2".into(), + description: "more awesomeness".into(), + }, + ]; + + let response = actions.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let expected_response = actions; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_actions().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_actions response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }), + Err(e.clone()), + ]; + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")]; + + let response = bytes + .iter() + .cloned() + .map(arrow_flight::Result::new) + .map(Ok) + .collect(); + test_server.set_do_action_response(response); + + let request = Action::new("action type", "action body"); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let expected_response = bytes; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let request = Action::new("action type", "action body"); + + let response = client.do_action(request.clone()).await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_action response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let request = Action::new("action type", "action body"); + + let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())]; + test_server.set_do_action_response(response); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_cancel_flight_info() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let expected_response = CancelFlightInfoResult::new(CancelStatus::Cancelled); + let response = expected_response.encode_to_vec(); + let response = Ok(arrow_flight::Result::new(response)); + test_server.set_do_action_response(vec![response]); + + let request = CancelFlightInfoRequest::new(FlightInfo::new()); + let actual_response = client + .cancel_flight_info(request.clone()) + .await + .expect("error making request"); + + let expected_request = Action::new("CancelFlightInfo", request.encode_to_vec()); + assert_eq!(actual_response, expected_response); + assert_eq!(test_server.take_do_action_request(), Some(expected_request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_cancel_flight_info_error_no_response() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + test_server.set_do_action_response(vec![]); + + let request = CancelFlightInfoRequest::new(FlightInfo::new()); + let err = client + .cancel_flight_info(request.clone()) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Protocol error: Received no response for cancel_flight_info call" + ); + // server still got the request + let expected_request = Action::new("CancelFlightInfo", request.encode_to_vec()); + assert_eq!(test_server.take_do_action_request(), Some(expected_request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_renew_flight_endpoint() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let expected_response = FlightEndpoint::new().with_app_metadata(vec![1]); + let response = expected_response.encode_to_vec(); + let response = Ok(arrow_flight::Result::new(response)); + test_server.set_do_action_response(vec![response]); + + let request = + RenewFlightEndpointRequest::new(FlightEndpoint::new().with_app_metadata(vec![0])); + let actual_response = client + .renew_flight_endpoint(request.clone()) + .await + .expect("error making request"); + + let expected_request = Action::new("RenewFlightEndpoint", request.encode_to_vec()); + assert_eq!(actual_response, expected_response); + assert_eq!(test_server.take_do_action_request(), Some(expected_request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_renew_flight_endpoint_error_no_response() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + test_server.set_do_action_response(vec![]); + + let request = RenewFlightEndpointRequest::new(FlightEndpoint::new()); + let err = client + .renew_flight_endpoint(request.clone()) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Protocol error: Received no response for renew_flight_endpoint call" + ); + // server still got the request + let expected_request = Action::new("RenewFlightEndpoint", request.encode_to_vec()); + assert_eq!(test_server.take_do_action_request(), Some(expected_request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +async fn test_flight_data() -> Vec { + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + // encode the batch as a stream of FlightData + FlightDataEncoderBuilder::new() + .build(futures::stream::iter(vec![Ok(batch)])) + .try_collect() + .await + .unwrap() +} + +async fn test_flight_data2() -> Vec { + let batch = RecordBatch::try_from_iter(vec![( + "col2", + Arc::new(UInt64Array::from_iter([10, 23, 33])) as _, + )]) + .unwrap(); + + // encode the batch as a stream of FlightData + FlightDataEncoderBuilder::new() + .build(futures::stream::iter(vec![Ok(batch)])) + .try_collect() + .await + .unwrap() +} + +/// Runs the future returned by the function, passing it a test server and client +async fn do_test(f: F) +where + F: Fn(TestFlightServer, FlightClient) -> Fut, + Fut: Future, +{ + let test_server = TestFlightServer::new(); + let fixture = TestFixture::new(test_server.service()).await; + let client = FlightClient::new(fixture.channel().await); + + // run the test function + f(test_server, client).await; + + // cleanly shutdown the test fixture + fixture.shutdown_and_wait().await +} + +fn expect_status(error: FlightError, expected: Status) { + let status = if let FlightError::Tonic(status) = error { + status + } else { + panic!("Expected FlightError::Tonic, got: {error:?}"); + }; + + assert_eq!( + status.code(), + expected.code(), + "Got {status:?} want {expected:?}" + ); + assert_eq!( + status.message(), + expected.message(), + "Got {status:?} want {expected:?}" + ); + assert_eq!( + status.details(), + expected.details(), + "Got {status:?} want {expected:?}" + ); +} diff --git a/arrow-flight/tests/common/fixture.rs b/arrow-flight/tests/common/fixture.rs new file mode 100644 index 000000000000..a666fa5d0d59 --- /dev/null +++ b/arrow-flight/tests/common/fixture.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::common::trailers_layer::TrailersLayer; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; +use http::Uri; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; +use tonic::transport::Channel; + +/// All tests must complete within this many seconds or else the test server is shutdown +const DEFAULT_TIMEOUT_SECONDS: u64 = 30; + +/// Creates and manages a running TestServer with a background task +pub struct TestFixture { + /// channel to send shutdown command + shutdown: Option>, + + /// Address the server is listening on + pub addr: SocketAddr, + + /// handle for the server task + handle: Option>>, +} + +impl TestFixture { + /// create a new test fixture from the server + #[allow(dead_code)] + pub async fn new(test_server: FlightServiceServer) -> Self { + // let OS choose a free port + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + println!("Listening on {addr}"); + + // prepare the shutdown channel + let (tx, rx) = tokio::sync::oneshot::channel(); + + let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); + + let shutdown_future = async move { + rx.await.ok(); + }; + + let serve_future = tonic::transport::Server::builder() + .timeout(server_timeout) + .layer(TrailersLayer) + .add_service(test_server) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + + // Run the server in its own background task + let handle = tokio::task::spawn(serve_future); + + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + /// Return a [`Channel`] connected to the TestServer + #[allow(dead_code)] + pub async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + let uri: Uri = url.parse().expect("Valid URI"); + Channel::builder(uri) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)) + .connect() + .await + .expect("error connecting to server") + } + + /// Stops the test server and waits for the server to shutdown + #[allow(dead_code)] + pub async fn shutdown_and_wait(mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + println!("Waiting on server to finish"); + handle + .await + .expect("task join error (panic?)") + .expect("Server Error found at shutdown"); + } + } +} + +impl Drop for TestFixture { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).ok(); + } + if self.handle.is_some() { + // tests should properly clean up TestFixture + println!("TestFixture::Drop called prior to `shutdown_and_wait`"); + } + } +} diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs new file mode 100644 index 000000000000..c4ac027c5890 --- /dev/null +++ b/arrow-flight/tests/common/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod fixture; +pub mod server; +pub mod trailers_layer; +pub mod utils; diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs new file mode 100644 index 000000000000..a004ccb0737e --- /dev/null +++ b/arrow-flight/tests/common/server.rs @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, Mutex}; + +use arrow_array::RecordBatch; +use arrow_schema::Schema; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; + +use arrow_flight::{ + encode::FlightDataEncoderBuilder, + flight_service_server::{FlightService, FlightServiceServer}, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket, +}; + +#[derive(Debug, Clone)] +/// Flight server for testing, with configurable responses +pub struct TestFlightServer { + /// Shared state to configure responses + state: Arc>, +} + +impl TestFlightServer { + /// Create a `TestFlightServer` + #[allow(dead_code)] + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(State::new())), + } + } + + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + #[allow(dead_code)] + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } + + /// Specify the response returned from the next call to handshake + #[allow(dead_code)] + pub fn set_handshake_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.handshake_response.replace(response); + } + + /// Take and return last handshake request sent to the server, + #[allow(dead_code)] + pub fn take_handshake_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .handshake_request + .take() + } + + /// Specify the response returned from the next call to get_flight_info + #[allow(dead_code)] + pub fn set_get_flight_info_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_flight_info_response.replace(response); + } + + /// Take and return last get_flight_info request sent to the server, + #[allow(dead_code)] + pub fn take_get_flight_info_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_flight_info_request + .take() + } + + /// Specify the response returned from the next call to poll_flight_info + #[allow(dead_code)] + pub fn set_poll_flight_info_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.poll_flight_info_response.replace(response); + } + + /// Take and return last poll_flight_info request sent to the server, + #[allow(dead_code)] + pub fn take_poll_flight_info_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .poll_flight_info_request + .take() + } + + /// Specify the response returned from the next call to `do_get` + #[allow(dead_code)] + pub fn set_do_get_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_get_response.replace(response); + } + + /// Take and return last do_get request send to the server, + #[allow(dead_code)] + pub fn take_do_get_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_get_request + .take() + } + + /// Specify the response returned from the next call to `do_put` + #[allow(dead_code)] + pub fn set_do_put_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_put_response.replace(response); + } + + /// Take and return last do_put request sent to the server, + #[allow(dead_code)] + pub fn take_do_put_request(&self) -> Option> { + self.state + .lock() + .expect("mutex not poisoned") + .do_put_request + .take() + } + + /// Specify the response returned from the next call to `do_exchange` + #[allow(dead_code)] + pub fn set_do_exchange_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_exchange_response.replace(response); + } + + /// Take and return last do_exchange request send to the server, + #[allow(dead_code)] + pub fn take_do_exchange_request(&self) -> Option> { + self.state + .lock() + .expect("mutex not poisoned") + .do_exchange_request + .take() + } + + /// Specify the response returned from the next call to `list_flights` + #[allow(dead_code)] + pub fn set_list_flights_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_flights_response.replace(response); + } + + /// Take and return last list_flights request send to the server, + #[allow(dead_code)] + pub fn take_list_flights_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_flights_request + .take() + } + + /// Specify the response returned from the next call to `get_schema` + #[allow(dead_code)] + pub fn set_get_schema_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_response.replace(response); + } + + /// Take and return last get_schema request send to the server, + #[allow(dead_code)] + pub fn take_get_schema_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_schema_request + .take() + } + + /// Specify the response returned from the next call to `list_actions` + #[allow(dead_code)] + pub fn set_list_actions_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_actions_response.replace(response); + } + + /// Take and return last list_actions request send to the server, + #[allow(dead_code)] + pub fn take_list_actions_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_actions_request + .take() + } + + /// Specify the response returned from the next call to `do_action` + #[allow(dead_code)] + pub fn set_do_action_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_action_response.replace(response); + } + + /// Take and return last do_action request send to the server, + #[allow(dead_code)] + pub fn take_do_action_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_action_request + .take() + } + + /// Returns the last metadata from a request received by the server + #[allow(dead_code)] + pub fn take_last_request_metadata(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .last_request_metadata + .take() + } + + /// Save the last request's metadatacom + fn save_metadata(&self, request: &Request) { + let metadata = request.metadata().clone(); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.last_request_metadata = Some(metadata); + } +} + +/// mutable state for the TestFlightServer, captures requests and provides responses +#[derive(Debug, Default)] +struct State { + /// The last handshake request that was received + pub handshake_request: Option, + /// The next response to return from `handshake()` + pub handshake_response: Option>, + /// The last `get_flight_info` request received + pub get_flight_info_request: Option, + /// The next response to return from `get_flight_info` + pub get_flight_info_response: Option>, + /// The last `poll_flight_info` request received + pub poll_flight_info_request: Option, + /// The next response to return from `poll_flight_info` + pub poll_flight_info_response: Option>, + /// The last do_get request received + pub do_get_request: Option, + /// The next response returned from `do_get` + pub do_get_response: Option>>, + /// The last do_put request received + pub do_put_request: Option>, + /// The next response returned from `do_put` + pub do_put_response: Option>>, + /// The last do_exchange request received + pub do_exchange_request: Option>, + /// The next response returned from `do_exchange` + pub do_exchange_response: Option>>, + /// The last list_flights request received + pub list_flights_request: Option, + /// The next response returned from `list_flights` + pub list_flights_response: Option>>, + /// The last get_schema request received + pub get_schema_request: Option, + /// The next response returned from `get_schema` + pub get_schema_response: Option>, + /// The last list_actions request received + pub list_actions_request: Option, + /// The next response returned from `list_actions` + pub list_actions_response: Option>>, + /// The last do_action request received + pub do_action_request: Option, + /// The next response returned from `do_action` + pub do_action_response: Option>>, + /// The last request headers received + pub last_request_metadata: Option, +} + +impl State { + fn new() -> Self { + Default::default() + } +} + +/// Implement the FlightService trait +#[tonic::async_trait] +impl FlightService for TestFlightServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let handshake_request = request.into_inner().message().await?.unwrap(); + + let mut state = self.state.lock().expect("mutex not poisoned"); + state.handshake_request = Some(handshake_request); + + let response = state + .handshake_response + .take() + .unwrap_or_else(|| Err(Status::internal("No handshake response configured")))?; + + // turn into a streaming response + let output = futures::stream::iter(std::iter::once(Ok(response))); + Ok(Response::new(output.boxed())) + } + + async fn list_flights( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_flights_request = Some(request.into_inner()); + + let flights: Vec<_> = state + .list_flights_response + .take() + .ok_or_else(|| Status::internal("No list_flights response configured"))?; + + let flights_stream = futures::stream::iter(flights); + + Ok(Response::new(flights_stream.boxed())) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_flight_info_request = Some(request.into_inner()); + let response = state + .get_flight_info_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_flight_info response configured")))?; + Ok(Response::new(response)) + } + + async fn poll_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.poll_flight_info_request = Some(request.into_inner()); + let response = state + .poll_flight_info_response + .take() + .unwrap_or_else(|| Err(Status::internal("No poll_flight_info response configured")))?; + Ok(Response::new(response)) + } + + async fn get_schema( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_request = Some(request.into_inner()); + let schema = state + .get_schema_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_schema response configured")))?; + + // encode the schema + let options = arrow_ipc::writer::IpcWriteOptions::default(); + let response: SchemaResult = SchemaAsIpc::new(&schema, &options) + .try_into() + .expect("Error encoding schema"); + + Ok(Response::new(response)) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_get_request = Some(request.into_inner()); + + let batches: Vec<_> = state + .do_get_response + .take() + .ok_or_else(|| Status::internal("No do_get response configured"))?; + + let batch_stream = futures::stream::iter(batches).map_err(Into::into); + + let stream = FlightDataEncoderBuilder::new() + .build(batch_stream) + .map_err(Into::into); + + let mut resp = Response::new(stream.boxed()); + resp.metadata_mut() + .insert("test-resp-header", "some_val".parse().unwrap()); + + Ok(resp) + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let do_put_request: Vec<_> = request.into_inner().try_collect().await?; + + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_put_request = Some(do_put_request); + + let response = state + .do_put_response + .take() + .ok_or_else(|| Status::internal("No do_put response configured"))?; + + let stream = futures::stream::iter(response).map_err(Into::into); + + Ok(Response::new(stream.boxed())) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_action_request = Some(request.into_inner()); + + let results: Vec<_> = state + .do_action_response + .take() + .ok_or_else(|| Status::internal("No do_action response configured"))?; + + let results_stream = futures::stream::iter(results); + + Ok(Response::new(results_stream.boxed())) + } + + async fn list_actions( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_actions_request = Some(request.into_inner()); + + let actions: Vec<_> = state + .list_actions_response + .take() + .ok_or_else(|| Status::internal("No list_actions response configured"))?; + + let action_stream = futures::stream::iter(actions); + + Ok(Response::new(action_stream.boxed())) + } + + async fn do_exchange( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let do_exchange_request: Vec<_> = request.into_inner().try_collect().await?; + + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_exchange_request = Some(do_exchange_request); + + let response = state + .do_exchange_response + .take() + .ok_or_else(|| Status::internal("No do_exchange response configured"))?; + + let stream = futures::stream::iter(response).map_err(Into::into); + + Ok(Response::new(stream.boxed())) + } +} diff --git a/arrow-flight/tests/common/trailers_layer.rs b/arrow-flight/tests/common/trailers_layer.rs new file mode 100644 index 000000000000..0ccb7df86c74 --- /dev/null +++ b/arrow-flight/tests/common/trailers_layer.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::ready; +use http::{HeaderValue, Request, Response}; +use http_body::{Frame, SizeHint}; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +#[derive(Debug, Copy, Clone, Default)] +pub struct TrailersLayer; + +impl Layer for TrailersLayer { + type Service = TrailersService; + + fn layer(&self, service: S) -> Self::Service { + TrailersService { service } + } +} + +#[derive(Debug, Clone)] +pub struct TrailersService { + service: S, +} + +impl Service> for TrailersService +where + S: Service, Response = Response>, + ResBody: http_body::Body, +{ + type Response = Response>; + type Error = S::Error; + type Future = WrappedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + WrappedFuture { + inner: self.service.call(request), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedFuture { + #[pin] + inner: F, + } +} + +impl Future for WrappedFuture +where + F: Future, Error>>, + ResBody: http_body::Body, +{ + type Output = Result>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, Error> = + ready!(self.as_mut().project().inner.poll(cx)); + + match result { + Ok(response) => Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))), + Err(e) => Poll::Ready(Err(e)), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedBody { + #[pin] + inner: B, + } +} + +impl http_body::Body for WrappedBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let mut result = ready!(self.project().inner.poll_frame(cx)); + + if let Some(Ok(frame)) = &mut result { + if let Some(trailers) = frame.trailers_mut() { + trailers.insert("test-trailer", HeaderValue::from_static("trailer_val")); + } + } + + Poll::Ready(result) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} diff --git a/arrow-flight/tests/common/utils.rs b/arrow-flight/tests/common/utils.rs new file mode 100644 index 000000000000..0f70e4b31021 --- /dev/null +++ b/arrow-flight/tests/common/utils.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for testing flight clients and servers + +use std::sync::Arc; + +use arrow_array::{ + types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, + StringViewArray, UInt8Array, +}; +use arrow_schema::{DataType, Field, Schema}; + +/// Make a primitive batch for testing +/// +/// Example: +/// i: 0, 1, None, 3, 4 +/// f: 5.0, 4.0, None, 2.0, 1.0 +#[allow(dead_code)] +pub fn make_primitive_batch(num_rows: usize) -> RecordBatch { + let i: UInt8Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some(i.try_into().unwrap()) + } + }) + .collect(); + + let f: Float64Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some((num_rows - i) as f64) + } + }) + .collect(); + + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() +} + +/// Make a dictionary batch for testing +/// +/// Example: +/// a: value0, value1, value2, None, value1, value2 +#[allow(dead_code)] +pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch { + let values: Vec<_> = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + // repeat some values for low cardinality + let v = i / 3; + Some(format!("value{v}")) + } + }) + .collect(); + + let a: DictionaryArray = values + .iter() + .map(|s| s.as_ref().map(|s| s.as_str())) + .collect(); + + RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() +} + +#[allow(dead_code)] +pub fn make_view_batches(num_rows: usize) -> RecordBatch { + const LONG_TEST_STRING: &str = + "This is a long string to make sure binary view array handles it"; + let schema = Schema::new(vec![ + Field::new("field1", DataType::BinaryView, true), + Field::new("field2", DataType::Utf8View, true), + ]); + + let string_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("foo"), + 2 => Some(LONG_TEST_STRING), + _ => unreachable!(), + }) + .collect(); + + let bin_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("bar".as_bytes()), + 2 => Some(LONG_TEST_STRING.as_bytes()), + _ => unreachable!(), + }) + .collect(); + + let binary_array = BinaryViewArray::from_iter(bin_view_values); + let utf8_array = StringViewArray::from_iter(string_view_values); + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(binary_array), Arc::new(utf8_array)], + ) + .unwrap() +} diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs new file mode 100644 index 000000000000..cbfae1825845 --- /dev/null +++ b/arrow-flight/tests/encode_decode.rs @@ -0,0 +1,503 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for round trip encoding / decoding + +use std::{collections::HashMap, sync::Arc}; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_cast::pretty::pretty_format_batches; +use arrow_flight::flight_descriptor::DescriptorType; +use arrow_flight::FlightDescriptor; +use arrow_flight::{ + decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, + encode::FlightDataEncoderBuilder, + error::FlightError, +}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use bytes::Bytes; +use futures::{StreamExt, TryStreamExt}; + +mod common; +use common::utils::{make_dictionary_batch, make_primitive_batch, make_view_batches}; + +#[tokio::test] +async fn test_empty() { + roundtrip(vec![]).await; +} + +#[tokio::test] +async fn test_empty_batch() { + let batch = make_primitive_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + roundtrip(vec![empty]).await; +} + +#[tokio::test] +async fn test_error() { + let input_batch_stream = + futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, _> = decode_stream.try_collect().await; + + let result = result.unwrap_err(); + assert_eq!(result.to_string(), "Not yet implemented: foo"); +} + +#[tokio::test] +async fn test_primitive_one() { + roundtrip(vec![make_primitive_batch(5)]).await; +} + +#[tokio::test] +async fn test_schema_metadata() { + let batch = make_primitive_batch(5); + let metadata = HashMap::from([("some_key".to_owned(), "some_value".to_owned())]); + + // create a batch that has schema level metadata + let schema = Arc::new(batch.schema().as_ref().clone().with_metadata(metadata)); + let batch = RecordBatch::try_new(schema, batch.columns().to_vec()).unwrap(); + + roundtrip(vec![batch]).await; +} + +#[tokio::test] +async fn test_primitive_many() { + roundtrip(vec![ + make_primitive_batch(1), + make_primitive_batch(7), + make_primitive_batch(32), + ]) + .await; +} + +#[tokio::test] +async fn test_primitive_empty() { + let batch = make_primitive_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + + roundtrip(vec![batch, empty]).await; +} + +#[tokio::test] +async fn test_dictionary_one() { + roundtrip_dictionary(vec![make_dictionary_batch(5)]).await; +} + +#[tokio::test] +async fn test_dictionary_many() { + roundtrip_dictionary(vec![ + make_dictionary_batch(5), + make_dictionary_batch(9), + make_dictionary_batch(5), + make_dictionary_batch(5), + ]) + .await; +} + +#[tokio::test] +async fn test_view_types_one() { + roundtrip(vec![make_view_batches(5)]).await; +} + +#[tokio::test] +async fn test_view_types_many() { + roundtrip(vec![ + make_view_batches(5), + make_view_batches(9), + make_view_batches(5), + make_view_batches(5), + ]) + .await; +} + +#[tokio::test] +async fn test_zero_batches_no_schema() { + let stream = FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // schema has not been received + assert!(decoder.schema().is_none()); +} + +#[tokio::test] +async fn test_zero_batches_schema_specified() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let stream = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // But schema has been received correctly + assert_eq!(decoder.schema(), Some(&schema)); +} + +#[tokio::test] +async fn test_with_flight_descriptor() { + let stream = futures::stream::iter(vec![Ok(make_dictionary_batch(5))]); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + + let descriptor = Some(FlightDescriptor { + r#type: DescriptorType::Path.into(), + path: vec!["table_name".to_string()], + cmd: Bytes::default(), + }); + + let encoder = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .with_flight_descriptor(descriptor.clone()); + + let mut encoder = encoder.build(stream); + + // First batch should be the schema + let first_batch = encoder.next().await.unwrap().unwrap(); + + assert_eq!(first_batch.flight_descriptor, descriptor); +} + +#[tokio::test] +async fn test_zero_batches_dictionary_schema_specified() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_dictionary("b", DataType::Int32, DataType::Utf8, false), + ])); + + // Expect dictionary to be hydrated in output (#3389) + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ])); + let stream = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // But schema has been received correctly + assert_eq!(decoder.schema(), Some(&expected_schema)); +} + +#[tokio::test] +async fn test_app_metadata() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(78))]); + + let app_metadata = Bytes::from("My Metadata"); + let encoder = FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone()); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + // expect that the app metadata made it through on the schema message + assert_eq!(messages.len(), 2); + let message2 = messages.pop().unwrap(); + let message1 = messages.pop().unwrap(); + + assert_eq!(message1.app_metadata(), app_metadata); + assert!(matches!(message1.payload, DecodedPayload::Schema(_))); + + // but not on the data + assert_eq!(message2.app_metadata(), Bytes::new()); + assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_))); +} + +#[tokio::test] +async fn test_max_message_size() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]); + + // 5 input rows, with a very small limit should result in 5 batch messages + let encoder = FlightDataEncoderBuilder::default().with_max_flight_data_size(1); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + assert_eq!(messages.len(), 6); + assert!(matches!(messages[0].payload, DecodedPayload::Schema(_))); + for message in messages.iter().skip(1) { + assert!(matches!(message.payload, DecodedPayload::RecordBatch(_))); + } +} + +#[tokio::test] +async fn test_max_message_size_fuzz() { + // send through batches of varying sizes with various max + // batch sizes and ensure the data gets through ok + let input = vec![ + make_primitive_batch(123), + make_primitive_batch(17), + make_primitive_batch(201), + make_primitive_batch(2), + make_primitive_batch(1), + make_primitive_batch(11), + make_primitive_batch(127), + ]; + + for max_message_size_bytes in [10, 1024, 2048, 6400, 3211212] { + let encoder = + FlightDataEncoderBuilder::default().with_max_flight_data_size(max_message_size_bytes); + + let input_batch_stream = futures::stream::iter(input.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + for b in &output { + assert_eq!(b.schema(), input[0].schema()); + } + + let a = pretty_format_batches(&input).unwrap().to_string(); + let b = pretty_format_batches(&output).unwrap().to_string(); + assert_eq!(a, b); + } +} + +#[tokio::test] +async fn test_mismatched_record_batch_schema() { + // send 2 batches with different schemas + let input_batch_stream = futures::stream::iter(vec![ + Ok(make_primitive_batch(5)), + Ok(make_dictionary_batch(3)), + ]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let result: Result, FlightError> = encode_stream.try_collect().await; + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "Arrow error: Invalid argument error: number of columns(1) must match number of fields(2) in schema" + ); +} + +#[tokio::test] +async fn test_chained_streams_batch_decoder() { + let batch1 = make_primitive_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "Protocol error: Unexpectedly saw multiple Schema messages in FlightData stream" + ); +} + +#[tokio::test] +async fn test_chained_streams_data_decoder() { + let batch1 = make_primitive_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // lower level decode stream can handle multiple schema messages + let decode_stream = FlightDataDecoder::new(encode_stream); + + let decoded_data: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + println!("decoded data: {decoded_data:#?}"); + + // expect two schema messages with the data + assert_eq!(decoded_data.len(), 4); + assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[1].payload, + DecodedPayload::RecordBatch(_) + )); + assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[3].payload, + DecodedPayload::RecordBatch(_) + )); +} + +#[tokio::test] +async fn test_mismatched_schema_message() { + // Model sending schema that is mismatched with the data + // and expect an error + async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str) { + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])) + // take only schema message from first stream + .take(1); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])) + // take only data message from second + .skip(1); + + // append the two streams + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err().to_string(); + assert!( + err.contains(expected), + "could not find '{expected}' in '{err}'" + ); + } + + // primitive batch first (has more columns) + do_test( + make_primitive_batch(5), + make_dictionary_batch(3), + "Error decoding ipc RecordBatch: Schema error: Invalid data for schema", + ) + .await; + + // dictionary batch first + do_test( + make_dictionary_batch(3), + make_primitive_batch(5), + "Error decoding ipc RecordBatch: Invalid argument error", + ) + .await; +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and validates the decoded record batches +/// match the input. +async fn roundtrip(input: Vec) { + let expected_output = input.clone(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and validates the decoded record batches +/// match the expected input. +/// +/// When is resolved, +/// it should be possible to use `roundtrip` +async fn roundtrip_dictionary(input: Vec) { + let schema = Arc::new(prepare_schema_for_flight(input[0].schema_ref())); + let expected_output: Vec<_> = input + .iter() + .map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap()) + .collect(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await +} + +async fn roundtrip_with_encoder( + encoder: FlightDataEncoderBuilder, + input_batches: Vec, + expected_batches: Vec, +) { + println!("Round tripping with encoder:\n{encoder:#?}"); + + let input_batch_stream = futures::stream::iter(input_batches.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output_batches: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + // remove any empty batches from input as they are not transmitted + let expected_batches: Vec<_> = expected_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect(); + + assert_eq!(expected_batches, output_batches); +} + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_schema_for_flight(schema: &Schema) -> Schema { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.as_ref().clone(), + }) + .collect(); + + Schema::new(fields) +} + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_batch_for_flight( + batch: &RecordBatch, + schema: SchemaRef, +) -> Result { + let columns = batch + .columns() + .iter() + .map(hydrate_dictionary) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(schema, columns)?) +} + +fn hydrate_dictionary(array: &ArrayRef) -> Result { + let arr = if let DataType::Dictionary(_, value) = array.data_type() { + arrow_cast::cast(array, value)? + } else { + Arc::clone(array) + }; + Ok(arr) +} diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs new file mode 100644 index 000000000000..349da062a82d --- /dev/null +++ b/arrow-flight/tests/flight_sql_client.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +use crate::common::fixture::TestFixture; +use crate::common::utils::make_primitive_batch; + +use arrow_array::RecordBatch; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; +use arrow_flight::flight_service_server::FlightServiceServer; +use arrow_flight::sql::client::FlightSqlServiceClient; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; +use arrow_flight::sql::{ + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest, + CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, TableExistsOption, + TableNotExistOption, +}; +use arrow_flight::Action; +use futures::{StreamExt, TryStreamExt}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tonic::{Request, Status}; +use uuid::Uuid; + +#[tokio::test] +pub async fn test_begin_end_transaction() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + + // begin commit + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .unwrap(); + + // begin rollback + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Rollback) + .await + .unwrap(); + + // unknown transaction id + let transaction_id = "UnknownTransactionId".to_string().into(); + assert!(flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .is_err()); +} + +#[tokio::test] +pub async fn test_execute_ingest() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = make_ingest_command(); + let expected_rows = 10; + let batches = vec![ + make_primitive_batch(5), + make_primitive_batch(3), + make_primitive_batch(2), + ]; + let actual_rows = flight_sql_client + .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok)) + .await + .expect("ingest should succeed"); + assert_eq!(actual_rows, expected_rows); + // make sure the batches made it through to the server + let ingested_batches = test_server.ingested_batches.lock().await.clone(); + assert_eq!(ingested_batches, batches); +} + +#[tokio::test] +pub async fn test_execute_ingest_error() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = make_ingest_command(); + // send an error from the client + let batches = vec![ + Ok(make_primitive_batch(5)), + Err(FlightError::NotYetImplemented( + "Client error message".to_string(), + )), + ]; + // make sure the client returns the error from the client + let err = flight_sql_client + .execute_ingest(cmd, futures::stream::iter(batches)) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "External error: Not yet implemented: Client error message" + ); +} + +fn make_ingest_command() -> CommandStatementIngest { + CommandStatementIngest { + table_definition_options: Some(TableDefinitionOptions { + if_not_exist: TableNotExistOption::Create.into(), + if_exists: TableExistsOption::Fail.into(), + }), + table: String::from("test"), + schema: None, + catalog: None, + temporary: true, + transaction_id: None, + options: HashMap::default(), + } +} + +#[derive(Clone)] +pub struct FlightSqlServiceImpl { + transactions: Arc>>, + ingested_batches: Arc>>, +} + +impl FlightSqlServiceImpl { + pub fn new() -> Self { + Self { + transactions: Arc::new(Mutex::new(HashMap::new())), + ingested_batches: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } +} + +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + let transaction_id = Uuid::new_v4().to_string(); + self.transactions + .lock() + .await + .insert(transaction_id.clone(), ()); + Ok(ActionBeginTransactionResult { + transaction_id: transaction_id.as_bytes().to_vec().into(), + }) + } + + async fn do_action_end_transaction( + &self, + query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + let transaction_id = String::from_utf8(query.transaction_id.to_vec()) + .map_err(|_| Status::invalid_argument("Invalid transaction id"))?; + if self + .transactions + .lock() + .await + .remove(&transaction_id) + .is_none() + { + return Err(Status::invalid_argument("Transaction id not found")); + } + Ok(()) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + request: Request, + ) -> Result { + let batches: Vec = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect() + .await?; + let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); + *self.ingested_batches.lock().await.as_mut() = batches; + Ok(affected_rows) + } +} diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs new file mode 100644 index 000000000000..6e1f6142c8b6 --- /dev/null +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -0,0 +1,757 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +use std::{pin::Pin, sync::Arc}; + +use crate::common::fixture::TestFixture; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; +use arrow_flight::{ + decode::FlightRecordBatchStream, + encode::FlightDataEncoderBuilder, + flight_service_server::{FlightService, FlightServiceServer}, + sql::{ + server::{FlightSqlService, PeekableFlightDataStream}, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, + CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables, + CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult, + ProstMessageExt, SqlInfo, + }, + utils::batches_to_flight_data, + Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, +}; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_schema::{ArrowError, DataType, Field, Schema}; +use assert_cmd::Command; +use bytes::Bytes; +use futures::{Stream, TryStreamExt}; +use prost::Message; +use tonic::{Request, Response, Status, Streaming}; + +const QUERY: &str = "SELECT * FROM table;"; + +#[tokio::test] +async fn test_simple() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("statement-query") + .arg(QUERY) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + +#[tokio::test] +async fn test_get_catalogs() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("catalogs") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+\ + \n| catalog_name |\ + \n+--------------+\ + \n| catalog_a |\ + \n| catalog_b |\ + \n+--------------+", + ); +} + +#[tokio::test] +async fn test_get_db_schemas() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("db-schemas") + .arg("catalog_a") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+\ + \n| catalog_name | db_schema_name |\ + \n+--------------+----------------+\ + \n| catalog_a | schema_1 |\ + \n| catalog_a | schema_2 |\ + \n+--------------+----------------+", + ); +} + +#[tokio::test] +async fn test_get_tables() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("tables") + .arg("catalog_a") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+------------+------------+\ + \n| catalog_name | db_schema_name | table_name | table_type |\ + \n+--------------+----------------+------------+------------+\ + \n| catalog_a | schema_1 | table_1 | TABLE |\ + \n| catalog_a | schema_2 | table_2 | VIEW |\ + \n+--------------+----------------+------------+------------+", + ); +} +#[tokio::test] +async fn test_get_tables_db_filter() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("tables") + .arg("catalog_a") + .arg("--db-schema-filter") + .arg("schema_2") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+------------+------------+\ + \n| catalog_name | db_schema_name | table_name | table_type |\ + \n+--------------+----------------+------------+------------+\ + \n| catalog_a | schema_2 | table_2 | VIEW |\ + \n+--------------+----------------+------------+------------+", + ); +} + +#[tokio::test] +async fn test_get_tables_types() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("table-types") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+\ + \n| table_type |\ + \n+--------------+\ + \n| SYSTEM_TABLE |\ + \n| TABLE |\ + \n| VIEW |\ + \n+--------------+", + ); +} + +const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; +const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; +const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; + +async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("prepared-statement-query") + .arg(PREPARED_QUERY) + .args(["-p", "$1=string"]) + .args(["-p", "$2=64"]) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateless() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: true, + }) + .await +} + +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateful() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: false, + }) + .await +} + +#[derive(Clone)] +pub struct FlightSqlServiceImpl { + /// Whether to emulate stateless (true) or stateful (false) behavior for + /// prepared statements. stateful servers will not return an updated + /// handle after executing `DoPut(CommandPreparedStatementQuery)` + stateless_prepared_statements: bool, +} + +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self { + stateless_prepared_statements: true, + } + } +} + +impl FlightSqlServiceImpl { + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } + + fn fake_result() -> Result { + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let string_array = StringArray::from(vec!["Hello", "lovely", "FlightSQL!"]); + let int_array = Int64Array::from(vec![Some(42), None, Some(1337)]); + + let cols = vec![ + Arc::new(string_array) as ArrayRef, + Arc::new(int_array) as ArrayRef, + ]; + RecordBatch::try_new(Arc::new(schema), cols) + } + + fn create_fake_prepared_stmt() -> Result { + let handle = PREPARED_STATEMENT_HANDLE.to_string(); + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let parameter_schema = Schema::new(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]); + + Ok(ActionCreatePreparedStatementResult { + prepared_statement_handle: handle.into(), + dataset_schema: serialize_schema(&schema)?, + parameter_schema: serialize_schema(¶meter_schema)?, + }) + } + + fn fake_flight_info(&self) -> Result { + let batch = Self::fake_result()?; + + Ok(FlightInfo::new() + .try_with_schema(batch.schema_ref()) + .expect("encoding schema") + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_1"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_2"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_total_records(batch.num_rows() as i64) + .with_total_bytes(batch.get_array_memory_size() as i64) + .with_ordered(false)) + } +} + +fn serialize_schema(schema: &Schema) -> Result { + Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0) +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + + async fn do_handshake( + &self, + _request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + Err(Status::unimplemented("do_handshake not implemented")) + } + + async fn do_get_fallback( + &self, + _request: Request, + message: Any, + ) -> Result::DoGetStream>, Status> { + let part = message.unpack::().unwrap().unwrap().handle; + let batch = Self::fake_result().unwrap(); + let batch = match part.as_str() { + "part_1" => batch.slice(0, 2), + "part_2" => batch.slice(2, 1), + ticket => panic!("Invalid ticket: {ticket:?}"), + }; + let schema = batch.schema_ref(); + let batches = vec![batch.clone()]; + let flight_data = batches_to_flight_data(schema, batches) + .unwrap() + .into_iter() + .map(Ok); + + let stream: Pin> + Send>> = + Box::pin(futures::stream::iter(flight_data)); + let resp = Response::new(stream); + Ok(resp) + } + + async fn get_flight_info_catalogs( + &self, + query: CommandGetCatalogs, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + async fn get_flight_info_schemas( + &self, + query: CommandGetDbSchemas, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + + async fn get_flight_info_tables( + &self, + query: CommandGetTables, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + + async fn get_flight_info_table_types( + &self, + query: CommandGetTableTypes, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + _request: Request, + ) -> Result, Status> { + assert_eq!(query.query, QUERY); + + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) + } + + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + if self.stateless_prepared_statements { + assert_eq!( + cmd.prepared_statement_handle, + UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } else { + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) + } + + async fn do_get_catalogs( + &self, + query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for catalog_name in ["catalog_a", "catalog_b"] { + builder.append(catalog_name); + } + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_schemas( + &self, + query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for (catalog_name, schema_name) in [ + ("catalog_a", "schema_1"), + ("catalog_a", "schema_2"), + ("catalog_b", "schema_3"), + ] { + builder.append(catalog_name, schema_name); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_tables( + &self, + query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for (catalog_name, schema_name, table_name, table_type, schema) in [ + ( + "catalog_a", + "schema_1", + "table_1", + "TABLE", + Arc::new(Schema::empty()), + ), + ( + "catalog_a", + "schema_2", + "table_2", + "VIEW", + Arc::new(Schema::empty()), + ), + ( + "catalog_b", + "schema_3", + "table_3", + "TABLE", + Arc::new(Schema::empty()), + ), + ] { + builder + .append(catalog_name, schema_name, table_name, table_type, &schema) + .unwrap(); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_table_types( + &self, + query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for table_type in ["TABLE", "VIEW", "SYSTEM_TABLE"] { + builder.append(table_type); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_put_prepared_statement_query( + &self, + _query: CommandPreparedStatementQuery, + request: Request, + ) -> Result { + // just make sure decoding the parameters works + let parameters = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect::>() + .await?; + + for (left, right) in parameters[0].schema().flattened_fields().iter().zip(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]) { + if left.name() != right.name() || left.data_type() != right.data_type() { + return Err(Status::invalid_argument(format!( + "Parameters did not match parameter schema\ngot {}", + parameters[0].schema(), + ))); + } + } + let handle = if self.stateless_prepared_statements { + UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into() + } else { + PREPARED_STATEMENT_HANDLE.to_string().into() + }; + let result = DoPutPreparedStatementResult { + prepared_statement_handle: Some(handle), + }; + Ok(result) + } + + async fn do_action_create_prepared_statement( + &self, + _query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + Self::create_fake_prepared_stmt() + .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchResults { + #[prost(string, tag = "1")] + pub handle: ::prost::alloc::string::String, +} + +impl ProstMessageExt for FetchResults { + fn type_url() -> &'static str { + "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" + } + + fn as_any(&self) -> Any { + Any { + type_url: FetchResults::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} diff --git a/arrow-integration-test/Cargo.toml b/arrow-integration-test/Cargo.toml new file mode 100644 index 000000000000..8afbfacff7c3 --- /dev/null +++ b/arrow-integration-test/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-integration-test" +version = { workspace = true } +description = "Support for the Apache Arrow JSON test data format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_integration_test" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow = { workspace = true } +arrow-buffer = { workspace = true } +hex = { version = "0.4", default-features = false, features = ["std"] } +serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +num = { version = "0.4", default-features = false, features = ["std"] } + +[build-dependencies] diff --git a/integration-testing/data/integration.json b/arrow-integration-test/data/integration.json similarity index 100% rename from integration-testing/data/integration.json rename to arrow-integration-test/data/integration.json diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs new file mode 100644 index 000000000000..e45e94c24e07 --- /dev/null +++ b/arrow-integration-test/src/datatype.rs @@ -0,0 +1,373 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode}; +use arrow::error::{ArrowError, Result}; +use std::sync::Arc; + +/// Parse a data type from a JSON representation. +pub fn data_type_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + let default_field = Arc::new(Field::new("", DataType::Boolean, true)); + match *json { + Value::Object(ref map) => match map.get("name") { + Some(s) if s == "null" => Ok(DataType::Null), + Some(s) if s == "bool" => Ok(DataType::Boolean), + Some(s) if s == "binary" => Ok(DataType::Binary), + Some(s) if s == "largebinary" => Ok(DataType::LargeBinary), + Some(s) if s == "utf8" => Ok(DataType::Utf8), + Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8), + Some(s) if s == "fixedsizebinary" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("byteWidth") { + Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32)) + } else { + Err(ArrowError::ParseError( + "Expecting a byteWidth for fixedsizebinary".to_string(), + )) + } + } + Some(s) if s == "decimal" => { + // return a list with any type as its child isn't defined in the map + let precision = match map.get("precision") { + Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()), + None => Err(ArrowError::ParseError( + "Expecting a precision for decimal".to_string(), + )), + }?; + let scale = match map.get("scale") { + Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()), + _ => Err(ArrowError::ParseError( + "Expecting a scale for decimal".to_string(), + )), + }?; + let bit_width: usize = match map.get("bitWidth") { + Some(b) => b.as_u64().unwrap() as usize, + _ => 128, // Default bit width + }; + + if bit_width == 128 { + Ok(DataType::Decimal128(precision, scale)) + } else if bit_width == 256 { + Ok(DataType::Decimal256(precision, scale)) + } else { + Err(ArrowError::ParseError( + "Decimal bit_width invalid".to_string(), + )) + } + } + Some(s) if s == "floatingpoint" => match map.get("precision") { + Some(p) if p == "HALF" => Ok(DataType::Float16), + Some(p) if p == "SINGLE" => Ok(DataType::Float32), + Some(p) if p == "DOUBLE" => Ok(DataType::Float64), + _ => Err(ArrowError::ParseError( + "floatingpoint precision missing or invalid".to_string(), + )), + }, + Some(s) if s == "timestamp" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "timestamp unit missing or invalid".to_string(), + )), + }; + let tz = match map.get("timezone") { + None => Ok(None), + Some(Value::String(tz)) => Ok(Some(tz.as_str().into())), + _ => Err(ArrowError::ParseError( + "timezone must be a string".to_string(), + )), + }; + Ok(DataType::Timestamp(unit?, tz?)) + } + Some(s) if s == "date" => match map.get("unit") { + Some(p) if p == "DAY" => Ok(DataType::Date32), + Some(p) if p == "MILLISECOND" => Ok(DataType::Date64), + _ => Err(ArrowError::ParseError( + "date unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "time" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }; + match map.get("bitWidth") { + Some(p) if p == 32 => Ok(DataType::Time32(unit?)), + Some(p) if p == 64 => Ok(DataType::Time64(unit?)), + _ => Err(ArrowError::ParseError( + "time bitWidth missing or invalid".to_string(), + )), + } + } + Some(s) if s == "duration" => match map.get("unit") { + Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)), + Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)), + Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)), + Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "interval" => match map.get("unit") { + Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)), + Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)), + Some(p) if p == "MONTH_DAY_NANO" => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => Err(ArrowError::ParseError( + "interval unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "int" => match map.get("isSigned") { + Some(&Value::Bool(true)) => match map.get("bitWidth") { + Some(Value::Number(n)) => match n.as_u64() { + Some(8) => Ok(DataType::Int8), + Some(16) => Ok(DataType::Int16), + Some(32) => Ok(DataType::Int32), + Some(64) => Ok(DataType::Int64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + Some(&Value::Bool(false)) => match map.get("bitWidth") { + Some(Value::Number(n)) => match n.as_u64() { + Some(8) => Ok(DataType::UInt8), + Some(16) => Ok(DataType::UInt16), + Some(32) => Ok(DataType::UInt32), + Some(64) => Ok(DataType::UInt64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int signed missing or invalid".to_string(), + )), + }, + Some(s) if s == "list" => { + // return a list with any type as its child isn't defined in the map + Ok(DataType::List(default_field)) + } + Some(s) if s == "largelist" => { + // return a largelist with any type as its child isn't defined in the map + Ok(DataType::LargeList(default_field)) + } + Some(s) if s == "fixedsizelist" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("listSize") { + Ok(DataType::FixedSizeList( + default_field, + size.as_i64().unwrap() as i32, + )) + } else { + Err(ArrowError::ParseError( + "Expecting a listSize for fixedsizelist".to_string(), + )) + } + } + Some(s) if s == "struct" => { + // return an empty `struct` type as its children aren't defined in the map + Ok(DataType::Struct(Fields::empty())) + } + Some(s) if s == "map" => { + if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") { + // Return a map with an empty type as its children aren't defined in the map + Ok(DataType::Map(default_field, *keys_sorted)) + } else { + Err(ArrowError::ParseError( + "Expecting a keysSorted for map".to_string(), + )) + } + } + Some(s) if s == "union" => { + if let Some(Value::String(mode)) = map.get("mode") { + let union_mode = if mode == "SPARSE" { + UnionMode::Sparse + } else if mode == "DENSE" { + UnionMode::Dense + } else { + return Err(ArrowError::ParseError(format!( + "Unknown union mode {mode:?} for union" + ))); + }; + if let Some(values) = map.get("typeIds") { + let values = values.as_array().unwrap(); + let fields = values + .iter() + .map(|t| (t.as_i64().unwrap() as i8, default_field.clone())) + .collect(); + + Ok(DataType::Union(fields, union_mode)) + } else { + Err(ArrowError::ParseError( + "Expecting a typeIds for union ".to_string(), + )) + } + } else { + Err(ArrowError::ParseError( + "Expecting a mode for union".to_string(), + )) + } + } + Some(other) => Err(ArrowError::ParseError(format!( + "invalid or unsupported type name: {other} in {json:?}" + ))), + None => Err(ArrowError::ParseError("type name missing".to_string())), + }, + _ => Err(ArrowError::ParseError( + "invalid json value type".to_string(), + )), + } +} + +/// Generate a JSON representation of the data type. +pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { + use serde_json::json; + match data_type { + DataType::Null => json!({"name": "null"}), + DataType::Boolean => json!({"name": "bool"}), + DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), + DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), + DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), + DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), + DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), + DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), + DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), + DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), + DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), + DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), + DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), + DataType::Utf8 => json!({"name": "utf8"}), + DataType::LargeUtf8 => json!({"name": "largeutf8"}), + DataType::Binary => json!({"name": "binary"}), + DataType::LargeBinary => json!({"name": "largebinary"}), + DataType::BinaryView | DataType::Utf8View => { + unimplemented!("BinaryView/Utf8View not implemented") + } + DataType::FixedSizeBinary(byte_width) => { + json!({"name": "fixedsizebinary", "byteWidth": byte_width}) + } + DataType::Struct(_) => json!({"name": "struct"}), + DataType::Union(_, _) => json!({"name": "union"}), + DataType::List(_) => json!({ "name": "list"}), + DataType::LargeList(_) => json!({ "name": "largelist"}), + DataType::ListView(_) | DataType::LargeListView(_) => { + unimplemented!("ListView/LargeListView not implemented") + } + DataType::FixedSizeList(_, length) => { + json!({"name":"fixedsizelist", "listSize": length}) + } + DataType::Time32(unit) => { + json!({"name": "time", "bitWidth": 32, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Time64(unit) => { + json!({"name": "time", "bitWidth": 64, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Date32 => { + json!({"name": "date", "unit": "DAY"}) + } + DataType::Date64 => { + json!({"name": "date", "unit": "MILLISECOND"}) + } + DataType::Timestamp(unit, None) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Timestamp(unit, Some(tz)) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }, "timezone": tz}) + } + DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { + IntervalUnit::YearMonth => "YEAR_MONTH", + IntervalUnit::DayTime => "DAY_TIME", + IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", + }}), + DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}), + DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), + DataType::Decimal128(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128}) + } + DataType::Decimal256(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256}) + } + DataType::Map(_, keys_sorted) => { + json!({"name": "map", "keysSorted": keys_sorted}) + } + DataType::RunEndEncoded(_, _) => todo!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::Value; + + #[test] + fn parse_utf8_from_json() { + let json = "{\"name\":\"utf8\"}"; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = data_type_from_json(&value).unwrap(); + assert_eq!(DataType::Utf8, dt); + } + + #[test] + fn parse_int32_from_json() { + let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}"; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = data_type_from_json(&value).unwrap(); + assert_eq!(DataType::Int32, dt); + } +} diff --git a/arrow-integration-test/src/field.rs b/arrow-integration-test/src/field.rs new file mode 100644 index 000000000000..32edc4165938 --- /dev/null +++ b/arrow-integration-test/src/field.rs @@ -0,0 +1,568 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{data_type_from_json, data_type_to_json}; +use arrow::datatypes::{DataType, Field}; +use arrow::error::{ArrowError, Result}; +use std::collections::HashMap; +use std::sync::Arc; + +/// Parse a `Field` definition from a JSON representation. +pub fn field_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + match *json { + Value::Object(ref map) => { + let name = match map.get("name") { + Some(Value::String(name)) => name.to_string(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'name' attribute".to_string(), + )); + } + }; + let nullable = match map.get("nullable") { + Some(&Value::Bool(b)) => b, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'nullable' attribute".to_string(), + )); + } + }; + let data_type = match map.get("type") { + Some(t) => data_type_from_json(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'type' attribute".to_string(), + )); + } + }; + + // Referenced example file: testing/data/arrow-ipc-stream/integration/1.0.0-littleendian/generated_custom_metadata.json.gz + let metadata = match map.get("metadata") { + Some(Value::Array(values)) => { + let mut res: HashMap = HashMap::default(); + for value in values { + match value.as_object() { + Some(map) => { + if map.len() != 2 { + return Err(ArrowError::ParseError( + "Field 'metadata' must have exact two entries for each key-value map".to_string(), + )); + } + if let (Some(k), Some(v)) = (map.get("key"), map.get("value")) { + if let (Some(k_str), Some(v_str)) = (k.as_str(), v.as_str()) { + res.insert( + k_str.to_string().clone(), + v_str.to_string().clone(), + ); + } else { + return Err(ArrowError::ParseError( + "Field 'metadata' must have map value of string type" + .to_string(), + )); + } + } else { + return Err(ArrowError::ParseError("Field 'metadata' lacks map keys named \"key\" or \"value\"".to_string())); + } + } + _ => { + return Err(ArrowError::ParseError( + "Field 'metadata' contains non-object key-value pair" + .to_string(), + )); + } + } + } + res + } + // We also support map format, because Schema's metadata supports this. + // See https://github.com/apache/arrow/pull/5907 + Some(Value::Object(values)) => { + let mut res: HashMap = HashMap::default(); + for (k, v) in values { + if let Some(str_value) = v.as_str() { + res.insert(k.clone(), str_value.to_string().clone()); + } else { + return Err(ArrowError::ParseError(format!( + "Field 'metadata' contains non-string value for key {k}" + ))); + } + } + res + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field `metadata` is not json array".to_string(), + )); + } + _ => HashMap::default(), + }; + + // if data_type is a struct or list, get its children + let data_type = match data_type { + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { + match map.get("children") { + Some(Value::Array(values)) => { + if values.len() != 1 { + return Err(ArrowError::ParseError( + "Field 'children' must have one element for a list data type" + .to_string(), + )); + } + match data_type { + DataType::List(_) => { + DataType::List(Arc::new(field_from_json(&values[0])?)) + } + DataType::LargeList(_) => { + DataType::LargeList(Arc::new(field_from_json(&values[0])?)) + } + DataType::FixedSizeList(_, int) => DataType::FixedSizeList( + Arc::new(field_from_json(&values[0])?), + int, + ), + _ => unreachable!( + "Data type should be a list, largelist or fixedsizelist" + ), + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } + DataType::Struct(_) => match map.get("children") { + Some(Value::Array(values)) => { + DataType::Struct(values.iter().map(field_from_json).collect::>()?) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, + DataType::Map(_, keys_sorted) => { + match map.get("children") { + Some(Value::Array(values)) if values.len() == 1 => { + let child = field_from_json(&values[0])?; + // child must be a struct + match child.data_type() { + DataType::Struct(map_fields) if map_fields.len() == 2 => { + DataType::Map(Arc::new(child), keys_sorted) + } + t => { + return Err(ArrowError::ParseError(format!( + "Map children should be a struct with 2 fields, found {t:?}" + ))) + } + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array with 1 element".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } + DataType::Union(fields, mode) => match map.get("children") { + Some(Value::Array(values)) => { + let fields = fields + .iter() + .zip(values) + .map(|((id, _), value)| Ok((id, Arc::new(field_from_json(value)?)))) + .collect::>()?; + + DataType::Union(fields, mode) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, + _ => data_type, + }; + + let mut dict_id = 0; + let mut dict_is_ordered = false; + + let data_type = match map.get("dictionary") { + Some(dictionary) => { + let index_type = match dictionary.get("indexType") { + Some(t) => data_type_from_json(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'indexType' attribute".to_string(), + )); + } + }; + dict_id = match dictionary.get("id") { + Some(Value::Number(n)) => n.as_i64().unwrap(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'id' attribute".to_string(), + )); + } + }; + dict_is_ordered = match dictionary.get("isOrdered") { + Some(&Value::Bool(n)) => n, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'isOrdered' attribute".to_string(), + )); + } + }; + DataType::Dictionary(Box::new(index_type), Box::new(data_type)) + } + _ => data_type, + }; + + let mut field = Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered); + field.set_metadata(metadata); + Ok(field) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for field".to_string(), + )), + } +} + +/// Generate a JSON representation of the `Field`. +pub fn field_to_json(field: &Field) -> serde_json::Value { + let children: Vec = match field.data_type() { + DataType::Struct(fields) => fields.iter().map(|x| field_to_json(x.as_ref())).collect(), + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => vec![field_to_json(field)], + _ => vec![], + }; + + match field.data_type() { + DataType::Dictionary(ref index_type, ref value_type) => serde_json::json!({ + "name": field.name(), + "nullable": field.is_nullable(), + "type": data_type_to_json(value_type), + "children": children, + "dictionary": { + "id": field.dict_id().unwrap(), + "indexType": data_type_to_json(index_type), + "isOrdered": field.dict_is_ordered().unwrap(), + } + }), + _ => serde_json::json!({ + "name": field.name(), + "nullable": field.is_nullable(), + "type": data_type_to_json(field.data_type()), + "children": children + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::UnionMode; + use serde_json::Value; + + #[test] + fn struct_field_to_json() { + let f = Field::new_struct( + "address", + vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ], + false, + ); + let value: Value = serde_json::from_str( + r#"{ + "name": "address", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "street", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "zip", + "nullable": false, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + + #[test] + fn map_field_to_json() { + let f = Field::new_map( + "my_map", + "my_entries", + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + true, + false, + ); + let value: Value = serde_json::from_str( + r#"{ + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + + #[test] + fn primitive_field_to_json() { + let f = Field::new("first_name", DataType::Utf8, false); + let value: Value = serde_json::from_str( + r#"{ + "name": "first_name", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + #[test] + fn parse_struct_from_json() { + let json = r#" + { + "name": "address", + "type": { + "name": "struct" + }, + "nullable": false, + "children": [ + { + "name": "street", + "type": { + "name": "utf8" + }, + "nullable": false, + "children": [] + }, + { + "name": "zip", + "type": { + "name": "int", + "isSigned": false, + "bitWidth": 16 + }, + "nullable": false, + "children": [] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_struct( + "address", + vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ], + false, + ); + + assert_eq!(expected, dt); + } + + #[test] + fn parse_map_from_json() { + let json = r#" + { + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_map( + "my_map", + "my_entries", + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + true, + false, + ); + + assert_eq!(expected, dt); + } + + #[test] + fn parse_union_from_json() { + let json = r#" + { + "name": "my_union", + "nullable": false, + "type": { + "name": "union", + "mode": "SPARSE", + "typeIds": [ + 5, + 7 + ] + }, + "children": [ + { + "name": "f1", + "type": { + "name": "int", + "isSigned": true, + "bitWidth": 32 + }, + "nullable": true, + "children": [] + }, + { + "name": "f2", + "type": { + "name": "utf8" + }, + "nullable": true, + "children": [] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_union( + "my_union", + vec![5, 7], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + UnionMode::Sparse, + ); + + assert_eq!(expected, dt); + } +} diff --git a/integration-testing/src/util.rs b/arrow-integration-test/src/lib.rs similarity index 73% rename from integration-testing/src/util.rs rename to arrow-integration-test/src/lib.rs index e098c4e1491a..d1486fd5a153 100644 --- a/integration-testing/src/util.rs +++ b/arrow-integration-test/src/lib.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! Utils for JSON integration testing +//! Support for the [Apache Arrow JSON test data format](https://github.com/apache/arrow/blob/master/docs/source/format/Integration.rst#json-test-data-format) //! //! These utilities define structs that read the integration JSON format for integration testing purposes. +//! +//! This is not a canonical format, but provides a human-readable way of verifying language implementations +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use hex::decode; use num::BigInt; use num::Signed; @@ -29,14 +32,21 @@ use std::sync::Arc; use arrow::array::*; use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::compute; use arrow::datatypes::*; use arrow::error::{ArrowError, Result}; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; use arrow::util::bit_util; -use arrow::util::decimal::Decimal256; + +mod datatype; +mod field; +mod schema; + +pub use datatype::*; +pub use field::*; +pub use schema::*; /// A struct that represents an Arrow file with a schema and record batches +/// +/// See #[derive(Deserialize, Serialize, Debug)] pub struct ArrowJson { pub schema: ArrowJsonSchema, @@ -69,12 +79,18 @@ pub struct ArrowJsonField { pub metadata: Option, } +impl From<&FieldRef> for ArrowJsonField { + fn from(value: &FieldRef) -> Self { + Self::from(value.as_ref()) + } +} + impl From<&Field> for ArrowJsonField { fn from(field: &Field) -> Self { - let metadata_value = match field.metadata() { - Some(kv_list) => { + let metadata_value = match field.metadata().is_empty() { + false => { let mut array = Vec::new(); - for (k, v) in kv_list { + for (k, v) in field.metadata() { let mut kv_map = SJMap::new(); kv_map.insert(k.clone(), Value::String(v.clone())); array.push(Value::Object(kv_map)); @@ -90,7 +106,7 @@ impl From<&Field> for ArrowJsonField { Self { name: field.name().to_string(), - field_type: field.data_type().to_json(), + field_type: data_type_to_json(field.data_type()), nullable: field.is_nullable(), children: vec![], dictionary: None, // TODO: not enough info @@ -160,12 +176,13 @@ impl ArrowJson { match batch { Some(Ok(batch)) => { if json_batch != batch { - println!("json: {:?}", json_batch); - println!("batch: {:?}", batch); + println!("json: {json_batch:?}"); + println!("batch: {batch:?}"); return Ok(false); } } - _ => return Ok(false), + Some(Err(e)) => return Err(e), + None => return Ok(false), } } @@ -242,10 +259,7 @@ impl ArrowJsonField { true } Err(e) => { - eprintln!( - "Encountered error while converting JSON field to Arrow field: {:?}", - e - ); + eprintln!("Encountered error while converting JSON field to Arrow field: {e:?}"); false } } @@ -255,8 +269,9 @@ impl ArrowJsonField { /// TODO: convert to use an Into fn to_arrow_field(&self) -> Result { // a bit regressive, but we have to convert the field to JSON in order to convert it - let field = serde_json::to_value(self)?; - Field::from(&field) + let field = + serde_json::to_value(self).map_err(|error| ArrowError::JsonError(error.to_string()))?; + field_from_json(&field) } } @@ -310,10 +325,7 @@ pub fn array_from_json( { match is_valid { 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) + ArrowError::JsonError(format!("Unable to get {value:?} as int64")) })? as i8), _ => b.append_null(), }; @@ -336,10 +348,7 @@ pub fn array_from_json( } Ok(Arc::new(b.finish())) } - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { let mut b = Int32Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity @@ -354,14 +363,29 @@ pub fn array_from_json( }; } let array = Arc::new(b.finish()) as ArrayRef; - compute::cast(&array, field.data_type()) + arrow::compute::cast(&array, field.data_type()) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = IntervalYearMonthBuilder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { + | DataType::Duration(_) => { let mut b = Int64Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity @@ -373,12 +397,28 @@ pub fn array_from_json( match is_valid { 1 => b.append_value(match value { Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } + Value::String(s) => s.parse().expect("Unable to parse string as i64"), + _ => panic!("Unable to parse {value:?} as number"), + }), + _ => b.append_null(), + }; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::Interval(IntervalUnit::DayTime) => { + let mut b = IntervalDayTimeBuilder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { Value::Object(ref map) - if map.contains_key("days") - && map.contains_key("milliseconds") => + if map.contains_key("days") && map.contains_key("milliseconds") => { match field.data_type() { DataType::Interval(IntervalUnit::DayTime) => { @@ -387,35 +427,24 @@ pub fn array_from_json( match (days, milliseconds) { (Value::Number(d), Value::Number(m)) => { - let mut bytes = [0_u8; 8]; - let m = (m.as_i64().unwrap() as i32) - .to_le_bytes(); - let d = (d.as_i64().unwrap() as i32) - .to_le_bytes(); - - let c = [d, m].concat(); - bytes.copy_from_slice(c.as_slice()); - i64::from_le_bytes(bytes) + let days = d.as_i64().unwrap() as _; + let millis = m.as_i64().unwrap() as _; + IntervalDayTime::new(days, millis) + } + _ => { + panic!("Unable to parse {value:?} as interval daytime") } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), } } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), + _ => panic!("Unable to parse {value:?} as interval daytime"), } } - _ => panic!("Unable to parse {:?} as number", value), + _ => panic!("Unable to parse {value:?} as number"), }), _ => b.append_null(), }; } - let array = Arc::new(b.finish()) as ArrayRef; - compute::cast(&array, field.data_type()) + Ok(Arc::new(b.finish())) } DataType::UInt8 => { let mut b = UInt8Builder::with_capacity(json_col.count); @@ -485,11 +514,9 @@ pub fn array_from_json( .expect("Unable to parse string as u64"), ) } else if value.is_number() { - b.append_value( - value.as_u64().expect("Unable to read number as u64"), - ) + b.append_value(value.as_u64().expect("Unable to read number as u64")) } else { - panic!("Unable to parse value {:?} as u64", value) + panic!("Unable to parse value {value:?} as u64") } } _ => b.append_null(), @@ -521,19 +548,14 @@ pub fn array_from_json( let months = months.as_i64().unwrap() as i32; let days = days.as_i64().unwrap() as i32; let nanoseconds = nanoseconds.as_i64().unwrap(); - let months_days_ns: i128 = ((nanoseconds as i128) - & 0xFFFFFFFFFFFFFFFF) - << 64 - | ((days as i128) & 0xFFFFFFFF) << 32 - | ((months as i128) & 0xFFFFFFFF); - months_days_ns + IntervalMonthDayNano::new(months, days, nanoseconds) } (_, _, _) => { - panic!("Unable to parse {:?} as MonthDayNano", v) + panic!("Unable to parse {v:?} as MonthDayNano") } } } - _ => panic!("Unable to parse {:?} as MonthDayNano", value), + _ => panic!("Unable to parse {value:?} as MonthDayNano"), }), _ => b.append_null(), }; @@ -664,11 +686,7 @@ pub fn array_from_json( DataType::List(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = array_from_json(child_field, children[0].clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -678,7 +696,7 @@ pub fn array_from_json( let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_buffer(Buffer::from(offsets.to_byte_slice())) .add_child_data(child_array.into_data()) .null_bit_buffer(Some(null_buf)) .build() @@ -688,11 +706,7 @@ pub fn array_from_json( DataType::LargeList(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = array_from_json(child_field, children[0].clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -706,7 +720,7 @@ pub fn array_from_json( let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_buffer(Buffer::from(offsets.to_byte_slice())) .add_child_data(child_array.into_data()) .null_bit_buffer(Some(null_buf)) .build() @@ -715,11 +729,7 @@ pub fn array_from_json( } DataType::FixedSizeList(child_field, _) => { let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = array_from_json(child_field, children[0].clone(), dictionaries)?; let null_buf = create_null_buf(&json_col); let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) @@ -746,17 +756,13 @@ pub fn array_from_json( } DataType::Dictionary(key_type, value_type) => { let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) + ArrowError::JsonError(format!("Unable to find dict_id for field {field:?}")) })?; // find dictionary let dictionary = dictionaries .ok_or_else(|| { ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field + "Unable to find any dictionaries for field {field:?}" )) })? .get(&dict_id); @@ -770,18 +776,12 @@ pub fn array_from_json( dictionaries, ), None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field + "Unable to find dictionary for field {field:?}" ))), } } DataType::Decimal128(precision, scale) => { - let mut b = - Decimal128Builder::with_capacity(json_col.count, *precision, *scale); - // C++ interop tests involve incompatible decimal values - unsafe { - b.disable_value_validation(); - } + let mut b = Decimal128Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity .as_ref() @@ -790,21 +790,16 @@ pub fn array_from_json( .zip(json_col.data.unwrap()) { match is_valid { - 1 => { - b.append_value(value.as_str().unwrap().parse::().unwrap())? - } + 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), _ => b.append_null(), }; } - Ok(Arc::new(b.finish())) + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) } DataType::Decimal256(precision, scale) => { - let mut b = - Decimal256Builder::with_capacity(json_col.count, *precision, *scale); - // C++ interop tests involve incompatible decimal values - unsafe { - b.disable_value_validation(); - } + let mut b = Decimal256Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity .as_ref() @@ -822,26 +817,20 @@ pub fn array_from_json( } else { [255_u8; 32] }; - bytes[0..integer_bytes.len()] - .copy_from_slice(integer_bytes.as_slice()); - let decimal = - Decimal256::try_new_from_bytes(*precision, *scale, &bytes) - .unwrap(); - b.append_value(&decimal)?; + bytes[0..integer_bytes.len()].copy_from_slice(integer_bytes.as_slice()); + b.append_value(i256::from_le_bytes(bytes)); } _ => b.append_null(), } } - Ok(Arc::new(b.finish())) + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) } DataType::Map(child_field, _) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = array_from_json(child_field, children[0].clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -850,7 +839,7 @@ pub fn array_from_json( .collect(); let array_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_buffer(Buffer::from(offsets.to_byte_slice())) .add_child_data(child_array.into_data()) .null_bit_buffer(Some(null_buf)) .build() @@ -859,7 +848,7 @@ pub fn array_from_json( let array = MapArray::from(array_data); Ok(Arc::new(array)) } - DataType::Union(fields, field_type_ids, _) => { + DataType::Union(fields, _) => { let type_ids = if let Some(type_id) = json_col.type_id { type_id } else { @@ -868,30 +857,22 @@ pub fn array_from_json( )); }; - let offset: Option = json_col.offset.map(|offsets| { - let offsets: Vec = - offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect(); - Buffer::from(&offsets.to_byte_slice()) - }); + let offset: Option> = json_col + .offset + .map(|offsets| offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect()); - let mut children: Vec<(Field, Arc)> = vec![]; - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let mut children = Vec::with_capacity(fields.len()); + for ((_, field), col) in fields.iter().zip(json_col.children.unwrap()) { let array = array_from_json(field, col, dictionaries)?; - children.push((field.clone(), array)); + children.push(array); } - let array = UnionArray::try_new( - field_type_ids, - Buffer::from(&type_ids.to_byte_slice()), - offset, - children, - ) - .unwrap(); + let array = + UnionArray::try_new(fields.clone(), type_ids.into(), offset, children).unwrap(); Ok(Arc::new(array)) } t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t + "data type {t:?} not supported" ))), } } @@ -939,16 +920,14 @@ pub fn dictionary_array_from_json( // convert key and value to dictionary data let dict_data = ArrayData::builder(field.data_type().clone()) .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) + .add_buffer(keys.to_data().buffers()[0].clone()) .null_bit_buffer(Some(null_buf)) .add_child_data(values.into_data()) .build() .unwrap(); let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } + DataType::Int8 => Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef, DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), @@ -961,13 +940,12 @@ pub fn dictionary_array_from_json( Ok(array) } _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key + "Dictionary key type {dict_key:?} not supported" ))), } } -/// A helper to create a null buffer from a Vec +/// A helper to create a null buffer from a `Vec` fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { let num_bytes = bit_util::ceil(json_col.count, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); @@ -1045,9 +1023,6 @@ mod tests { use std::fs::File; use std::io::Read; - use std::sync::Arc; - - use arrow::buffer::Buffer; #[test] fn test_schema_equality() { @@ -1100,11 +1075,7 @@ mod tests { Field::new("c3", DataType::Utf8, true), Field::new( "c4", - DataType::List(Box::new(Field::new( - "custom_item", - DataType::Int32, - false, - ))), + DataType::List(Arc::new(Field::new("custom_item", DataType::Int32, false))), true, ), ]); @@ -1113,100 +1084,95 @@ mod tests { #[test] fn test_arrow_data_equality() { - let secs_tz = Some("Europe/Budapest".to_string()); - let millis_tz = Some("America/New_York".to_string()); - let micros_tz = Some("UTC".to_string()); - let nanos_tz = Some("Africa/Johannesburg".to_string()); + let secs_tz = Some("Europe/Budapest".into()); + let millis_tz = Some("America/New_York".into()); + let micros_tz = Some("UTC".into()); + let nanos_tz = Some("Africa/Johannesburg".into()); - let schema = - Schema::new(vec![ - Field::new("bools-with-metadata-map", DataType::Boolean, true) - .with_metadata(Some( - [("k".to_string(), "v".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools-with-metadata-vec", DataType::Boolean, true) - .with_metadata(Some( - [("k2".to_string(), "v2".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools", DataType::Boolean, true), - Field::new("int8s", DataType::Int8, true), - Field::new("int16s", DataType::Int16, true), - Field::new("int32s", DataType::Int32, true), - Field::new("int64s", DataType::Int64, true), - Field::new("uint8s", DataType::UInt8, true), - Field::new("uint16s", DataType::UInt16, true), - Field::new("uint32s", DataType::UInt32, true), - Field::new("uint64s", DataType::UInt64, true), - Field::new("float32s", DataType::Float32, true), - Field::new("float64s", DataType::Float64, true), - Field::new("date_days", DataType::Date32, true), - Field::new("date_millis", DataType::Date64, true), - Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), - Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), - Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), - Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), - Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), - Field::new( - "ts_millis", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "ts_micros", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "ts_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new( - "ts_secs_tz", - DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), - true, - ), - Field::new( - "ts_millis_tz", - DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), - true, - ), - Field::new( - "ts_micros_tz", - DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), - true, - ), - Field::new( - "ts_nanos_tz", - DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), - true, - ), - Field::new("utf8s", DataType::Utf8, true), - Field::new( - "lists", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "structs", - DataType::Struct(vec![ - Field::new("int32s", DataType::Int32, true), - Field::new("utf8s", DataType::Utf8, true), - ]), - true, - ), - ]); + let schema = Schema::new(vec![ + Field::new("bools-with-metadata-map", DataType::Boolean, true).with_metadata( + [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(), + ), + Field::new("bools-with-metadata-vec", DataType::Boolean, true).with_metadata( + [("k2".to_string(), "v2".to_string())] + .iter() + .cloned() + .collect(), + ), + Field::new("bools", DataType::Boolean, true), + Field::new("int8s", DataType::Int8, true), + Field::new("int16s", DataType::Int16, true), + Field::new("int32s", DataType::Int32, true), + Field::new("int64s", DataType::Int64, true), + Field::new("uint8s", DataType::UInt8, true), + Field::new("uint16s", DataType::UInt16, true), + Field::new("uint32s", DataType::UInt32, true), + Field::new("uint64s", DataType::UInt64, true), + Field::new("float32s", DataType::Float32, true), + Field::new("float64s", DataType::Float64, true), + Field::new("date_days", DataType::Date32, true), + Field::new("date_millis", DataType::Date64, true), + Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), + Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new( + "ts_millis", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micros", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ts_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "ts_secs_tz", + DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), + true, + ), + Field::new( + "ts_millis_tz", + DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), + true, + ), + Field::new( + "ts_micros_tz", + DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), + true, + ), + Field::new( + "ts_nanos_tz", + DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), + true, + ), + Field::new("utf8s", DataType::Utf8, true), + Field::new( + "lists", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "structs", + DataType::Struct(Fields::from(vec![ + Field::new("int32s", DataType::Int32, true), + Field::new("utf8s", DataType::Utf8, true), + ])), + true, + ), + ]); - let bools_with_metadata_map = - BooleanArray::from(vec![Some(true), None, Some(false)]); - let bools_with_metadata_vec = - BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_map = BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_vec = BooleanArray::from(vec![Some(true), None, Some(false)]); let bools = BooleanArray::from(vec![Some(true), None, Some(false)]); let int8s = Int8Array::from(vec![Some(1), None, Some(3)]); let int16s = Int16Array::from(vec![Some(1), None, Some(3)]); @@ -1224,54 +1190,32 @@ mod tests { Some(29923997007884), Some(30612271819236), ]); - let time_secs = - Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); - let time_millis = Time32MillisecondArray::from(vec![ - Some(6613125), - Some(74667230), - Some(52260079), - ]); - let time_micros = - Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); - let time_nanos = Time64NanosecondArray::from(vec![ - Some(73380123595985), - None, - Some(16584393546415), - ]); - let ts_secs = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - None, - ); - let ts_millis = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - None, - ); - let ts_micros = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], None); - let ts_nanos = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - None, - ); - let ts_secs_tz = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - secs_tz, - ); - let ts_millis_tz = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - millis_tz, - ); + let time_secs = Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); + let time_millis = + Time32MillisecondArray::from(vec![Some(6613125), Some(74667230), Some(52260079)]); + let time_micros = Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); + let time_nanos = + Time64NanosecondArray::from(vec![Some(73380123595985), None, Some(16584393546415)]); + let ts_secs = TimestampSecondArray::from(vec![None, Some(193438817552), None]); + let ts_millis = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]); + let ts_micros = TimestampMicrosecondArray::from(vec![None, None, None]); + let ts_nanos = TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]); + let ts_secs_tz = TimestampSecondArray::from(vec![None, Some(193438817552), None]) + .with_timezone_opt(secs_tz); + let ts_millis_tz = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]) + .with_timezone_opt(millis_tz); let ts_micros_tz = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], micros_tz); - let ts_nanos_tz = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - nanos_tz, - ); + TimestampMicrosecondArray::from(vec![None, None, None]).with_timezone_opt(micros_tz); + let ts_nanos_tz = + TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]) + .with_timezone_opt(nanos_tz); let utf8s = StringArray::from(vec![Some("aa"), None, Some("bbb")]); let value_data = Int32Array::from(vec![None, Some(2), None, None]); - let value_offsets = Buffer::from_slice_ref(&[0, 3, 4, 4]); - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let value_offsets = Buffer::from_slice_ref([0, 3, 4, 4]); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -1283,14 +1227,14 @@ mod tests { let structs_int32s = Int32Array::from(vec![None, Some(-2), None]); let structs_utf8s = StringArray::from(vec![None, None, Some("aaaaaa")]); - let struct_data_type = DataType::Struct(vec![ + let struct_data_type = DataType::Struct(Fields::from(vec![ Field::new("int32s", DataType::Int32, true), Field::new("utf8s", DataType::Utf8, true), - ]); + ])); let struct_data = ArrayData::builder(struct_data_type) .len(3) - .add_child_data(structs_int32s.data().clone()) - .add_child_data(structs_utf8s.data().clone()) + .add_child_data(structs_int32s.into_data()) + .add_child_data(structs_utf8s.into_data()) .null_bit_buffer(Some(Buffer::from([0b00000011]))) .build() .unwrap(); diff --git a/arrow-integration-test/src/schema.rs b/arrow-integration-test/src/schema.rs new file mode 100644 index 000000000000..541a1ec746ac --- /dev/null +++ b/arrow-integration-test/src/schema.rs @@ -0,0 +1,728 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{field_from_json, field_to_json}; +use arrow::datatypes::{Fields, Schema}; +use arrow::error::{ArrowError, Result}; +use std::collections::HashMap; + +/// Generate a JSON representation of the `Schema`. +pub fn schema_to_json(schema: &Schema) -> serde_json::Value { + serde_json::json!({ + "fields": schema.fields().iter().map(|f| field_to_json(f.as_ref())).collect::>(), + "metadata": serde_json::to_value(schema.metadata()).unwrap() + }) +} + +/// Parse a `Schema` definition from a JSON representation. +pub fn schema_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + match *json { + Value::Object(ref schema) => { + let fields: Fields = match schema.get("fields") { + Some(Value::Array(fields)) => { + fields.iter().map(field_from_json).collect::>()? + } + _ => { + return Err(ArrowError::ParseError( + "Schema fields should be an array".to_string(), + )) + } + }; + + let metadata = if let Some(value) = schema.get("metadata") { + from_metadata(value)? + } else { + HashMap::default() + }; + + Ok(Schema::new_with_metadata(fields, metadata)) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for schema".to_string(), + )), + } +} + +/// Parse a `metadata` definition from a JSON representation. +/// The JSON can either be an Object or an Array of Objects. +fn from_metadata(json: &serde_json::Value) -> Result> { + use serde_json::Value; + match json { + Value::Array(_) => { + let mut hashmap = HashMap::new(); + let values: Vec = + serde_json::from_value(json.clone()).map_err(|_| { + ArrowError::JsonError("Unable to parse object into key-value pair".to_string()) + })?; + for meta in values { + hashmap.insert(meta.key.clone(), meta.value); + } + Ok(hashmap) + } + Value::Object(md) => md + .iter() + .map(|(k, v)| { + if let Value::String(v) = v { + Ok((k.to_string(), v.to_string())) + } else { + Err(ArrowError::ParseError( + "metadata `value` field must be a string".to_string(), + )) + } + }) + .collect::>(), + _ => Err(ArrowError::ParseError( + "`metadata` field must be an object".to_string(), + )), + } +} + +#[derive(serde::Deserialize)] +struct MetadataKeyValue { + key: String, + value: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; + use serde_json::Value; + use std::sync::Arc; + + #[test] + fn schema_json() { + // Add some custom metadata + let metadata: HashMap = [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Binary, false), + Field::new("c3", DataType::FixedSizeBinary(3), false), + Field::new("c4", DataType::Boolean, false), + Field::new("c5", DataType::Date32, false), + Field::new("c6", DataType::Date64, false), + Field::new("c7", DataType::Time32(TimeUnit::Second), false), + Field::new("c8", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("c9", DataType::Time32(TimeUnit::Microsecond), false), + Field::new("c10", DataType::Time32(TimeUnit::Nanosecond), false), + Field::new("c11", DataType::Time64(TimeUnit::Second), false), + Field::new("c12", DataType::Time64(TimeUnit::Millisecond), false), + Field::new("c13", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("c14", DataType::Time64(TimeUnit::Nanosecond), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + ), + Field::new( + "c17", + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::Interval(IntervalUnit::MonthDayNano), false), + Field::new( + "c22", + DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + false, + ), + Field::new( + "c23", + DataType::FixedSizeList( + Arc::new(Field::new("bools", DataType::Boolean, false)), + 5, + ), + false, + ), + Field::new( + "c24", + DataType::List(Arc::new(Field::new( + "inner_list", + DataType::List(Arc::new(Field::new( + "struct", + DataType::Struct(Fields::empty()), + true, + ))), + false, + ))), + true, + ), + Field::new( + "c25", + DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::UInt16, false), + ])), + false, + ), + Field::new("c26", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c27", DataType::Interval(IntervalUnit::DayTime), true), + Field::new("c28", DataType::Interval(IntervalUnit::MonthDayNano), true), + Field::new("c29", DataType::Duration(TimeUnit::Second), false), + Field::new("c30", DataType::Duration(TimeUnit::Millisecond), false), + Field::new("c31", DataType::Duration(TimeUnit::Microsecond), false), + Field::new("c32", DataType::Duration(TimeUnit::Nanosecond), false), + Field::new_dict( + "c33", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 123, + true, + ), + Field::new("c34", DataType::LargeBinary, true), + Field::new("c35", DataType::LargeUtf8, true), + Field::new( + "c36", + DataType::LargeList(Arc::new(Field::new( + "inner_large_list", + DataType::LargeList(Arc::new(Field::new( + "struct", + DataType::Struct(Fields::empty()), + false, + ))), + true, + ))), + true, + ), + Field::new( + "c37", + DataType::Map( + Arc::new(Field::new( + "my_entries", + DataType::Struct(Fields::from(vec![ + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + ])), + false, + )), + true, + ), + false, + ), + ], + metadata, + ); + + let expected = schema_to_json(&schema); + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "c2", + "nullable": false, + "type": { + "name": "binary" + }, + "children": [] + }, + { + "name": "c3", + "nullable": false, + "type": { + "name": "fixedsizebinary", + "byteWidth": 3 + }, + "children": [] + }, + { + "name": "c4", + "nullable": false, + "type": { + "name": "bool" + }, + "children": [] + }, + { + "name": "c5", + "nullable": false, + "type": { + "name": "date", + "unit": "DAY" + }, + "children": [] + }, + { + "name": "c6", + "nullable": false, + "type": { + "name": "date", + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c7", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c8", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c9", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c10", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c11", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c12", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c13", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c14", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c15", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c16", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "MILLISECOND", + "timezone": "UTC" + }, + "children": [] + }, + { + "name": "c17", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "MICROSECOND", + "timezone": "Africa/Johannesburg" + }, + "children": [] + }, + { + "name": "c18", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c19", + "nullable": false, + "type": { + "name": "interval", + "unit": "DAY_TIME" + }, + "children": [] + }, + { + "name": "c20", + "nullable": false, + "type": { + "name": "interval", + "unit": "YEAR_MONTH" + }, + "children": [] + }, + { + "name": "c21", + "nullable": false, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c22", + "nullable": false, + "type": { + "name": "list" + }, + "children": [ + { + "name": "item", + "nullable": true, + "type": { + "name": "bool" + }, + "children": [] + } + ] + }, + { + "name": "c23", + "nullable": false, + "type": { + "name": "fixedsizelist", + "listSize": 5 + }, + "children": [ + { + "name": "bools", + "nullable": false, + "type": { + "name": "bool" + }, + "children": [] + } + ] + }, + { + "name": "c24", + "nullable": true, + "type": { + "name": "list" + }, + "children": [ + { + "name": "inner_list", + "nullable": false, + "type": { + "name": "list" + }, + "children": [ + { + "name": "struct", + "nullable": true, + "type": { + "name": "struct" + }, + "children": [] + } + ] + } + ] + }, + { + "name": "c25", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "a", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "b", + "nullable": false, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + }, + { + "name": "c26", + "nullable": true, + "type": { + "name": "interval", + "unit": "YEAR_MONTH" + }, + "children": [] + }, + { + "name": "c27", + "nullable": true, + "type": { + "name": "interval", + "unit": "DAY_TIME" + }, + "children": [] + }, + { + "name": "c28", + "nullable": true, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c29", + "nullable": false, + "type": { + "name": "duration", + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c30", + "nullable": false, + "type": { + "name": "duration", + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c31", + "nullable": false, + "type": { + "name": "duration", + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c32", + "nullable": false, + "type": { + "name": "duration", + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c33", + "nullable": true, + "children": [], + "type": { + "name": "utf8" + }, + "dictionary": { + "id": 123, + "indexType": { + "name": "int", + "bitWidth": 32, + "isSigned": true + }, + "isOrdered": true + } + }, + { + "name": "c34", + "nullable": true, + "type": { + "name": "largebinary" + }, + "children": [] + }, + { + "name": "c35", + "nullable": true, + "type": { + "name": "largeutf8" + }, + "children": [] + }, + { + "name": "c36", + "nullable": true, + "type": { + "name": "largelist" + }, + "children": [ + { + "name": "inner_large_list", + "nullable": true, + "type": { + "name": "largelist" + }, + "children": [ + { + "name": "struct", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [] + } + ] + } + ] + }, + { + "name": "c37", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + } + ], + "metadata" : { + "Key": "Value" + } + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + assert_eq!(expected, value); + + // convert back to a schema + let value: Value = serde_json::from_str(json).unwrap(); + let schema2 = schema_from_json(&value).unwrap(); + + assert_eq!(schema, schema2); + + // Check that empty metadata produces empty value in JSON and can be parsed + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + } + ], + "metadata": {} + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + let schema = schema_from_json(&value).unwrap(); + assert!(schema.metadata.is_empty()); + + // Check that metadata field is not required in the JSON. + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + } + ] + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + let schema = schema_from_json(&value).unwrap(); + assert!(schema.metadata.is_empty()); + } +} diff --git a/integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml similarity index 69% rename from integration-testing/Cargo.toml rename to arrow-integration-testing/Cargo.toml index b9f6cf81855e..7be56d919852 100644 --- a/integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -17,31 +17,36 @@ [package] name = "arrow-integration-testing" -description = "Binaries used in the Arrow integration tests" -version = "22.0.0" -homepage = "https://github.com/apache/arrow-rs" -repository = "https://github.com/apache/arrow-rs" -authors = ["Apache Arrow "] -license = "Apache-2.0" -edition = "2021" +description = "Binaries used in the Arrow integration tests (NOT PUBLISHED TO crates.io)" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } publish = false -rust-version = "1.62" +rust-version = { workspace = true } + +[lib] +crate-type = ["lib", "cdylib"] [features] logging = ["tracing-subscriber"] [dependencies] -arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json"] } +arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json", "ffi"] } arrow-flight = { path = "../arrow-flight", default-features = false } +arrow-buffer = { path = "../arrow-buffer", default-features = false } +arrow-integration-test = { path = "../arrow-integration-test", default-features = false } async-trait = { version = "0.1.41", default-features = false } -clap = { version = "3", default-features = false, features = ["std", "derive"] } +clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } hex = { version = "0.4", default-features = false, features = ["std"] } -prost = { version = "0.11", default-features = false } +prost = { version = "0.13", default-features = false } serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false } -tonic = { version = "0.8", default-features = false } +tonic = { version = "0.12", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } diff --git a/integration-testing/README.md b/arrow-integration-testing/README.md similarity index 99% rename from integration-testing/README.md rename to arrow-integration-testing/README.md index e82591e6b139..dcf39c27fbc5 100644 --- a/integration-testing/README.md +++ b/arrow-integration-testing/README.md @@ -48,7 +48,7 @@ ln -s arrow/rust ```shell cd arrow -pip install -e dev/archery[docker] +pip install -e dev/archery[integration] ``` ### Build the C++ binaries: diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs similarity index 97% rename from integration-testing/src/bin/arrow-file-to-stream.rs rename to arrow-integration-testing/src/bin/arrow-file-to-stream.rs index e939fe4f0bf7..3e027faef91f 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs @@ -30,7 +30,7 @@ struct Args { fn main() -> Result<()> { let args = Args::parse(); - let f = File::open(&args.file_name)?; + let f = File::open(args.file_name)?; let reader = BufReader::new(f); let mut reader = FileReader::try_new(reader, None)?; let schema = reader.schema(); diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs similarity index 66% rename from integration-testing/src/bin/arrow-json-integration-test.rs rename to arrow-integration-testing/src/bin/arrow-json-integration-test.rs index a7d7cf6ee7cb..cc3dd2110e36 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::Schema; -use arrow::datatypes::{DataType, Field}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow_integration_testing::{read_json_file, util::*}; +use arrow_integration_test::*; +use arrow_integration_testing::{canonicalize_schema, open_json_file}; use clap::Parser; use std::fs::File; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] #[clap(rename_all = "SCREAMING_SNAKE_CASE")] enum Mode { ArrowToJson, @@ -41,7 +40,13 @@ struct Args { arrow: String, #[clap(short, long, help("Path to JSON file"))] json: String, - #[clap(arg_enum, short, long, default_value_t = Mode::Validate, help="Mode of integration testing tool")] + #[clap( + value_enum, + short, + long, + default_value = "VALIDATE", + help = "Mode of integration testing tool" + )] mode: Mode, #[clap(short, long)] verbose: bool, @@ -61,15 +66,15 @@ fn main() -> Result<()> { fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Converting {} to {}", json_name, arrow_name); + eprintln!("Converting {json_name} to {arrow_name}"); } - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; let arrow_file = File::create(arrow_name)?; let mut writer = FileWriter::try_new(arrow_file, &json_file.schema)?; - for b in json_file.batches { + for b in json_file.read_batches()? { writer.write(&b)?; } @@ -80,7 +85,7 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Converting {} to {}", arrow_name, json_name); + eprintln!("Converting {arrow_name} to {json_name}"); } let arrow_file = File::open(arrow_name)?; @@ -111,54 +116,13 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn canonicalize_schema(schema: &Schema) -> Schema { - let fields = schema - .fields() - .iter() - .map(|field| match field.data_type() { - DataType::Map(child_field, sorted) => match child_field.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - let first_field = fields.get(0).unwrap(); - let key_field = Field::new( - "key", - first_field.data_type().clone(), - first_field.is_nullable(), - ); - let second_field = fields.get(1).unwrap(); - let value_field = Field::new( - "value", - second_field.data_type().clone(), - second_field.is_nullable(), - ); - - let struct_type = DataType::Struct(vec![key_field, value_field]); - let child_field = - Field::new("entries", struct_type, child_field.is_nullable()); - - Field::new( - field.name().as_str(), - DataType::Map(Box::new(child_field), *sorted), - field.is_nullable(), - ) - } - _ => panic!( - "The child field of Map type should be Struct type with 2 fields." - ), - }, - _ => field.clone(), - }) - .collect::>(); - - Schema::new(fields).with_metadata(schema.metadata().clone()) -} - fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Validating {} and {}", arrow_name, json_name); + eprintln!("Validating {arrow_name} and {json_name}"); } // open JSON file - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; // open Arrow file let arrow_file = File::open(arrow_name)?; @@ -173,7 +137,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { ))); } - let json_batches = &json_file.batches; + let json_batches = json_file.read_batches()?; // compare number of batches assert!( @@ -197,8 +161,8 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { for i in 0..num_columns { assert_eq!( - arrow_batch.column(i).data(), - json_batch.column(i).data(), + arrow_batch.column(i).as_ref(), + json_batch.column(i).as_ref(), "Arrow and JSON batch columns not the same" ); } diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/arrow-integration-testing/src/bin/arrow-stream-to-file.rs similarity index 100% rename from integration-testing/src/bin/arrow-stream-to-file.rs rename to arrow-integration-testing/src/bin/arrow-stream-to-file.rs diff --git a/integration-testing/src/bin/flight-test-integration-client.rs b/arrow-integration-testing/src/bin/flight-test-integration-client.rs similarity index 95% rename from integration-testing/src/bin/flight-test-integration-client.rs rename to arrow-integration-testing/src/bin/flight-test-integration-client.rs index fa99b424e378..b8bbb952837b 100644 --- a/integration-testing/src/bin/flight-test-integration-client.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-client.rs @@ -20,7 +20,7 @@ use clap::Parser; type Error = Box; type Result = std::result::Result; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] enum Scenario { Middleware, #[clap(name = "auth:basic_proto")] @@ -40,7 +40,7 @@ struct Args { help = "path to the descriptor file, only used when scenario is not provided. See https://arrow.apache.org/docs/format/Integration.html#json-test-data-format" )] path: Option, - #[clap(long, arg_enum)] + #[clap(long, value_enum)] scenario: Option, } @@ -62,8 +62,7 @@ async fn main() -> Result { } None => { let path = args.path.expect("No path is given"); - flight_client_scenarios::integration_test::run_scenario(&host, port, &path) - .await?; + flight_client_scenarios::integration_test::run_scenario(&host, port, &path).await?; } } diff --git a/integration-testing/src/bin/flight-test-integration-server.rs b/arrow-integration-testing/src/bin/flight-test-integration-server.rs similarity index 96% rename from integration-testing/src/bin/flight-test-integration-server.rs rename to arrow-integration-testing/src/bin/flight-test-integration-server.rs index 6ed22ad81d90..5310d07d4f8e 100644 --- a/integration-testing/src/bin/flight-test-integration-server.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-server.rs @@ -21,7 +21,7 @@ use clap::Parser; type Error = Box; type Result = std::result::Result; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] enum Scenario { Middleware, #[clap(name = "auth:basic_proto")] @@ -33,7 +33,7 @@ enum Scenario { struct Args { #[clap(long)] port: u16, - #[clap(long, arg_enum)] + #[clap(long, value_enum)] scenario: Option, } diff --git a/integration-testing/src/flight_client_scenarios.rs b/arrow-integration-testing/src/flight_client_scenarios.rs similarity index 100% rename from integration-testing/src/flight_client_scenarios.rs rename to arrow-integration-testing/src/flight_client_scenarios.rs diff --git a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs similarity index 84% rename from integration-testing/src/flight_client_scenarios/auth_basic_proto.rs rename to arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs index ab398d3d2e7b..376e31e15553 100644 --- a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -17,9 +17,7 @@ use crate::{AUTH_PASSWORD, AUTH_USERNAME}; -use arrow_flight::{ - flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest, -}; +use arrow_flight::{flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest}; use futures::{stream, StreamExt}; use prost::Message; use tonic::{metadata::MetadataValue, Request, Status}; @@ -30,7 +28,7 @@ type Result = std::result::Result; type Client = FlightServiceClient; pub async fn run_scenario(host: &str, port: u16) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let mut client = FlightServiceClient::connect(url).await?; let action = arrow_flight::Action::default(); @@ -41,15 +39,13 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { Err(e) => { if e.code() != tonic::Code::Unauthenticated { return Err(Box::new(Status::internal(format!( - "Expected UNAUTHENTICATED but got {:?}", - e + "Expected UNAUTHENTICATED but got {e:?}" )))); } } Ok(other) => { return Err(Box::new(Status::internal(format!( - "Expected UNAUTHENTICATED but got {:?}", - other + "Expected UNAUTHENTICATED but got {other:?}" )))); } } @@ -74,17 +70,13 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { .expect("No response received") .expect("Invalid response received"); - let body = String::from_utf8(r.body).unwrap(); + let body = std::str::from_utf8(&r.body).unwrap(); assert_eq!(body, AUTH_USERNAME); Ok(()) } -async fn authenticate( - client: &mut Client, - username: &str, - password: &str, -) -> Result { +async fn authenticate(client: &mut Client, username: &str, password: &str) -> Result { let auth = BasicAuth { username: username.into(), password: password.into(), @@ -94,7 +86,7 @@ async fn authenticate( let req = stream::once(async { HandshakeRequest { - payload, + payload: payload.into(), ..HandshakeRequest::default() } }); @@ -105,5 +97,5 @@ async fn authenticate( let r = rx.next().await.expect("must respond from handshake")?; assert!(rx.next().await.is_none(), "must not respond a second time"); - Ok(String::from_utf8(r.payload).unwrap()) + Ok(std::str::from_utf8(&r.payload).unwrap().into()) } diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs similarity index 80% rename from integration-testing/src/flight_client_scenarios/integration_test.rs rename to arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c01baa09a1f7..1a6c4e28a76b 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{read_json_file, ArrowFile}; +use crate::open_json_file; use std::collections::HashMap; use arrow::{ @@ -27,8 +27,7 @@ use arrow::{ }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, - SchemaAsIpc, Ticket, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, SchemaAsIpc, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; @@ -42,27 +41,20 @@ type Result = std::result::Result; type Client = FlightServiceClient; pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let client = FlightServiceClient::connect(url).await?; - let ArrowFile { - schema, batches, .. - } = read_json_file(path)?; + let json_file = open_json_file(path)?; - let schema = Arc::new(schema); + let batches = json_file.read_batches()?; + let schema = Arc::new(json_file.schema); let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Path); descriptor.path = vec![path.to_string()]; - upload_data( - client.clone(), - schema.clone(), - descriptor.clone(), - batches.clone(), - ) - .await?; + upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?; verify_data(client, descriptor, &batches).await?; Ok(()) @@ -130,15 +122,23 @@ async fn send_batch( batch: &RecordBatch, options: &writer::IpcWriteOptions, ) -> Result { - let (dictionary_flight_data, mut batch_flight_data) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, options); + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let dictionary_flight_data: Vec = + encoded_dictionaries.into_iter().map(Into::into).collect(); + let mut batch_flight_data: FlightData = encoded_batch.into(); upload_tx .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) .await?; // Only the record batch's FlightData gets app_metadata - batch_flight_data.app_metadata = metadata.to_vec(); + batch_flight_data.app_metadata = metadata.to_vec().into(); upload_tx.send(batch_flight_data).await?; Ok(()) } @@ -195,19 +195,16 @@ async fn consume_flight_location( let mut dictionaries_by_id = HashMap::new(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = receive_batch_flight_data( - &mut resp, - actual_schema.clone(), - &mut dictionaries_by_id, - ) - .await - .unwrap_or_else(|| { - panic!( - "Got fewer batches than expected, received so far: {} expected: {}", - counter, - expected_data.len(), - ) - }); + let data = + receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); @@ -224,10 +221,10 @@ async fn consume_flight_location( let field = schema.field(i); let field_name = field.name(); - let expected_data = expected_batch.column(i).data(); - let actual_data = actual_batch.column(i).data(); + let expected_data = expected_batch.column(i).as_ref(); + let actual_data = actual_batch.column(i).as_ref(); - assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + assert_eq!(expected_data, actual_data, "Data for field {field_name}"); } } @@ -242,8 +239,8 @@ async fn consume_flight_location( async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { let data = resp.next().await?.ok()?; - let message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + let message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); // message header is a Schema, so read it let ipc_schema: ipc::Schema = message @@ -260,12 +257,12 @@ async fn receive_batch_flight_data( dictionaries_by_id: &mut HashMap, ) -> Option { let mut data = resp.next().await?.ok()?; - let mut message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing first message"); + let mut message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message"); while message.header_type() == ipc::MessageHeader::DictionaryBatch { reader::read_dictionary( - &Buffer::from(&data.data_body), + &Buffer::from(data.data_body.as_ref()), message .header_as_dictionary_batch() .expect("Error parsing dictionary"), @@ -276,8 +273,8 @@ async fn receive_batch_flight_data( .expect("Error reading dictionary"); data = resp.next().await?.ok()?; - message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); } Some(data) diff --git a/integration-testing/src/flight_client_scenarios/middleware.rs b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs similarity index 90% rename from integration-testing/src/flight_client_scenarios/middleware.rs rename to arrow-integration-testing/src/flight_client_scenarios/middleware.rs index db8c42cc081c..3b71edf446a3 100644 --- a/integration-testing/src/flight_client_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs @@ -16,22 +16,22 @@ // under the License. use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - FlightDescriptor, + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, FlightDescriptor, }; +use prost::bytes::Bytes; use tonic::{Request, Status}; type Error = Box; type Result = std::result::Result; pub async fn run_scenario(host: &str, port: u16) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let conn = tonic::transport::Endpoint::new(url)?.connect().await?; let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Cmd); - descriptor.cmd = b"".to_vec(); + descriptor.cmd = Bytes::from_static(b""); // This call is expected to fail. match client @@ -47,8 +47,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { if value != "expected value" { let msg = format!( "On failing call: Expected to receive header 'x-middleware: expected value', \ - but instead got: '{}'", - value + but instead got: '{value}'" ); return Err(Box::new(Status::internal(msg))); } @@ -56,7 +55,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { } // This call should succeed - descriptor.cmd = b"success".to_vec(); + descriptor.cmd = Bytes::from_static(b"success"); let resp = client.get_flight_info(Request::new(descriptor)).await?; let headers = resp.metadata(); @@ -66,8 +65,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { if value != "expected value" { let msg = format!( "On success call: Expected to receive header 'x-middleware: expected value', \ - but instead got: '{}'", - value + but instead got: '{value}'" ); return Err(Box::new(Status::internal(msg))); } diff --git a/integration-testing/src/flight_server_scenarios.rs b/arrow-integration-testing/src/flight_server_scenarios.rs similarity index 89% rename from integration-testing/src/flight_server_scenarios.rs rename to arrow-integration-testing/src/flight_server_scenarios.rs index e56252f1dfbf..48d4e6045684 100644 --- a/integration-testing/src/flight_server_scenarios.rs +++ b/arrow-integration-testing/src/flight_server_scenarios.rs @@ -28,7 +28,7 @@ type Error = Box; type Result = std::result::Result; pub async fn listen_on(port: u16) -> Result { - let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?; let listener = TcpListener::bind(addr).await?; let addr = listener.local_addr()?; @@ -39,10 +39,12 @@ pub async fn listen_on(port: u16) -> Result { pub fn endpoint(ticket: &str, location_uri: impl Into) -> FlightEndpoint { FlightEndpoint { ticket: Some(Ticket { - ticket: ticket.as_bytes().to_vec(), + ticket: ticket.as_bytes().to_vec().into(), }), location: vec![Location { uri: location_uri.into(), }], + expiration_time: None, + app_metadata: vec![].into(), } } diff --git a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs similarity index 88% rename from integration-testing/src/flight_server_scenarios/auth_basic_proto.rs rename to arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index 68a4a0d3b4ad..20d868953664 100644 --- a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -19,15 +19,13 @@ use std::pin::Pin; use std::sync::Arc; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, - FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; use tokio::sync::Mutex; -use tonic::{ - metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming, -}; +use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; @@ -63,10 +61,7 @@ pub struct AuthBasicProtoScenarioImpl { } impl AuthBasicProtoScenarioImpl { - async fn check_auth( - &self, - metadata: &MetadataMap, - ) -> Result { + async fn check_auth(&self, metadata: &MetadataMap) -> Result { let token = metadata .get_bin("auth-token-bin") .and_then(|v| v.to_bytes().ok()) @@ -74,10 +69,7 @@ impl AuthBasicProtoScenarioImpl { self.is_valid(token).await } - async fn is_valid( - &self, - token: Option, - ) -> Result { + async fn is_valid(&self, token: Option) -> Result { match token { Some(t) if t == *self.username => Ok(GrpcServerCallContext { peer_identity: self.username.to_string(), @@ -142,14 +134,12 @@ impl FlightService for AuthBasicProtoScenarioImpl { let req = req.expect("Error reading handshake request"); let HandshakeRequest { payload, .. } = req; - let auth = BasicAuth::decode(&*payload) - .expect("Error parsing handshake request"); + let auth = + BasicAuth::decode(&*payload).expect("Error parsing handshake request"); - let resp = if *auth.username == *username - && *auth.password == *password - { + let resp = if *auth.username == *username && *auth.password == *password { Ok(HandshakeResponse { - payload: username.as_bytes().to_vec(), + payload: username.as_bytes().to_vec().into(), ..HandshakeResponse::default() }) } else { @@ -188,6 +178,14 @@ impl FlightService for AuthBasicProtoScenarioImpl { Err(Status::unimplemented("Not yet implemented")) } + async fn poll_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_put( &self, request: Request>, @@ -203,7 +201,7 @@ impl FlightService for AuthBasicProtoScenarioImpl { ) -> Result, Status> { let flight_context = self.check_auth(request.metadata()).await?; // Respond with the authenticated username. - let buf = flight_context.peer_identity().as_bytes().to_vec(); + let buf = flight_context.peer_identity().as_bytes().to_vec().into(); let result = arrow_flight::Result { body: buf }; let output = futures::stream::once(async { Ok(result) }); Ok(Response::new(Box::pin(output) as Self::DoActionStream)) diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs similarity index 81% rename from integration-testing/src/flight_server_scenarios/integration_test.rs rename to arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index dee2fda3be3d..76eb9d880199 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -16,7 +16,6 @@ // under the License. use std::collections::HashMap; -use std::convert::TryFrom; use std::pin::Pin; use std::sync::Arc; @@ -25,17 +24,16 @@ use arrow::{ buffer::Buffer, datatypes::Schema, datatypes::SchemaRef, - ipc::{self, reader}, + ipc::{self, reader, writer}, record_batch::RecordBatch, }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, + PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; -use std::convert::TryInto; use tokio::sync::Mutex; use tonic::{transport::Server, Request, Response, Status, Streaming}; @@ -48,7 +46,7 @@ pub async fn scenario_setup(port: u16) -> Result { let addr = super::listen_on(port).await?; let service = FlightServiceImpl { - server_location: format!("grpc+tcp://{}", addr), + server_location: format!("grpc+tcp://{addr}"), ..Default::default() }; let svc = FlightServiceServer::new(service); @@ -103,33 +101,39 @@ impl FlightService for FlightServiceImpl { let ticket = request.into_inner(); let key = String::from_utf8(ticket.ticket.to_vec()) - .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {e:?}")))?; let uploaded_chunks = self.uploaded_chunks.lock().await; - let flight = uploaded_chunks.get(&key).ok_or_else(|| { - Status::not_found(format!("Could not find flight. {}", key)) - })?; + let flight = uploaded_chunks + .get(&key) + .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?; let options = arrow::ipc::writer::IpcWriteOptions::default(); - let schema = - std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); + let schema = std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); let batches = flight .chunks .iter() .enumerate() .flat_map(|(counter, batch)| { - let (dictionary_flight_data, mut batch_flight_data) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, true); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, &options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into); + let mut batch_flight_data: FlightData = encoded_batch.into(); // Only the record batch's FlightData gets app_metadata - let metadata = counter.to_string().into_bytes(); + let metadata = counter.to_string().into(); batch_flight_data.app_metadata = metadata; dictionary_flight_data - .into_iter() .chain(std::iter::once(batch_flight_data)) .map(Ok) }); @@ -173,8 +177,7 @@ impl FlightService for FlightServiceImpl { let endpoint = self.endpoint_from_path(&path[0]); - let total_records: usize = - flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); let options = arrow::ipc::writer::IpcWriteOptions::default(); let message = SchemaAsIpc::new(&flight.schema, &options) @@ -191,14 +194,23 @@ impl FlightService for FlightServiceImpl { endpoint: vec![endpoint], total_records: total_records as i64, total_bytes: -1, + ordered: false, + app_metadata: vec![].into(), }; Ok(Response::new(info)) } - other => Err(Status::unimplemented(format!("Request type: {}", other))), + other => Err(Status::unimplemented(format!("Request type: {other}"))), } } + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_put( &self, request: Request>, @@ -214,15 +226,14 @@ impl FlightService for FlightServiceImpl { .clone() .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?; - if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() - { + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() { return Err(Status::invalid_argument("Must specify a path")); } let key = descriptor.path[0].clone(); let schema = Schema::try_from(&flight_data) - .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {e:?}")))?; let schema_ref = Arc::new(schema.clone()); let (response_tx, response_rx) = mpsc::channel(10); @@ -275,10 +286,10 @@ async fn send_app_metadata( app_metadata: &[u8], ) -> Result<(), Status> { tx.send(Ok(PutResult { - app_metadata: app_metadata.to_vec(), + app_metadata: app_metadata.to_vec().into(), })) .await - .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e))) + .map_err(|e| Status::internal(format!("Could not send PutResult: {e:?}"))) } async fn record_batch_from_message( @@ -287,9 +298,9 @@ async fn record_batch_from_message( schema_ref: SchemaRef, dictionaries_by_id: &HashMap, ) -> Result { - let ipc_batch = message.header_as_record_batch().ok_or_else(|| { - Status::internal("Could not parse message header as record batch") - })?; + let ipc_batch = message + .header_as_record_batch() + .ok_or_else(|| Status::internal("Could not parse message header as record batch"))?; let arrow_batch_result = reader::read_record_batch( data_body, @@ -300,9 +311,8 @@ async fn record_batch_from_message( &message.version(), ); - arrow_batch_result.map_err(|e| { - Status::internal(format!("Could not convert to RecordBatch: {:?}", e)) - }) + arrow_batch_result + .map_err(|e| Status::internal(format!("Could not convert to RecordBatch: {e:?}"))) } async fn dictionary_from_message( @@ -311,9 +321,9 @@ async fn dictionary_from_message( schema_ref: SchemaRef, dictionaries_by_id: &mut HashMap, ) -> Result<(), Status> { - let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { - Status::internal("Could not parse message header as dictionary batch") - })?; + let ipc_batch = message + .header_as_dictionary_batch() + .ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?; let dictionary_batch_result = reader::read_dictionary( data_body, @@ -322,9 +332,8 @@ async fn dictionary_from_message( dictionaries_by_id, &message.version(), ); - dictionary_batch_result.map_err(|e| { - Status::internal(format!("Could not convert to Dictionary: {:?}", e)) - }) + dictionary_batch_result + .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}"))) } async fn save_uploaded_chunks( @@ -342,7 +351,7 @@ async fn save_uploaded_chunks( while let Some(Ok(data)) = input_stream.next().await { let message = arrow::ipc::root_as_message(&data.data_header[..]) - .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; + .map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?; match message.header_type() { ipc::MessageHeader::Schema => { @@ -355,7 +364,7 @@ async fn save_uploaded_chunks( let batch = record_batch_from_message( message, - &Buffer::from(data.data_body), + &Buffer::from(data.data_body.as_ref()), schema_ref.clone(), &dictionaries_by_id, ) @@ -366,7 +375,7 @@ async fn save_uploaded_chunks( ipc::MessageHeader::DictionaryBatch => { dictionary_from_message( message, - &Buffer::from(data.data_body), + &Buffer::from(data.data_body.as_ref()), schema_ref.clone(), &mut dictionaries_by_id, ) @@ -375,8 +384,7 @@ async fn save_uploaded_chunks( t => { return Err(Status::internal(format!( "Reading types other than record batches not yet supported, \ - unable to read {:?}", - t + unable to read {t:?}" ))); } } diff --git a/integration-testing/src/flight_server_scenarios/middleware.rs b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs similarity index 92% rename from integration-testing/src/flight_server_scenarios/middleware.rs rename to arrow-integration-testing/src/flight_server_scenarios/middleware.rs index 5876ac9bfe6d..e8d9c521bb99 100644 --- a/integration-testing/src/flight_server_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs @@ -19,9 +19,9 @@ use std::pin::Pin; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, + SchemaResult, Ticket, }; use futures::Stream; use tonic::{transport::Server, Request, Response, Status, Streaming}; @@ -93,7 +93,7 @@ impl FlightService for MiddlewareScenarioImpl { let descriptor = request.into_inner(); - if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd.as_ref() == b"success" { // Return a fake location - the test doesn't read it let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010"); @@ -120,6 +120,13 @@ impl FlightService for MiddlewareScenarioImpl { Err(status) } + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_put( &self, _request: Request>, diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs new file mode 100644 index 000000000000..4ce7b06a1888 --- /dev/null +++ b/arrow-integration-testing/src/lib.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common code used in the integration test binaries + +use serde_json::Value; + +use arrow::array::{Array, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_test::*; +use std::collections::HashMap; +use std::ffi::{c_int, CStr, CString}; +use std::fs::File; +use std::io::BufReader; +use std::iter::zip; +use std::ptr; +use std::sync::Arc; + +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + +pub mod flight_client_scenarios; +pub mod flight_server_scenarios; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + dictionaries: HashMap, + arrow_json: Value, +} + +impl ArrowFile { + pub fn read_batch(&self, batch_num: usize) -> Result { + let b = self.arrow_json["batches"].get(batch_num).unwrap(); + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + } + + pub fn read_batches(&self) -> Result> { + self.arrow_json["batches"] + .as_array() + .unwrap() + .iter() + .map(|b| { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + }) + .collect() + } +} + +// Canonicalize the names of map fields in a schema +pub fn canonicalize_schema(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Map(child_field, sorted) => match child_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + let first_field = &fields[0]; + let key_field = + Arc::new(Field::new("key", first_field.data_type().clone(), false)); + let second_field = &fields[1]; + let value_field = Arc::new(Field::new( + "value", + second_field.data_type().clone(), + second_field.is_nullable(), + )); + + let fields = Fields::from([key_field, value_field]); + let struct_type = DataType::Struct(fields); + let child_field = Field::new("entries", struct_type, false); + + Arc::new(Field::new( + field.name().as_str(), + DataType::Map(Arc::new(child_field), *sorted), + field.is_nullable(), + )) + } + _ => panic!("The child field of Map type should be Struct type with 2 fields."), + }, + _ => field.clone(), + }) + .collect::(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +pub fn open_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = schema_from_json(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = + serde_json::from_value(d.clone()).expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + Ok(ArrowFile { + schema, + dictionaries, + arrow_json, + }) +} + +/// Read gzipped JSON test file +/// +/// For example given the input: +/// version = `0.17.1` +/// path = `generated_union` +/// +/// Returns the contents of +/// `arrow-ipc-stream/integration/0.17.1/generated_union.json.gz` +pub fn read_gzip_json(version: &str, path: &str) -> ArrowJson { + use flate2::read::GzDecoder; + use std::io::Read; + + let testdata = arrow_test_data(); + let file = File::open(format!( + "{testdata}/arrow-ipc-stream/integration/{version}/{path}.json.gz" + )) + .unwrap(); + let mut gz = GzDecoder::new(&file); + let mut s = String::new(); + gz.read_to_string(&mut s).unwrap(); + // convert to Arrow JSON + let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); + arrow_json +} + +// +// C Data Integration entrypoints +// + +fn cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let f = open_json_file(json_name.to_str()?)?; + let c_schema = FFI_ArrowSchema::try_from(&f.schema)?; + // Move exported schema into output struct + unsafe { ptr::write(out, c_schema) }; + Ok(()) +} + +fn cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let b = open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let a = StructArray::from(b).into_data(); + let c_array = FFI_ArrowArray::new(&a); + // Move exported array into output struct + unsafe { ptr::write(out, c_array) }; + Ok(()) +} + +fn cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_schema = open_json_file(json_name.to_str()?)?.schema; + + // The source ArrowSchema will be released when this is dropped + let imported_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema) }; + let imported_schema = Schema::try_from(&imported_schema)?; + + // compare schemas + if canonicalize_schema(&json_schema) != canonicalize_schema(&imported_schema) { + return Err(ArrowError::ComputeError(format!( + "Schemas do not match.\n- JSON: {:?}\n- Imported: {:?}", + json_schema, imported_schema + ))); + } + Ok(()) +} + +fn compare_batches(a: &RecordBatch, b: &RecordBatch) -> Result<()> { + if a.num_columns() != b.num_columns() { + return Err(ArrowError::InvalidArgumentError( + "batches do not have the same number of columns".to_string(), + )); + } + for (a_column, b_column) in zip(a.columns(), b.columns()) { + if a_column != b_column { + return Err(ArrowError::InvalidArgumentError( + "batch columns are not the same".to_string(), + )); + } + } + Ok(()) +} + +fn cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_batch = + open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let schema = json_batch.schema(); + + let data_type_for_import = DataType::Struct(schema.fields.clone()); + let imported_array = unsafe { FFI_ArrowArray::from_raw(c_array) }; + let imported_array = unsafe { from_ffi_and_data_type(imported_array, data_type_for_import) }?; + imported_array.validate_full()?; + let imported_batch = RecordBatch::from(StructArray::from(imported_array)); + + compare_batches(&json_batch, &imported_batch) +} + +// If Result is an error, then export a const char* to its string display, otherwise NULL +fn result_to_c_error(result: &std::result::Result) -> *mut i8 { + match result { + Ok(_) => ptr::null_mut(), + Err(e) => CString::new(format!("{}", e)).unwrap().into_raw(), + } +} + +/// Release a const char* exported by result_to_c_error() +/// +/// # Safety +/// +/// The pointer is assumed to have been obtained using CString::into_raw. +#[no_mangle] +pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut i8) { + if !c_error.is_null() { + drop(unsafe { CString::from_raw(c_error) }); + } +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_export_schema_from_json(c_json_name, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_import_schema_and_compare_to_json(c_json_name, c_schema); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_export_batch_from_json(c_json_name, batch_num, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_import_batch_and_compare_to_json(c_json_name, batch_num, c_array); + result_to_c_error(&r) +} diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs new file mode 100644 index 000000000000..a683075990c7 --- /dev/null +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for reading the content of [`FileReader`] and [`StreamReader`] +//! in `testing/arrow-ipc-stream/integration/...` + +use arrow::error::ArrowError; +use arrow::ipc::reader::{FileReader, StreamDecoder, StreamReader}; +use arrow::util::test_util::arrow_test_data; +use arrow_buffer::Buffer; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; +use std::io::Read; + +#[test] +fn read_0_1_4() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + let paths = [ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn read_0_1_7() { + let testdata = arrow_test_data(); + let version = "0.17.1"; + let paths = ["generated_union"]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn read_1_0_0_bigendian() { + let testdata = arrow_test_data(); + let paths = [ + "generated_decimal", + "generated_dictionary", + "generated_interval", + "generated_datetime", + "generated_map", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{testdata}/arrow-ipc-stream/integration/1.0.0-bigendian/{path}.arrow_file" + )) + .unwrap(); + + let reader = FileReader::try_new(file, None); + + assert!(reader.is_err()); + let err = reader.err().unwrap(); + assert!(matches!(err, ArrowError::IpcError(_))); + assert_eq!(err.to_string(), "Ipc error: the endianness of the source system does not match the endianness of the target system."); + }); +} + +#[test] +fn read_1_0_0_littleendian() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + let paths = vec![ + "generated_datetime", + "generated_custom_metadata", + "generated_decimal", + "generated_decimal256", + "generated_dictionary", + "generated_dictionary_unsigned", + "generated_duplicate_fieldnames", + "generated_extension", + "generated_interval", + "generated_map", + // https://github.com/apache/arrow-rs/issues/3460 + //"generated_map_non_canonical", + "generated_nested", + "generated_nested_dictionary", + "generated_nested_large_offsets", + "generated_null", + "generated_null_trivial", + "generated_primitive", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_recursive_nested", + "generated_union", + ]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn read_2_0_0_compression() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + + // the test is repetitive, thus we can read all supported files at once + let paths = ["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +/// Verifies the arrow file format integration test +/// +/// Input file: +/// `arrow-ipc-stream/integration//.arrow_file +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn verify_arrow_file(testdata: &str, version: &str, path: &str) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + println!("Verifying {filename}"); + + // Compare contents to the expected output format in JSON + { + println!(" verifying content"); + let file = File::open(&filename).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + // Verify that projection works by selecting the first column + { + println!(" verifying projection"); + let file = File::open(&filename).unwrap(); + let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); + let datatype_0 = reader.schema().fields()[0].data_type().clone(); + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.columns().len(), 1); + assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); + }); + } +} + +/// Verifies the arrow stream integration test +/// +/// Input file: +/// `arrow-ipc-stream/integration//.stream +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn verify_arrow_stream(testdata: &str, version: &str, path: &str) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + println!("Verifying {filename}"); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + + // Compare contents to the expected output format in JSON + { + println!(" verifying content"); + let file = File::open(&filename).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); + + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + // the next batch must be empty + assert!(reader.next().is_none()); + // the stream must indicate that it's finished + assert!(reader.is_finished()); + } + + // Test stream decoder + let expected = arrow_json.get_record_batches().unwrap(); + for chunk_sizes in [1, 2, 8, 123] { + let mut decoder = StreamDecoder::new(); + let stream = chunked_file(&filename, chunk_sizes); + let mut actual = Vec::with_capacity(expected.len()); + for mut x in stream { + while !x.is_empty() { + if let Some(x) = decoder.decode(&mut x).unwrap() { + actual.push(x); + } + } + } + decoder.finish().unwrap(); + assert_eq!(expected, actual); + } +} + +fn chunked_file(filename: &str, chunk_size: u64) -> impl Iterator { + let mut file = File::open(filename).unwrap(); + std::iter::from_fn(move || { + let mut buf = vec![]; + let read = (&mut file).take(chunk_size).read_to_end(&mut buf).unwrap(); + (read != 0).then(|| Buffer::from_vec(buf)) + }) +} diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs new file mode 100644 index 000000000000..d780eb2ee0b5 --- /dev/null +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::ipc; +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; +use std::io::Seek; + +#[test] +fn write_0_1_4() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + let paths = [ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_0_1_7() { + let testdata = arrow_test_data(); + let version = "0.17.1"; + let paths = ["generated_union"]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_1_0_0_littleendian() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + let paths = [ + "generated_datetime", + "generated_custom_metadata", + "generated_decimal", + "generated_decimal256", + "generated_dictionary", + "generated_dictionary_unsigned", + "generated_duplicate_fieldnames", + "generated_extension", + "generated_interval", + "generated_map", + // https://github.com/apache/arrow-rs/issues/3460 + // "generated_map_non_canonical", + "generated_nested", + "generated_nested_dictionary", + "generated_nested_large_offsets", + "generated_null", + "generated_null_trivial", + "generated_primitive", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_recursive_nested", + "generated_union", + ]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_2_0_0_compression() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + let paths = ["generated_lz4", "generated_zstd"]; + + // writer options for each compression type + let all_options = [ + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) + .unwrap(), + // write IPC version 5 with zstd + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::ZSTD)) + .unwrap(), + ]; + + paths.iter().for_each(|path| { + for options in &all_options { + println!("Using options {options:?}"); + roundtrip_arrow_file_with_options(&testdata, version, path, options.clone()); + roundtrip_arrow_stream_with_options(&testdata, version, path, options.clone()); + } + }); +} + +/// Verifies the arrow file writer by reading the contents of an +/// arrow_file, writing it to a file, and then ensuring the contents +/// match the expected json contents. It also verifies that +/// RecordBatches read from the new file matches the original. +/// +/// Input file: +/// `arrow-ipc-stream/integration//.arrow_file +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn roundtrip_arrow_file(testdata: &str, version: &str, path: &str) { + roundtrip_arrow_file_with_options(testdata, version, path, IpcWriteOptions::default()) +} + +fn roundtrip_arrow_file_with_options( + testdata: &str, + version: &str, + path: &str, + options: IpcWriteOptions, +) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + println!("Verifying {filename}"); + + let mut tempfile = tempfile::tempfile().unwrap(); + + { + println!(" writing to tempfile {tempfile:?}"); + let file = File::open(&filename).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read and rewrite the file to a temp location + { + let mut writer = + FileWriter::try_new_with_options(&mut tempfile, &reader.schema(), options).unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + } + + { + println!(" checking rewrite to with json"); + tempfile.rewind().unwrap(); + let mut reader = FileReader::try_new(&tempfile, None).unwrap(); + + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + { + println!(" checking rewrite with original"); + let file = File::open(&filename).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); + + tempfile.rewind().unwrap(); + let rewrite_reader = FileReader::try_new(&tempfile, None).unwrap(); + + // Compare to original reader + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }); + } +} + +/// Verifies the arrow file writer by reading the contents of an +/// arrow_file, writing it to a file, and then ensuring the contents +/// match the expected json contents. It also verifies that +/// RecordBatches read from the new file matches the original. +/// +/// Input file: +/// `arrow-ipc-stream/integration//.stream +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn roundtrip_arrow_stream(testdata: &str, version: &str, path: &str) { + roundtrip_arrow_stream_with_options(testdata, version, path, IpcWriteOptions::default()) +} + +fn roundtrip_arrow_stream_with_options( + testdata: &str, + version: &str, + path: &str, + options: IpcWriteOptions, +) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + println!("Verifying {filename}"); + + let mut tempfile = tempfile::tempfile().unwrap(); + + { + println!(" writing to tempfile {tempfile:?}"); + let file = File::open(&filename).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read and rewrite the file to a temp location + { + let mut writer = + StreamWriter::try_new_with_options(&mut tempfile, &reader.schema(), options) + .unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + } + + { + println!(" checking rewrite to with json"); + tempfile.rewind().unwrap(); + let mut reader = StreamReader::try_new(&tempfile, None).unwrap(); + + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + { + println!(" checking rewrite with original"); + let file = File::open(&filename).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); + + tempfile.rewind().unwrap(); + let rewrite_reader = StreamReader::try_new(&tempfile, None).unwrap(); + + // Compare to original reader + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }); + } +} diff --git a/arrow-ipc/CONTRIBUTING.md b/arrow-ipc/CONTRIBUTING.md new file mode 100644 index 000000000000..5e14760f19df --- /dev/null +++ b/arrow-ipc/CONTRIBUTING.md @@ -0,0 +1,37 @@ + + +## Developer's guide + +# IPC + +The expected flatc version is 1.12.0+, built from [flatbuffers](https://github.com/google/flatbuffers) +master at fixed commit ID, by regen.sh. + +The IPC flatbuffer code was generated by running this command from the root of the project: + +```bash +./regen.sh +``` + +The above script will run the `flatc` compiler and perform some adjustments to the source code: + +- Replace `type__` with `type_` +- Remove `org::apache::arrow::flatbuffers` namespace +- Add includes to each generated file diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml new file mode 100644 index 000000000000..94b89a55f2fb --- /dev/null +++ b/arrow-ipc/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-ipc" +version = { workspace = true } +description = "Support for the Arrow IPC format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_ipc" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +flatbuffers = { version = "24.3.25", default-features = false } +lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } +zstd = { version = "0.13.0", default-features = false, optional = true } + +[features] +default = [] +lz4 = ["lz4_flex"] + +[dev-dependencies] +tempfile = "3.3" diff --git a/arrow/regen.sh b/arrow-ipc/regen.sh similarity index 83% rename from arrow/regen.sh rename to arrow-ipc/regen.sh index 9d384b6b63b6..8d8862ccc7f4 100755 --- a/arrow/regen.sh +++ b/arrow-ipc/regen.sh @@ -18,15 +18,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# Change to the toplevel Rust directory -pushd $DIR/../../ +# Change to the toplevel `arrow-rs` directory +pushd $DIR/../ echo "Build flatc from source ..." FB_URL="https://github.com/google/flatbuffers" -# https://github.com/google/flatbuffers/pull/6393 -FB_COMMIT="408cf5802415e1dea65fef7489a6c2f3740fb381" -FB_DIR="rust/arrow/.flatbuffers" +FB_DIR="arrow/.flatbuffers" FLATC="$FB_DIR/bazel-bin/flatc" if [ -z $(which bazel) ]; then @@ -44,28 +42,21 @@ else git -C $FB_DIR pull fi -echo "hard reset to $FB_COMMIT" -git -C $FB_DIR reset --hard $FB_COMMIT - pushd $FB_DIR echo "run: bazel build :flatc ..." bazel build :flatc popd -FB_PATCH="rust/arrow/format-0ed34c83.patch" -echo "Patch flatbuffer files with ${FB_PATCH} for cargo doc" -echo "NOTE: the patch MAY need update in case of changes in format/*.fbs" -git apply --check ${FB_PATCH} && git apply ${FB_PATCH} # Execute the code generation: -$FLATC --filename-suffix "" --rust -o rust/arrow/src/ipc/gen/ format/*.fbs +$FLATC --filename-suffix "" --rust -o arrow-ipc/src/gen/ format/*.fbs # Reset changes to format/ git checkout -- format # Now the files are wrongly named so we have to change that. popd -pushd $DIR/src/ipc/gen +pushd $DIR/src/gen PREFIX=$(cat <<'HEREDOC' // Licensed to the Apache Software Foundation (ASF) under one @@ -94,9 +85,9 @@ use flatbuffers::EndianScalar; HEREDOC ) -SCHEMA_IMPORT="\nuse crate::ipc::gen::Schema::*;" -SPARSE_TENSOR_IMPORT="\nuse crate::ipc::gen::SparseTensor::*;" -TENSOR_IMPORT="\nuse crate::ipc::gen::Tensor::*;" +SCHEMA_IMPORT="\nuse crate::gen::Schema::*;" +SPARSE_TENSOR_IMPORT="\nuse crate::gen::SparseTensor::*;" +TENSOR_IMPORT="\nuse crate::gen::Tensor::*;" # For flatbuffer(1.12.0+), remove: use crate::${name}::\*; names=("File" "Message" "Schema" "SparseTensor" "Tensor") @@ -119,8 +110,9 @@ for f in `ls *.rs`; do sed -i '' '/} \/\/ pub mod arrow/d' $f sed -i '' '/} \/\/ pub mod apache/d' $f sed -i '' '/} \/\/ pub mod org/d' $f - sed -i '' '/use std::mem;/d' $f - sed -i '' '/use std::cmp::Ordering;/d' $f + sed -i '' '/use core::mem;/d' $f + sed -i '' '/use core::cmp::Ordering;/d' $f + sed -i '' '/use self::flatbuffers::{EndianScalar, Follow};/d' $f # required by flatc 1.12.0+ sed -i '' "/\#\!\[allow(unused_imports, dead_code)\]/d" $f @@ -150,7 +142,7 @@ done # Return back to base directory popd -cargo +stable fmt -- src/ipc/gen/* +cargo +stable fmt -- src/gen/* echo "DONE!" echo "Please run 'cargo doc' and 'cargo test' with nightly and stable, " diff --git a/arrow/src/ipc/compression/codec.rs b/arrow-ipc/src/compression.rs similarity index 54% rename from arrow/src/ipc/compression/codec.rs rename to arrow-ipc/src/compression.rs index 58ba8cb86585..47ea7785cbec 100644 --- a/arrow/src/ipc/compression/codec.rs +++ b/arrow-ipc/src/compression.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; -use crate::error::{ArrowError, Result}; -use crate::ipc::CompressionType; -use std::io::{Read, Write}; +use crate::CompressionType; +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; const LENGTH_NO_COMPRESSED_DATA: i64 = -1; const LENGTH_OF_PREFIX_DATA: i64 = 8; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Represents compressing a ipc stream using a particular compression algorithm +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CompressionCodec { Lz4Frame, Zstd, @@ -33,13 +32,12 @@ pub enum CompressionCodec { impl TryFrom for CompressionCodec { type Error = ArrowError; - fn try_from(compression_type: CompressionType) -> Result { + fn try_from(compression_type: CompressionType) -> Result { match compression_type { CompressionType::ZSTD => Ok(CompressionCodec::Zstd), CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame), other_type => Err(ArrowError::NotYetImplemented(format!( - "compression type {:?} not supported ", - other_type + "compression type {other_type:?} not supported " ))), } } @@ -60,7 +58,7 @@ impl CompressionCodec { &self, input: &[u8], output: &mut Vec, - ) -> Result { + ) -> Result { let uncompressed_data_len = input.len(); let original_output_len = output.len(); @@ -71,7 +69,7 @@ impl CompressionCodec { output.extend_from_slice(&uncompressed_data_len.to_le_bytes()); self.compress(input, output)?; - let compression_len = output.len(); + let compression_len = output.len() - original_output_len; if compression_len > uncompressed_data_len { // length of compressed data was larger than // uncompressed data, use the uncompressed data with @@ -92,73 +90,123 @@ impl CompressionCodec { /// [8 bytes]: uncompressed length /// [remaining bytes]: compressed data stream /// ``` - pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { + pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { // read the first 8 bytes to determine if the data is // compressed let decompressed_length = read_uncompressed_size(input); let buffer = if decompressed_length == 0 { - // emtpy + // empty Buffer::from([]) } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA { // no compression input.slice(LENGTH_OF_PREFIX_DATA as usize) - } else { + } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) { // decompress data using the codec - let mut uncompressed_buffer = - Vec::with_capacity(decompressed_length as usize); let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..]; - self.decompress(input_data, &mut uncompressed_buffer)?; - Buffer::from(uncompressed_buffer) + let v = self.decompress(input_data, decompressed_length as _)?; + Buffer::from_vec(v) + } else { + return Err(ArrowError::IpcError(format!( + "Invalid uncompressed length: {decompressed_length}" + ))); }; Ok(buffer) } /// Compress the data in input buffer and write to output buffer /// using the specified compression - fn compress(&self, input: &[u8], output: &mut Vec) -> Result<()> { + fn compress(&self, input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { match self { - CompressionCodec::Lz4Frame => { - let mut encoder = lz4::EncoderBuilder::new().build(output)?; - encoder.write_all(input)?; - match encoder.finish().1 { - Ok(_) => Ok(()), - Err(e) => Err(e.into()), - } - } - CompressionCodec::Zstd => { - let mut encoder = zstd::Encoder::new(output, 0)?; - encoder.write_all(input)?; - match encoder.finish() { - Ok(_) => Ok(()), - Err(e) => Err(e.into()), - } - } + CompressionCodec::Lz4Frame => compress_lz4(input, output), + CompressionCodec::Zstd => compress_zstd(input, output), } } /// Decompress the data in input buffer and write to output buffer /// using the specified compression - fn decompress(&self, input: &[u8], output: &mut Vec) -> Result { - let result: Result = match self { - CompressionCodec::Lz4Frame => { - let mut decoder = lz4::Decoder::new(input)?; - match decoder.read_to_end(output) { - Ok(size) => Ok(size), - Err(e) => Err(e.into()), - } - } - CompressionCodec::Zstd => { - let mut decoder = zstd::Decoder::new(input)?; - match decoder.read_to_end(output) { - Ok(size) => Ok(size), - Err(e) => Err(e.into()), - } - } + fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + let ret = match self { + CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?, + CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?, }; - result + if ret.len() != decompressed_size { + return Err(ArrowError::IpcError(format!( + "Expected compressed length of {decompressed_size} got {}", + ret.len() + ))); + } + Ok(ret) } } +#[cfg(feature = "lz4")] +fn compress_lz4(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { + use std::io::Write; + let mut encoder = lz4_flex::frame::FrameEncoder::new(output); + encoder.write_all(input)?; + encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(()) +} + +#[cfg(not(feature = "lz4"))] +#[allow(clippy::ptr_arg)] +fn compress_lz4(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> { + Err(ArrowError::InvalidArgumentError( + "lz4 IPC compression requires the lz4 feature".to_string(), + )) +} + +#[cfg(feature = "lz4")] +fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + use std::io::Read; + let mut output = Vec::with_capacity(decompressed_size); + lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?; + Ok(output) +} + +#[cfg(not(feature = "lz4"))] +#[allow(clippy::ptr_arg)] +fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { + Err(ArrowError::InvalidArgumentError( + "lz4 IPC decompression requires the lz4 feature".to_string(), + )) +} + +#[cfg(feature = "zstd")] +fn compress_zstd(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { + use std::io::Write; + let mut encoder = zstd::Encoder::new(output, 0)?; + encoder.write_all(input)?; + encoder.finish()?; + Ok(()) +} + +#[cfg(not(feature = "zstd"))] +#[allow(clippy::ptr_arg)] +fn compress_zstd(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> { + Err(ArrowError::InvalidArgumentError( + "zstd IPC compression requires the zstd feature".to_string(), + )) +} + +#[cfg(feature = "zstd")] +fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + use std::io::Read; + let mut output = Vec::with_capacity(decompressed_size); + zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?; + Ok(output) +} + +#[cfg(not(feature = "zstd"))] +#[allow(clippy::ptr_arg)] +fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { + Err(ArrowError::InvalidArgumentError( + "zstd IPC decompression requires the zstd feature".to_string(), + )) +} + /// Get the uncompressed length /// Notes: /// LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed @@ -173,31 +221,29 @@ fn read_uncompressed_size(buffer: &[u8]) -> i64 { #[cfg(test)] mod tests { - use super::*; - #[test] + #[cfg(feature = "lz4")] fn test_lz4_compression() { - let input_bytes = "hello lz4".as_bytes(); - let codec: CompressionCodec = CompressionCodec::Lz4Frame; + let input_bytes = b"hello lz4"; + let codec = super::CompressionCodec::Lz4Frame; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } #[test] + #[cfg(feature = "zstd")] fn test_zstd_compression() { - let input_bytes = "hello zstd".as_bytes(); - let codec: CompressionCodec = CompressionCodec::Zstd; + let input_bytes = b"hello zstd"; + let codec = super::CompressionCodec::Zstd; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs new file mode 100644 index 000000000000..52c6a0d614d0 --- /dev/null +++ b/arrow-ipc/src/convert.rs @@ -0,0 +1,1221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for converting between IPC types and native Arrow types + +use arrow_buffer::Buffer; +use arrow_schema::*; +use flatbuffers::{ + FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, + VerifierOptions, WIPOffset, +}; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use crate::writer::DictionaryTracker; +use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER}; +use DataType::*; + +/// Low level Arrow [Schema] to IPC bytes converter +/// +/// See also [`fb_to_schema`] for the reverse operation +/// +/// # Example +/// ``` +/// # use arrow_ipc::convert::{fb_to_schema, IpcSchemaEncoder}; +/// # use arrow_ipc::root_as_schema; +/// # use arrow_ipc::writer::DictionaryTracker; +/// # use arrow_schema::{DataType, Field, Schema}; +/// // given an arrow schema to serialize +/// let schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, false), +/// ]); +/// +/// // Use a dictionary tracker to track dictionary id if needed +/// let mut dictionary_tracker = DictionaryTracker::new(true); +/// // create a FlatBuffersBuilder that contains the encoded bytes +/// let fb = IpcSchemaEncoder::new() +/// .with_dictionary_tracker(&mut dictionary_tracker) +/// .schema_to_fb(&schema); +/// +/// // the bytes are in `fb.finished_data()` +/// let ipc_bytes = fb.finished_data(); +/// +/// // convert the IPC bytes back to an Arrow schema +/// let ipc_schema = root_as_schema(ipc_bytes).unwrap(); +/// let schema2 = fb_to_schema(ipc_schema); +/// assert_eq!(schema, schema2); +/// ``` +#[derive(Debug)] +pub struct IpcSchemaEncoder<'a> { + dictionary_tracker: Option<&'a mut DictionaryTracker>, +} + +impl<'a> Default for IpcSchemaEncoder<'a> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> IpcSchemaEncoder<'a> { + /// Create a new schema encoder + pub fn new() -> IpcSchemaEncoder<'a> { + IpcSchemaEncoder { + dictionary_tracker: None, + } + } + + /// Specify a dictionary tracker to use + pub fn with_dictionary_tracker( + mut self, + dictionary_tracker: &'a mut DictionaryTracker, + ) -> Self { + self.dictionary_tracker = Some(dictionary_tracker); + self + } + + /// Serialize a schema in IPC format, returning a completed [`FlatBufferBuilder`] + /// + /// Note: Call [`FlatBufferBuilder::finished_data`] to get the serialized bytes + pub fn schema_to_fb<'b>(&mut self, schema: &Schema) -> FlatBufferBuilder<'b> { + let mut fbb = FlatBufferBuilder::new(); + + let root = self.schema_to_fb_offset(&mut fbb, schema); + + fbb.finish(root, None); + + fbb + } + + /// Serialize a schema to an in progress [`FlatBufferBuilder`], returning the in progress offset. + pub fn schema_to_fb_offset<'b>( + &mut self, + fbb: &mut FlatBufferBuilder<'b>, + schema: &Schema, + ) -> WIPOffset> { + let fields = schema + .fields() + .iter() + .map(|field| build_field(fbb, &mut self.dictionary_tracker, field)) + .collect::>(); + let fb_field_list = fbb.create_vector(&fields); + + let fb_metadata_list = + (!schema.metadata().is_empty()).then(|| metadata_to_fb(fbb, schema.metadata())); + + let mut builder = crate::SchemaBuilder::new(fbb); + builder.add_fields(fb_field_list); + if let Some(fb_metadata_list) = fb_metadata_list { + builder.add_custom_metadata(fb_metadata_list); + } + builder.finish() + } +} + +/// Serialize a schema in IPC format +#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")] +pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder<'_> { + IpcSchemaEncoder::new().schema_to_fb(schema) +} + +pub fn metadata_to_fb<'a>( + fbb: &mut FlatBufferBuilder<'a>, + metadata: &HashMap, +) -> WIPOffset>>> { + let custom_metadata = metadata + .iter() + .map(|(k, v)| { + let fb_key_name = fbb.create_string(k); + let fb_val_name = fbb.create_string(v); + + let mut kv_builder = crate::KeyValueBuilder::new(fbb); + kv_builder.add_key(fb_key_name); + kv_builder.add_value(fb_val_name); + kv_builder.finish() + }) + .collect::>(); + fbb.create_vector(&custom_metadata) +} + +#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")] +pub fn schema_to_fb_offset<'a>( + fbb: &mut FlatBufferBuilder<'a>, + schema: &Schema, +) -> WIPOffset> { + IpcSchemaEncoder::new().schema_to_fb_offset(fbb, schema) +} + +/// Convert an IPC Field to Arrow Field +impl<'a> From> for Field { + fn from(field: crate::Field) -> Field { + let arrow_field = if let Some(dictionary) = field.dictionary() { + Field::new_dict( + field.name().unwrap(), + get_data_type(field, true), + field.nullable(), + dictionary.id(), + dictionary.isOrdered(), + ) + } else { + Field::new( + field.name().unwrap(), + get_data_type(field, true), + field.nullable(), + ) + }; + + let mut metadata_map = HashMap::default(); + if let Some(list) = field.custom_metadata() { + for kv in list { + if let (Some(k), Some(v)) = (kv.key(), kv.value()) { + metadata_map.insert(k.to_string(), v.to_string()); + } + } + } + + arrow_field.with_metadata(metadata_map) + } +} + +/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema]. +pub fn fb_to_schema(fb: crate::Schema) -> Schema { + let mut fields: Vec = vec![]; + let c_fields = fb.fields().unwrap(); + let len = c_fields.len(); + for i in 0..len { + let c_field: crate::Field = c_fields.get(i); + match c_field.type_type() { + crate::Type::Decimal if fb.endianness() == crate::Endianness::Big => { + unimplemented!("Big Endian is not supported for Decimal!") + } + _ => (), + }; + fields.push(c_field.into()); + } + + let mut metadata: HashMap = HashMap::default(); + if let Some(md_fields) = fb.custom_metadata() { + let len = md_fields.len(); + for i in 0..len { + let kv = md_fields.get(i); + let k_str = kv.key(); + let v_str = kv.value(); + if let Some(k) = k_str { + if let Some(v) = v_str { + metadata.insert(k.to_string(), v.to_string()); + } + } + } + } + Schema::new_with_metadata(fields, metadata) +} + +/// Try deserialize flat buffer format bytes into a schema +pub fn try_schema_from_flatbuffer_bytes(bytes: &[u8]) -> Result { + if let Ok(ipc) = crate::root_as_message(bytes) { + if let Some(schema) = ipc.header_as_schema().map(fb_to_schema) { + Ok(schema) + } else { + Err(ArrowError::ParseError( + "Unable to get head as schema".to_string(), + )) + } + } else { + Err(ArrowError::ParseError( + "Unable to get root as message".to_string(), + )) + } +} + +/// Try deserialize the IPC format bytes into a schema +pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result { + // There are two protocol types: https://issues.apache.org/jira/browse/ARROW-6313 + // The original protocol is: + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + // The latest version of protocol is: + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + if buffer.len() >= 4 { + // check continuation marker + let continuation_marker = &buffer[0..4]; + let begin_offset: usize = if continuation_marker.eq(&CONTINUATION_MARKER) { + // 4 bytes: CONTINUATION_MARKER + // 4 bytes: length + // buffer + 4 + } else { + // backward compatibility for buffer without the continuation marker + // 4 bytes: length + // buffer + 0 + }; + let msg = size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| { + ArrowError::ParseError(format!("Unable to convert flight info to a message: {err}")) + })?; + let ipc_schema = msg.header_as_schema().ok_or_else(|| { + ArrowError::ParseError("Unable to convert flight info to a schema".to_string()) + })?; + Ok(fb_to_schema(ipc_schema)) + } else { + Err(ArrowError::ParseError( + "The buffer length is less than 4 and missing the continuation marker or length of buffer".to_string() + )) + } +} + +/// Get the Arrow data type from the flatbuffer Field table +pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> DataType { + if let Some(dictionary) = field.dictionary() { + if may_be_dictionary { + let int = dictionary.indexType().unwrap(); + let index_type = match (int.bitWidth(), int.is_signed()) { + (8, true) => DataType::Int8, + (8, false) => DataType::UInt8, + (16, true) => DataType::Int16, + (16, false) => DataType::UInt16, + (32, true) => DataType::Int32, + (32, false) => DataType::UInt32, + (64, true) => DataType::Int64, + (64, false) => DataType::UInt64, + _ => panic!("Unexpected bitwidth and signed"), + }; + return DataType::Dictionary( + Box::new(index_type), + Box::new(get_data_type(field, false)), + ); + } + } + + match field.type_type() { + crate::Type::Null => DataType::Null, + crate::Type::Bool => DataType::Boolean, + crate::Type::Int => { + let int = field.type_as_int().unwrap(); + match (int.bitWidth(), int.is_signed()) { + (8, true) => DataType::Int8, + (8, false) => DataType::UInt8, + (16, true) => DataType::Int16, + (16, false) => DataType::UInt16, + (32, true) => DataType::Int32, + (32, false) => DataType::UInt32, + (64, true) => DataType::Int64, + (64, false) => DataType::UInt64, + z => panic!( + "Int type with bit width of {} and signed of {} not supported", + z.0, z.1 + ), + } + } + crate::Type::Binary => DataType::Binary, + crate::Type::BinaryView => DataType::BinaryView, + crate::Type::LargeBinary => DataType::LargeBinary, + crate::Type::Utf8 => DataType::Utf8, + crate::Type::Utf8View => DataType::Utf8View, + crate::Type::LargeUtf8 => DataType::LargeUtf8, + crate::Type::FixedSizeBinary => { + let fsb = field.type_as_fixed_size_binary().unwrap(); + DataType::FixedSizeBinary(fsb.byteWidth()) + } + crate::Type::FloatingPoint => { + let float = field.type_as_floating_point().unwrap(); + match float.precision() { + crate::Precision::HALF => DataType::Float16, + crate::Precision::SINGLE => DataType::Float32, + crate::Precision::DOUBLE => DataType::Float64, + z => panic!("FloatingPoint type with precision of {z:?} not supported"), + } + } + crate::Type::Date => { + let date = field.type_as_date().unwrap(); + match date.unit() { + crate::DateUnit::DAY => DataType::Date32, + crate::DateUnit::MILLISECOND => DataType::Date64, + z => panic!("Date type with unit of {z:?} not supported"), + } + } + crate::Type::Time => { + let time = field.type_as_time().unwrap(); + match (time.bitWidth(), time.unit()) { + (32, crate::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), + (32, crate::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond), + (64, crate::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond), + (64, crate::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), + z => panic!( + "Time type with bit width of {} and unit of {:?} not supported", + z.0, z.1 + ), + } + } + crate::Type::Timestamp => { + let timestamp = field.type_as_timestamp().unwrap(); + let timezone: Option<_> = timestamp.timezone().map(|tz| tz.into()); + match timestamp.unit() { + crate::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), + crate::TimeUnit::MILLISECOND => { + DataType::Timestamp(TimeUnit::Millisecond, timezone) + } + crate::TimeUnit::MICROSECOND => { + DataType::Timestamp(TimeUnit::Microsecond, timezone) + } + crate::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone), + z => panic!("Timestamp type with unit of {z:?} not supported"), + } + } + crate::Type::Interval => { + let interval = field.type_as_interval().unwrap(); + match interval.unit() { + crate::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth), + crate::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), + crate::IntervalUnit::MONTH_DAY_NANO => { + DataType::Interval(IntervalUnit::MonthDayNano) + } + z => panic!("Interval type with unit of {z:?} unsupported"), + } + } + crate::Type::Duration => { + let duration = field.type_as_duration().unwrap(); + match duration.unit() { + crate::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second), + crate::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond), + crate::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond), + crate::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond), + z => panic!("Duration type with unit of {z:?} unsupported"), + } + } + crate::Type::List => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a list to have one child") + } + DataType::List(Arc::new(children.get(0).into())) + } + crate::Type::LargeList => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a large list to have one child") + } + DataType::LargeList(Arc::new(children.get(0).into())) + } + crate::Type::FixedSizeList => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a list to have one child") + } + let fsl = field.type_as_fixed_size_list().unwrap(); + DataType::FixedSizeList(Arc::new(children.get(0).into()), fsl.listSize()) + } + crate::Type::Struct_ => { + let fields = match field.children() { + Some(children) => children.iter().map(Field::from).collect(), + None => Fields::empty(), + }; + DataType::Struct(fields) + } + crate::Type::RunEndEncoded => { + let children = field.children().unwrap(); + if children.len() != 2 { + panic!( + "RunEndEncoded type should have exactly two children. Found {}", + children.len() + ) + } + let run_ends_field = children.get(0).into(); + let values_field = children.get(1).into(); + DataType::RunEndEncoded(Arc::new(run_ends_field), Arc::new(values_field)) + } + crate::Type::Map => { + let map = field.type_as_map().unwrap(); + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a map to have one child") + } + DataType::Map(Arc::new(children.get(0).into()), map.keysSorted()) + } + crate::Type::Decimal => { + let fsb = field.type_as_decimal().unwrap(); + let bit_width = fsb.bitWidth(); + if bit_width == 128 { + DataType::Decimal128( + fsb.precision().try_into().unwrap(), + fsb.scale().try_into().unwrap(), + ) + } else if bit_width == 256 { + DataType::Decimal256( + fsb.precision().try_into().unwrap(), + fsb.scale().try_into().unwrap(), + ) + } else { + panic!("Unexpected decimal bit width {bit_width}") + } + } + crate::Type::Union => { + let union = field.type_as_union().unwrap(); + + let union_mode = match union.mode() { + crate::UnionMode::Dense => UnionMode::Dense, + crate::UnionMode::Sparse => UnionMode::Sparse, + mode => panic!("Unexpected union mode: {mode:?}"), + }; + + let mut fields = vec![]; + if let Some(children) = field.children() { + for i in 0..children.len() { + fields.push(Field::from(children.get(i))); + } + }; + + let fields = match union.typeIds() { + None => UnionFields::new(0_i8..fields.len() as i8, fields), + Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8), fields), + }; + + DataType::Union(fields, union_mode) + } + t => unimplemented!("Type {:?} not supported", t), + } +} + +pub(crate) struct FBFieldType<'b> { + pub(crate) type_type: crate::Type, + pub(crate) type_: WIPOffset, + pub(crate) children: Option>>>>, +} + +/// Create an IPC Field from an Arrow Field +pub(crate) fn build_field<'a>( + fbb: &mut FlatBufferBuilder<'a>, + dictionary_tracker: &mut Option<&mut DictionaryTracker>, + field: &Field, +) -> WIPOffset> { + // Optional custom metadata. + let mut fb_metadata = None; + if !field.metadata().is_empty() { + fb_metadata = Some(metadata_to_fb(fbb, field.metadata())); + }; + + let fb_field_name = fbb.create_string(field.name().as_str()); + let field_type = get_fb_field_type(field.data_type(), dictionary_tracker, fbb); + + let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() { + match dictionary_tracker { + Some(tracker) => Some(get_fb_dictionary( + index_type, + tracker.set_dict_id(field), + field + .dict_is_ordered() + .expect("All Dictionary types have `dict_is_ordered`"), + fbb, + )), + None => Some(get_fb_dictionary( + index_type, + field + .dict_id() + .expect("Dictionary type must have a dictionary id"), + field + .dict_is_ordered() + .expect("All Dictionary types have `dict_is_ordered`"), + fbb, + )), + } + } else { + None + }; + + let mut field_builder = crate::FieldBuilder::new(fbb); + field_builder.add_name(fb_field_name); + if let Some(dictionary) = fb_dictionary { + field_builder.add_dictionary(dictionary) + } + field_builder.add_type_type(field_type.type_type); + field_builder.add_nullable(field.is_nullable()); + match field_type.children { + None => {} + Some(children) => field_builder.add_children(children), + }; + field_builder.add_type_(field_type.type_); + + if let Some(fb_metadata) = fb_metadata { + field_builder.add_custom_metadata(fb_metadata); + } + + field_builder.finish() +} + +/// Get the IPC type of a data type +pub(crate) fn get_fb_field_type<'a>( + data_type: &DataType, + dictionary_tracker: &mut Option<&mut DictionaryTracker>, + fbb: &mut FlatBufferBuilder<'a>, +) -> FBFieldType<'a> { + // some IPC implementations expect an empty list for child data, instead of a null value. + // An empty field list is thus returned for primitive types + let empty_fields: Vec> = vec![]; + match data_type { + Null => FBFieldType { + type_type: crate::Type::Null, + type_: crate::NullBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Boolean => FBFieldType { + type_type: crate::Type::Bool, + type_: crate::BoolBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + UInt8 | UInt16 | UInt32 | UInt64 => { + let children = fbb.create_vector(&empty_fields[..]); + let mut builder = crate::IntBuilder::new(fbb); + builder.add_is_signed(false); + match data_type { + UInt8 => builder.add_bitWidth(8), + UInt16 => builder.add_bitWidth(16), + UInt32 => builder.add_bitWidth(32), + UInt64 => builder.add_bitWidth(64), + _ => {} + }; + FBFieldType { + type_type: crate::Type::Int, + type_: builder.finish().as_union_value(), + children: Some(children), + } + } + Int8 | Int16 | Int32 | Int64 => { + let children = fbb.create_vector(&empty_fields[..]); + let mut builder = crate::IntBuilder::new(fbb); + builder.add_is_signed(true); + match data_type { + Int8 => builder.add_bitWidth(8), + Int16 => builder.add_bitWidth(16), + Int32 => builder.add_bitWidth(32), + Int64 => builder.add_bitWidth(64), + _ => {} + }; + FBFieldType { + type_type: crate::Type::Int, + type_: builder.finish().as_union_value(), + children: Some(children), + } + } + Float16 | Float32 | Float64 => { + let children = fbb.create_vector(&empty_fields[..]); + let mut builder = crate::FloatingPointBuilder::new(fbb); + match data_type { + Float16 => builder.add_precision(crate::Precision::HALF), + Float32 => builder.add_precision(crate::Precision::SINGLE), + Float64 => builder.add_precision(crate::Precision::DOUBLE), + _ => {} + }; + FBFieldType { + type_type: crate::Type::FloatingPoint, + type_: builder.finish().as_union_value(), + children: Some(children), + } + } + Binary => FBFieldType { + type_type: crate::Type::Binary, + type_: crate::BinaryBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + LargeBinary => FBFieldType { + type_type: crate::Type::LargeBinary, + type_: crate::LargeBinaryBuilder::new(fbb) + .finish() + .as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + BinaryView => FBFieldType { + type_type: crate::Type::BinaryView, + type_: crate::BinaryViewBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Utf8View => FBFieldType { + type_type: crate::Type::Utf8View, + type_: crate::Utf8ViewBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Utf8 => FBFieldType { + type_type: crate::Type::Utf8, + type_: crate::Utf8Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + LargeUtf8 => FBFieldType { + type_type: crate::Type::LargeUtf8, + type_: crate::LargeUtf8Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + FixedSizeBinary(len) => { + let mut builder = crate::FixedSizeBinaryBuilder::new(fbb); + builder.add_byteWidth(*len); + FBFieldType { + type_type: crate::Type::FixedSizeBinary, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Date32 => { + let mut builder = crate::DateBuilder::new(fbb); + builder.add_unit(crate::DateUnit::DAY); + FBFieldType { + type_type: crate::Type::Date, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Date64 => { + let mut builder = crate::DateBuilder::new(fbb); + builder.add_unit(crate::DateUnit::MILLISECOND); + FBFieldType { + type_type: crate::Type::Date, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Time32(unit) | Time64(unit) => { + let mut builder = crate::TimeBuilder::new(fbb); + match unit { + TimeUnit::Second => { + builder.add_bitWidth(32); + builder.add_unit(crate::TimeUnit::SECOND); + } + TimeUnit::Millisecond => { + builder.add_bitWidth(32); + builder.add_unit(crate::TimeUnit::MILLISECOND); + } + TimeUnit::Microsecond => { + builder.add_bitWidth(64); + builder.add_unit(crate::TimeUnit::MICROSECOND); + } + TimeUnit::Nanosecond => { + builder.add_bitWidth(64); + builder.add_unit(crate::TimeUnit::NANOSECOND); + } + } + FBFieldType { + type_type: crate::Type::Time, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Timestamp(unit, tz) => { + let tz = tz.as_deref().unwrap_or_default(); + let tz_str = fbb.create_string(tz); + let mut builder = crate::TimestampBuilder::new(fbb); + let time_unit = match unit { + TimeUnit::Second => crate::TimeUnit::SECOND, + TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND, + }; + builder.add_unit(time_unit); + if !tz.is_empty() { + builder.add_timezone(tz_str); + } + FBFieldType { + type_type: crate::Type::Timestamp, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Interval(unit) => { + let mut builder = crate::IntervalBuilder::new(fbb); + let interval_unit = match unit { + IntervalUnit::YearMonth => crate::IntervalUnit::YEAR_MONTH, + IntervalUnit::DayTime => crate::IntervalUnit::DAY_TIME, + IntervalUnit::MonthDayNano => crate::IntervalUnit::MONTH_DAY_NANO, + }; + builder.add_unit(interval_unit); + FBFieldType { + type_type: crate::Type::Interval, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Duration(unit) => { + let mut builder = crate::DurationBuilder::new(fbb); + let time_unit = match unit { + TimeUnit::Second => crate::TimeUnit::SECOND, + TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND, + }; + builder.add_unit(time_unit); + FBFieldType { + type_type: crate::Type::Duration, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + List(ref list_type) => { + let child = build_field(fbb, dictionary_tracker, list_type); + FBFieldType { + type_type: crate::Type::List, + type_: crate::ListBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + ListView(_) | LargeListView(_) => unimplemented!("ListView/LargeListView not implemented"), + LargeList(ref list_type) => { + let child = build_field(fbb, dictionary_tracker, list_type); + FBFieldType { + type_type: crate::Type::LargeList, + type_: crate::LargeListBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + FixedSizeList(ref list_type, len) => { + let child = build_field(fbb, dictionary_tracker, list_type); + let mut builder = crate::FixedSizeListBuilder::new(fbb); + builder.add_listSize(*len); + FBFieldType { + type_type: crate::Type::FixedSizeList, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + Struct(fields) => { + // struct's fields are children + let mut children = vec![]; + for field in fields { + children.push(build_field(fbb, dictionary_tracker, field)); + } + FBFieldType { + type_type: crate::Type::Struct_, + type_: crate::Struct_Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } + RunEndEncoded(run_ends, values) => { + let run_ends_field = build_field(fbb, dictionary_tracker, run_ends); + let values_field = build_field(fbb, dictionary_tracker, values); + let children = [run_ends_field, values_field]; + FBFieldType { + type_type: crate::Type::RunEndEncoded, + type_: crate::RunEndEncodedBuilder::new(fbb) + .finish() + .as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } + Map(map_field, keys_sorted) => { + let child = build_field(fbb, dictionary_tracker, map_field); + let mut field_type = crate::MapBuilder::new(fbb); + field_type.add_keysSorted(*keys_sorted); + FBFieldType { + type_type: crate::Type::Map, + type_: field_type.finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + Dictionary(_, value_type) => { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + get_fb_field_type(value_type, dictionary_tracker, fbb) + } + Decimal128(precision, scale) => { + let mut builder = crate::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(128); + FBFieldType { + type_type: crate::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Decimal256(precision, scale) => { + let mut builder = crate::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(256); + FBFieldType { + type_type: crate::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Union(fields, mode) => { + let mut children = vec![]; + for (_, field) in fields.iter() { + children.push(build_field(fbb, dictionary_tracker, field)); + } + + let union_mode = match mode { + UnionMode::Sparse => crate::UnionMode::Sparse, + UnionMode::Dense => crate::UnionMode::Dense, + }; + + let fbb_type_ids = + fbb.create_vector(&fields.iter().map(|(t, _)| t as i32).collect::>()); + let mut builder = crate::UnionBuilder::new(fbb); + builder.add_mode(union_mode); + builder.add_typeIds(fbb_type_ids); + + FBFieldType { + type_type: crate::Type::Union, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } + } +} + +/// Create an IPC dictionary encoding +pub(crate) fn get_fb_dictionary<'a>( + index_type: &DataType, + dict_id: i64, + dict_is_ordered: bool, + fbb: &mut FlatBufferBuilder<'a>, +) -> WIPOffset> { + // We assume that the dictionary index type (as an integer) has already been + // validated elsewhere, and can safely assume we are dealing with integers + let mut index_builder = crate::IntBuilder::new(fbb); + + match *index_type { + Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), + UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), + _ => {} + } + + match *index_type { + Int8 | UInt8 => index_builder.add_bitWidth(8), + Int16 | UInt16 => index_builder.add_bitWidth(16), + Int32 | UInt32 => index_builder.add_bitWidth(32), + Int64 | UInt64 => index_builder.add_bitWidth(64), + _ => {} + } + + let index_builder = index_builder.finish(); + + let mut builder = crate::DictionaryEncodingBuilder::new(fbb); + builder.add_id(dict_id); + builder.add_indexType(index_builder); + builder.add_isOrdered(dict_is_ordered); + + builder.finish() +} + +/// An owned container for a validated [`Message`] +/// +/// Safely decoding a flatbuffer requires validating the various embedded offsets, +/// see [`Verifier`]. This is a potentially expensive operation, and it is therefore desirable +/// to only do this once. [`crate::root_as_message`] performs this validation on construction, +/// however, it returns a [`Message`] borrowing the provided byte slice. This prevents +/// storing this [`Message`] in the same data structure that owns the buffer, as this +/// would require self-referential borrows. +/// +/// [`MessageBuffer`] solves this problem by providing a safe API for a [`Message`] +/// without a lifetime bound. +#[derive(Clone)] +pub struct MessageBuffer(Buffer); + +impl Debug for MessageBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl MessageBuffer { + /// Try to create a [`MessageBuffer`] from the provided [`Buffer`] + pub fn try_new(buf: Buffer) -> Result { + let opts = VerifierOptions::default(); + let mut v = Verifier::new(&opts, &buf); + >::run_verifier(&mut v, 0).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) + })?; + Ok(Self(buf)) + } + + /// Return the [`Message`] + #[inline] + pub fn as_ref(&self) -> Message<'_> { + // SAFETY: Run verifier on construction + unsafe { crate::root_as_message_unchecked(&self.0) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert_schema_round_trip() { + let md: HashMap = [("Key".to_string(), "value".to_string())] + .iter() + .cloned() + .collect(); + let field_md: HashMap = [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(); + let schema = Schema::new_with_metadata( + vec![ + Field::new("uint8", DataType::UInt8, false).with_metadata(field_md), + Field::new("uint16", DataType::UInt16, true), + Field::new("uint32", DataType::UInt32, false), + Field::new("uint64", DataType::UInt64, true), + Field::new("int8", DataType::Int8, true), + Field::new("int16", DataType::Int16, false), + Field::new("int32", DataType::Int32, true), + Field::new("int64", DataType::Int64, false), + Field::new("float16", DataType::Float16, true), + Field::new("float32", DataType::Float32, false), + Field::new("float64", DataType::Float64, true), + Field::new("null", DataType::Null, false), + Field::new("bool", DataType::Boolean, false), + Field::new("date32", DataType::Date32, false), + Field::new("date64", DataType::Date64, true), + Field::new("time32[s]", DataType::Time32(TimeUnit::Second), true), + Field::new("time32[ms]", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("time64[us]", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("time64[ns]", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new( + "timestamp[s]", + DataType::Timestamp(TimeUnit::Second, None), + false, + ), + Field::new( + "timestamp[ms]", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "timestamp[us]", + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), + false, + ), + Field::new( + "timestamp[ns]", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "interval[ym]", + DataType::Interval(IntervalUnit::YearMonth), + true, + ), + Field::new( + "interval[dt]", + DataType::Interval(IntervalUnit::DayTime), + true, + ), + Field::new( + "interval[mdn]", + DataType::Interval(IntervalUnit::MonthDayNano), + true, + ), + Field::new("utf8", DataType::Utf8, false), + Field::new("utf8_view", DataType::Utf8View, false), + Field::new("binary", DataType::Binary, false), + Field::new("binary_view", DataType::BinaryView, false), + Field::new_list("list[u8]", Field::new("item", DataType::UInt8, false), true), + Field::new_fixed_size_list( + "fixed_size_list[u8]", + Field::new("item", DataType::UInt8, false), + 2, + true, + ), + Field::new_list( + "list[struct]", + Field::new_struct( + "struct", + vec![ + Field::new("float32", UInt8, false), + Field::new("int32", Int32, true), + Field::new("bool", Boolean, true), + ], + true, + ), + false, + ), + Field::new_struct( + "struct>", + vec![Field::new( + "dictionary", + Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )], + false, + ), + Field::new_struct( + "struct]>]>", + vec![ + Field::new("int64", DataType::Int64, true), + Field::new_list( + "list[struct]>]", + Field::new_struct( + "struct", + vec![ + Field::new("date32", DataType::Date32, true), + Field::new_list( + "list[struct<>]", + Field::new( + "struct", + DataType::Struct(Fields::empty()), + false, + ), + false, + ), + ], + false, + ), + false, + ), + ], + false, + ), + Field::new_union( + "union]>]>", + vec![0, 1], + vec![ + Field::new("int64", DataType::Int64, true), + Field::new_list( + "list[union]>]", + Field::new_union( + "union]>", + vec![0, 1], + vec![ + Field::new("date32", DataType::Date32, true), + Field::new_list( + "list[union<>]", + Field::new( + "union", + DataType::Union( + UnionFields::empty(), + UnionMode::Sparse, + ), + false, + ), + false, + ), + ], + UnionMode::Dense, + ), + false, + ), + ], + UnionMode::Sparse, + ), + Field::new("struct<>", DataType::Struct(Fields::empty()), true), + Field::new( + "union<>", + DataType::Union(UnionFields::empty(), UnionMode::Dense), + true, + ), + Field::new( + "union<>", + DataType::Union(UnionFields::empty(), UnionMode::Sparse), + true, + ), + Field::new( + "union", + DataType::Union( + UnionFields::new( + vec![2, 3], // non-default type ids + vec![ + Field::new("int32", DataType::Int32, true), + Field::new("utf8", DataType::Utf8, true), + ], + ), + UnionMode::Dense, + ), + true, + ), + Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 123, + true, + ), + Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), + true, + 123, + true, + ), + Field::new("decimal", DataType::Decimal128(10, 6), false), + ], + md, + ); + + let mut dictionary_tracker = DictionaryTracker::new(true); + let fb = IpcSchemaEncoder::new() + .with_dictionary_tracker(&mut dictionary_tracker) + .schema_to_fb(&schema); + + // read back fields + let ipc = crate::root_as_schema(fb.finished_data()).unwrap(); + let schema2 = fb_to_schema(ipc); + assert_eq!(schema, schema2); + } + + #[test] + fn schema_from_bytes() { + // Bytes of a schema generated via following python code, using pyarrow 10.0.1: + // + // import pyarrow as pa + // schema = pa.schema([pa.field('field1', pa.uint32(), nullable=False)]) + // sink = pa.BufferOutputStream() + // with pa.ipc.new_stream(sink, schema) as writer: + // pass + // # stripping continuation & length prefix & suffix bytes to get only schema bytes + // [x for x in sink.getvalue().to_pybytes()][8:-8] + let bytes: Vec = vec![ + 16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 4, 0, 12, 0, 0, + 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, 0, 0, 0, 16, 0, 20, + 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, 0, 0, 2, 16, 0, 0, 0, 32, 0, + 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0, 0, 0, 6, + 0, 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, + ]; + let ipc = crate::root_as_message(&bytes).unwrap(); + let schema = ipc.header_as_schema().unwrap(); + + // generate same message with Rust + let data_gen = crate::writer::IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(true); + let arrow_schema = Schema::new(vec![Field::new("field1", DataType::UInt32, false)]); + let bytes = data_gen + .schema_to_bytes_with_dictionary_tracker( + &arrow_schema, + &mut dictionary_tracker, + &crate::writer::IpcWriteOptions::default(), + ) + .ipc_message; + + let ipc2 = crate::root_as_message(&bytes).unwrap(); + let schema2 = ipc2.header_as_schema().unwrap(); + + // can't compare schema directly as it compares the underlying bytes, which can differ + assert!(schema.custom_metadata().is_none()); + assert!(schema2.custom_metadata().is_none()); + assert_eq!(schema.endianness(), schema2.endianness()); + assert!(schema.features().is_none()); + assert!(schema2.features().is_none()); + assert_eq!(fb_to_schema(schema), fb_to_schema(schema2)); + + assert_eq!(ipc.version(), ipc2.version()); + assert_eq!(ipc.header_type(), ipc2.header_type()); + assert_eq!(ipc.bodyLength(), ipc2.bodyLength()); + assert!(ipc.custom_metadata().is_none()); + assert!(ipc2.custom_metadata().is_none()); + } +} diff --git a/arrow/src/ipc/gen/File.rs b/arrow-ipc/src/gen/File.rs similarity index 69% rename from arrow/src/ipc/gen/File.rs rename to arrow-ipc/src/gen/File.rs index 04cbc6441377..c0c2fb183237 100644 --- a/arrow/src/ipc/gen/File.rs +++ b/arrow-ipc/src/gen/File.rs @@ -18,7 +18,7 @@ #![allow(dead_code)] #![allow(unused_imports)] -use crate::ipc::gen::Schema::*; +use crate::gen::Schema::*; use flatbuffers::EndianScalar; use std::{cmp::Ordering, mem}; // automatically generated by the FlatBuffers compiler, do not modify @@ -27,8 +27,13 @@ use std::{cmp::Ordering, mem}; #[repr(transparent)] #[derive(Clone, Copy, PartialEq)] pub struct Block(pub [u8; 24]); -impl std::fmt::Debug for Block { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Default for Block { + fn default() -> Self { + Self([0; 24]) + } +} +impl core::fmt::Debug for Block { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { f.debug_struct("Block") .field("offset", &self.offset()) .field("metaDataLength", &self.metaDataLength()) @@ -38,39 +43,25 @@ impl std::fmt::Debug for Block { } impl flatbuffers::SimpleToVerifyInSlice for Block {} -impl flatbuffers::SafeSliceAccess for Block {} impl<'a> flatbuffers::Follow<'a> for Block { type Inner = &'a Block; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { <&'a Block>::follow(buf, loc) } } impl<'a> flatbuffers::Follow<'a> for &'a Block { type Inner = &'a Block; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { flatbuffers::follow_cast_ref::(buf, loc) } } impl<'b> flatbuffers::Push for Block { type Output = Block; #[inline] - fn push(&self, dst: &mut [u8], _rest: &[u8]) { - let src = unsafe { - ::std::slice::from_raw_parts(self as *const Block as *const u8, Self::size()) - }; - dst.copy_from_slice(src); - } -} -impl<'b> flatbuffers::Push for &'b Block { - type Output = Block; - - #[inline] - fn push(&self, dst: &mut [u8], _rest: &[u8]) { - let src = unsafe { - ::std::slice::from_raw_parts(*self as *const Block as *const u8, Self::size()) - }; + unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { + let src = ::core::slice::from_raw_parts(self as *const Block as *const u8, Self::size()); dst.copy_from_slice(src); } } @@ -85,7 +76,8 @@ impl<'a> flatbuffers::Verifiable for Block { v.in_buffer::(pos) } } -impl Block { + +impl<'a> Block { #[allow(clippy::too_many_arguments)] pub fn new(offset: i64, metaDataLength: i32, bodyLength: i64) -> Self { let mut s = Self([0; 24]); @@ -97,50 +89,60 @@ impl Block { /// Index to the start of the RecordBlock (note this is past the Message header) pub fn offset(&self) -> i64 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[0..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_offset(&mut self, x: i64) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i64 as *const u8, + &x_le as *const _ as *const u8, self.0[0..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } /// Length of the metadata pub fn metaDataLength(&self) -> i32 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[8..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_metaDataLength(&mut self, x: i32) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i32 as *const u8, + &x_le as *const _ as *const u8, self.0[8..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } @@ -148,25 +150,30 @@ impl Block { /// Length of the data (this is aligned so there can be a gap between this and /// the metadata). pub fn bodyLength(&self) -> i64 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[16..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_bodyLength(&mut self, x: i64) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i64 as *const u8, + &x_le as *const _ as *const u8, self.0[16..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } @@ -185,16 +192,22 @@ pub struct Footer<'a> { impl<'a> flatbuffers::Follow<'a> for Footer<'a> { type Inner = Footer<'a>; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { - _tab: flatbuffers::Table { buf, loc }, + _tab: flatbuffers::Table::new(buf, loc), } } } impl<'a> Footer<'a> { + pub const VT_VERSION: flatbuffers::VOffsetT = 4; + pub const VT_SCHEMA: flatbuffers::VOffsetT = 6; + pub const VT_DICTIONARIES: flatbuffers::VOffsetT = 8; + pub const VT_RECORDBATCHES: flatbuffers::VOffsetT = 10; + pub const VT_CUSTOM_METADATA: flatbuffers::VOffsetT = 12; + #[inline] - pub fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { Footer { _tab: table } } #[allow(unused_mut)] @@ -219,49 +232,66 @@ impl<'a> Footer<'a> { builder.finish() } - pub const VT_VERSION: flatbuffers::VOffsetT = 4; - pub const VT_SCHEMA: flatbuffers::VOffsetT = 6; - pub const VT_DICTIONARIES: flatbuffers::VOffsetT = 8; - pub const VT_RECORDBATCHES: flatbuffers::VOffsetT = 10; - pub const VT_CUSTOM_METADATA: flatbuffers::VOffsetT = 12; - #[inline] pub fn version(&self) -> MetadataVersion { - self._tab - .get::(Footer::VT_VERSION, Some(MetadataVersion::V1)) - .unwrap() + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(Footer::VT_VERSION, Some(MetadataVersion::V1)) + .unwrap() + } } #[inline] pub fn schema(&self) -> Option> { - self._tab - .get::>(Footer::VT_SCHEMA, None) + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Footer::VT_SCHEMA, None) + } } #[inline] - pub fn dictionaries(&self) -> Option<&'a [Block]> { - self._tab - .get::>>( - Footer::VT_DICTIONARIES, - None, - ) - .map(|v| v.safe_slice()) + pub fn dictionaries(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + Footer::VT_DICTIONARIES, + None, + ) + } } #[inline] - pub fn recordBatches(&self) -> Option<&'a [Block]> { - self._tab - .get::>>( - Footer::VT_RECORDBATCHES, - None, - ) - .map(|v| v.safe_slice()) + pub fn recordBatches(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + Footer::VT_RECORDBATCHES, + None, + ) + } } /// User-defined metadata #[inline] pub fn custom_metadata( &self, ) -> Option>>> { - self._tab.get::>, - >>(Footer::VT_CUSTOM_METADATA, None) + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab.get::>, + >>(Footer::VT_CUSTOM_METADATA, None) + } } } @@ -273,25 +303,21 @@ impl flatbuffers::Verifiable for Footer<'_> { ) -> Result<(), flatbuffers::InvalidFlatbuffer> { use flatbuffers::Verifiable; v.visit_table(pos)? - .visit_field::(&"version", Self::VT_VERSION, false)? - .visit_field::>( - &"schema", - Self::VT_SCHEMA, - false, - )? + .visit_field::("version", Self::VT_VERSION, false)? + .visit_field::>("schema", Self::VT_SCHEMA, false)? .visit_field::>>( - &"dictionaries", + "dictionaries", Self::VT_DICTIONARIES, false, )? .visit_field::>>( - &"recordBatches", + "recordBatches", Self::VT_RECORDBATCHES, false, )? .visit_field::>, - >>(&"custom_metadata", Self::VT_CUSTOM_METADATA, false)? + >>("custom_metadata", Self::VT_CUSTOM_METADATA, false)? .finish(); Ok(()) } @@ -302,9 +328,7 @@ pub struct FooterArgs<'a> { pub dictionaries: Option>>, pub recordBatches: Option>>, pub custom_metadata: Option< - flatbuffers::WIPOffset< - flatbuffers::Vector<'a, flatbuffers::ForwardsUOffset>>, - >, + flatbuffers::WIPOffset>>>, >, } impl<'a> Default for FooterArgs<'a> { @@ -319,6 +343,7 @@ impl<'a> Default for FooterArgs<'a> { } } } + pub struct FooterBuilder<'a: 'b, 'b> { fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a>, start_: flatbuffers::WIPOffset, @@ -326,39 +351,29 @@ pub struct FooterBuilder<'a: 'b, 'b> { impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { #[inline] pub fn add_version(&mut self, version: MetadataVersion) { - self.fbb_.push_slot::( - Footer::VT_VERSION, - version, - MetadataVersion::V1, - ); + self.fbb_ + .push_slot::(Footer::VT_VERSION, version, MetadataVersion::V1); } #[inline] pub fn add_schema(&mut self, schema: flatbuffers::WIPOffset>) { self.fbb_ - .push_slot_always::>( - Footer::VT_SCHEMA, - schema, - ); + .push_slot_always::>(Footer::VT_SCHEMA, schema); } #[inline] pub fn add_dictionaries( &mut self, dictionaries: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_DICTIONARIES, - dictionaries, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_DICTIONARIES, dictionaries); } #[inline] pub fn add_recordBatches( &mut self, recordBatches: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_RECORDBATCHES, - recordBatches, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_RECORDBATCHES, recordBatches); } #[inline] pub fn add_custom_metadata( @@ -373,9 +388,7 @@ impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { ); } #[inline] - pub fn new( - _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, - ) -> FooterBuilder<'a, 'b> { + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> FooterBuilder<'a, 'b> { let start = _fbb.start_table(); FooterBuilder { fbb_: _fbb, @@ -389,8 +402,8 @@ impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { } } -impl std::fmt::Debug for Footer<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for Footer<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("Footer"); ds.field("version", &self.version()); ds.field("schema", &self.schema()); @@ -400,18 +413,6 @@ impl std::fmt::Debug for Footer<'_> { ds.finish() } } -#[inline] -#[deprecated(since = "2.0.0", note = "Deprecated in favor of `root_as...` methods.")] -pub fn get_root_as_footer<'a>(buf: &'a [u8]) -> Footer<'a> { - unsafe { flatbuffers::root_unchecked::>(buf) } -} - -#[inline] -#[deprecated(since = "2.0.0", note = "Deprecated in favor of `root_as...` methods.")] -pub fn get_size_prefixed_root_as_footer<'a>(buf: &'a [u8]) -> Footer<'a> { - unsafe { flatbuffers::size_prefixed_root_unchecked::>(buf) } -} - #[inline] /// Verifies that a buffer of bytes contains a `Footer` /// and returns it. @@ -429,9 +430,7 @@ pub fn root_as_footer(buf: &[u8]) -> Result Result { +pub fn size_prefixed_root_as_footer(buf: &[u8]) -> Result { flatbuffers::size_prefixed_root::