From 04113345f38eacc0f886f85923cbf739d92097bb Mon Sep 17 00:00:00 2001 From: Laramie Leavitt Date: Mon, 11 Nov 2024 14:38:57 -0800 Subject: [PATCH] Rework PartitionIndexTransformOverGrid to avoid recursion. PiperOrigin-RevId: 695486670 Change-Id: Id0ad384bba73227b0bc286796aa491ee4b443382 --- .../driver/neuroglancer_precomputed/driver.cc | 1 + tensorstore/driver/stack/driver.cc | 3 +- tensorstore/driver/zarr3/chunk_cache.cc | 1 + tensorstore/internal/BUILD | 27 +- tensorstore/internal/cache/chunk_cache.cc | 1 + tensorstore/internal/grid_partition.cc | 193 ----- tensorstore/internal/grid_partition.h | 34 - .../internal/grid_partition_iterator.cc | 233 ++++++ .../internal/grid_partition_iterator.h | 177 +++++ .../internal/grid_partition_iterator_test.cc | 699 +++++++++++++++++ tensorstore/internal/grid_partition_test.cc | 706 ------------------ 11 files changed, 1139 insertions(+), 936 deletions(-) create mode 100644 tensorstore/internal/grid_partition_iterator.cc create mode 100644 tensorstore/internal/grid_partition_iterator.h create mode 100644 tensorstore/internal/grid_partition_iterator_test.cc diff --git a/tensorstore/driver/neuroglancer_precomputed/driver.cc b/tensorstore/driver/neuroglancer_precomputed/driver.cc index bc25a88c4..1572e4c68 100644 --- a/tensorstore/driver/neuroglancer_precomputed/driver.cc +++ b/tensorstore/driver/neuroglancer_precomputed/driver.cc @@ -66,6 +66,7 @@ #include "tensorstore/internal/chunk_grid_specification.h" #include "tensorstore/internal/grid_chunk_key_ranges_base10.h" #include "tensorstore/internal/grid_partition.h" +#include "tensorstore/internal/grid_partition_iterator.h" #include "tensorstore/internal/grid_storage_statistics.h" #include "tensorstore/internal/json_binding/bindable.h" #include "tensorstore/internal/json_binding/json_binding.h" diff --git a/tensorstore/driver/stack/driver.cc b/tensorstore/driver/stack/driver.cc index d246c1c75..184b7f0df 100644 --- a/tensorstore/driver/stack/driver.cc +++ b/tensorstore/driver/stack/driver.cc @@ -14,10 +14,10 @@ #include "tensorstore/driver/driver.h" -#include #include #include +#include #include #include #include @@ -52,6 +52,7 @@ #include "tensorstore/internal/concurrency_resource.h" #include "tensorstore/internal/data_copy_concurrency_resource.h" #include "tensorstore/internal/grid_partition.h" +#include "tensorstore/internal/grid_partition_iterator.h" #include "tensorstore/internal/intrusive_ptr.h" #include "tensorstore/internal/irregular_grid.h" #include "tensorstore/internal/json_binding/json_binding.h" diff --git a/tensorstore/driver/zarr3/chunk_cache.cc b/tensorstore/driver/zarr3/chunk_cache.cc index e3f00e24b..9ef48832b 100644 --- a/tensorstore/driver/zarr3/chunk_cache.cc +++ b/tensorstore/driver/zarr3/chunk_cache.cc @@ -44,6 +44,7 @@ #include "tensorstore/internal/cache/kvs_backed_chunk_cache.h" #include "tensorstore/internal/chunk_grid_specification.h" #include "tensorstore/internal/grid_partition.h" +#include "tensorstore/internal/grid_partition_iterator.h" #include "tensorstore/internal/grid_storage_statistics.h" #include "tensorstore/internal/intrusive_ptr.h" #include "tensorstore/internal/lexicographical_grid_index_key.h" diff --git a/tensorstore/internal/BUILD b/tensorstore/internal/BUILD index a6736b539..fb76c94a8 100644 --- a/tensorstore/internal/BUILD +++ b/tensorstore/internal/BUILD @@ -491,13 +491,16 @@ tensorstore_cc_test( tensorstore_cc_library( name = "grid_partition", - srcs = ["grid_partition.cc"], + srcs = [ + "grid_partition.cc", + "grid_partition_iterator.cc", + ], hdrs = [ "grid_partition.h", + "grid_partition_iterator.h", ], deps = [ ":grid_partition_impl", - ":intrusive_ptr", "//tensorstore:box", "//tensorstore:index", "//tensorstore:index_interval", @@ -512,6 +515,7 @@ tensorstore_cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", ], ) @@ -1869,3 +1873,22 @@ tensorstore_cc_test( "@com_google_googletest//:gtest_main", ], ) + +tensorstore_cc_test( + name = "grid_partition_iterator_test", + srcs = ["grid_partition_iterator_test.cc"], + deps = [ + ":grid_partition", + ":grid_partition_impl", + ":regular_grid", + "//tensorstore:array", + "//tensorstore:index", + "//tensorstore:index_interval", + "//tensorstore/index_space:index_transform", + "//tensorstore/util:result", + "//tensorstore/util:span", + "//tensorstore/util:status", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorstore/internal/cache/chunk_cache.cc b/tensorstore/internal/cache/chunk_cache.cc index 05b879457..fe86774f1 100644 --- a/tensorstore/internal/cache/chunk_cache.cc +++ b/tensorstore/internal/cache/chunk_cache.cc @@ -41,6 +41,7 @@ #include "tensorstore/internal/chunk_grid_specification.h" #include "tensorstore/internal/elementwise_function.h" #include "tensorstore/internal/grid_partition.h" +#include "tensorstore/internal/grid_partition_iterator.h" #include "tensorstore/internal/intrusive_ptr.h" #include "tensorstore/internal/lock_collection.h" #include "tensorstore/internal/memory.h" diff --git a/tensorstore/internal/grid_partition.cc b/tensorstore/internal/grid_partition.cc index a10a920d8..e14cdfe83 100644 --- a/tensorstore/internal/grid_partition.cc +++ b/tensorstore/internal/grid_partition.cc @@ -19,10 +19,8 @@ #include #include #include -#include #include -#include "absl/container/fixed_array.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" @@ -34,10 +32,8 @@ #include "tensorstore/index_space/output_index_map.h" #include "tensorstore/index_space/output_index_method.h" #include "tensorstore/internal/grid_partition_impl.h" -#include "tensorstore/internal/intrusive_ptr.h" #include "tensorstore/rank.h" #include "tensorstore/util/dimension_set.h" -#include "tensorstore/util/iterate.h" #include "tensorstore/util/result.h" #include "tensorstore/util/span.h" #include "tensorstore/util/status.h" @@ -49,7 +45,6 @@ namespace tensorstore { namespace internal_grid_partition { namespace { - using IndexArraySet = IndexTransformGridPartition::IndexArraySet; using StridedSet = IndexTransformGridPartition::StridedSet; @@ -58,10 +53,6 @@ struct ConnectedSetIterateParameters { tensorstore::span grid_output_dimensions; OutputToGridCellFn output_to_grid_cell; IndexTransformView<> transform; - absl::FunctionRef grid_cell_indices, - IndexTransformView<> cell_transform)> - func; }; /// Sets the fixed grid cell indices for all grid dimensions that do not @@ -149,165 +140,6 @@ class StridedSetGridCellIterator { Index input_index_; }; -// Java-style iterator for computing the grid cells that intersect the original -// input domain for a given `StridedSet`. -// -/// For each grid cell, updates the `output_grid_cell_indices` for the given -/// index array. -class IndexArraySetIterator { - public: - IndexArraySetIterator(const IndexArraySet& index_array_set) - : grid_dimensions_(index_array_set.grid_dimensions), - grid_cell_indices_(index_array_set.grid_cell_indices), - partition_end_index_(index_array_set.num_partitions()), - partition_index_(0) {} - - void Reset() { partition_index_ = 0; } - - bool AtEnd() const { return partition_index_ == partition_end_index_; } - - Index Next(tensorstore::span output_grid_cell_indices) { - assert(!AtEnd()); - - // Assign the grid_cell_indices to the precomputed grid cell indices for - // this partition. - const Index grid_cell_indices_offset = - partition_index_ * grid_dimensions_.count(); - DimensionIndex grid_i = 0; - for (DimensionIndex grid_dim : grid_dimensions_.index_view()) { - output_grid_cell_indices[grid_dim] = - grid_cell_indices_[grid_cell_indices_offset + grid_i++]; - } - - return partition_index_++; - } - - private: - DimensionSet grid_dimensions_; - tensorstore::span grid_cell_indices_; - Index partition_end_index_; - Index partition_index_; -}; - -/// Helper class for iterating over the grid cell index vectors and computing -/// the `cell_transform` for each grid cell, based on precomputed -/// `IndexTransformGridPartition` data. -class ConnectedSetIterateHelper { - public: - explicit ConnectedSetIterateHelper(ConnectedSetIterateParameters params) - : params_(std::move(params)), - grid_cell_indices_(params_.grid_output_dimensions.size()), - cell_transform_(internal_grid_partition::InitializeCellTransform( - params_.info, params_.transform)) { - InitializeConstantGridCellIndices( - params_.transform, params_.grid_output_dimensions, - params_.output_to_grid_cell, grid_cell_indices_); - } - - /// Iterates over all grid cells and invokes the iteration callback function. - /// - /// This is implemented by recursively iterating over the partitions of each - /// connected set. - absl::Status Iterate() { return IterateOverIndexArraySets(0); } - - private: - /// Recursively iterates over the partial grid cells corresponding to the - /// index array connected sets, starting with `set_i`. - /// - /// For each grid cell, updates the `grid_cell_indices` for all grid - /// dimensions in the connected set and updates the `cell_transform` array - /// output index maps corresponding to each original input dimension in the - /// connected set. - /// - /// If there are no remaining index array connected sets over which to - /// recurse, starts recusing over the strided connected sets. - /// - /// Iteration is aborted if `InvokeCallback` returns an error. - /// - /// \param set_i The next index array connected set over which to iterate, in - /// the range `[0, info.index_array_sets().size()]`. - /// \returns The return value of the last recursively call. - absl::Status IterateOverIndexArraySets(DimensionIndex set_i) { - if (set_i == params_.info.index_array_sets().size()) { - return IterateOverStridedSets(0); - } - const IndexArraySet& index_array_set = - params_.info.index_array_sets()[set_i]; - IndexArraySetIterator iterator(index_array_set); - while (!iterator.AtEnd()) { - Index partition_i = iterator.Next(grid_cell_indices_); - UpdateCellTransformForIndexArraySetPartition( - index_array_set, set_i, partition_i, cell_transform_.get()); - TENSORSTORE_RETURN_IF_ERROR(IterateOverIndexArraySets(set_i + 1)); - } - return absl::OkStatus(); - } - - /// Recursively iterates over the partial grid cells corresponding to the - /// strided connected sets, starting with `set_i`. - /// - /// For each grid cell, updates the `grid_cell_indices` for all grid - /// dimensions in the connected set, and updates the input domain of the - /// corresponding synthetic input dimension of `cell_transform`. The output - /// index maps do not need to be updated. - /// - /// If there are no remaining strided sets over which to recurse, just invokes - /// the iteration callback function. - /// - /// Iteration is aborted if `InvokeCallback` returns an error. - /// - /// \param set_i The next strided connected set over which to iterate, in the - /// range `[0, info.strided_sets().size()]`. - /// \returns The return value of the last recursive call, or the last call to - /// `InvokeCallback`. - absl::Status IterateOverStridedSets(DimensionIndex set_i) { - if (set_i == params_.info.strided_sets().size()) return InvokeCallback(); - StridedSetGridCellIterator iterator( - params_.transform, params_.grid_output_dimensions, - params_.output_to_grid_cell, params_.info.strided_sets()[set_i]); - const DimensionIndex cell_input_dim = - set_i + params_.info.index_array_sets().size(); - while (!iterator.AtEnd()) { - auto restricted_domain = iterator.Next(grid_cell_indices_); - // Set the input domain for `cell_input_dim` for the duration of the - // subsequent recursive call to IterateOverStridedSets. - cell_transform_->input_origin()[cell_input_dim] = - restricted_domain.inclusive_min(); - cell_transform_->input_shape()[cell_input_dim] = restricted_domain.size(); - // Recursively iterate over the next strided connected set. - TENSORSTORE_RETURN_IF_ERROR(IterateOverStridedSets(set_i + 1)); - } - return absl::OkStatus(); - } - - /// Calls the iteration callback function. - /// - /// If an error `absl::Status` is returned, iteration should stop. - /// - /// \error Any error returned by the iteration callback function. - absl::Status InvokeCallback() { - internal_index_space::DebugCheckInvariants(cell_transform_.get()); - auto status = params_.func( - grid_cell_indices_, - TransformAccess::Make>(cell_transform_.get())); - // If `func` created and is still holding a reference to `cell_transform_`, - // we need to make a copy before modifying it. - cell_transform_ = MutableRep(std::move(cell_transform_)); - return status; - } - - ConnectedSetIterateParameters params_; - - // Current grid cell index vector `h` as defined in grid_partition.h, modified - // in place while iterating over all index vectors in `H`. - absl::FixedArray grid_cell_indices_; - - // This stores the current value of `cell_transform[h]`, as defined in - // grid_partition.h, for `h = grid_cell_indices_`. This is modified in - // place while iterating over all values for grid_cell_indices_. - internal_index_space::TransformRep::Ptr<> cell_transform_; -}; - bool GetStridedGridCellRanges( IndexTransformView<> transform, OutputToGridCellFn output_to_grid_cell, DimensionIndex grid_dim, DimensionIndex output_dim, @@ -511,31 +343,6 @@ class GetGridCellRangesIterateHelper { } // namespace } // namespace internal_grid_partition -namespace internal { - -absl::Status PartitionIndexTransformOverGrid( - tensorstore::span grid_output_dimensions, - OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform, - absl::FunctionRef< - absl::Status(tensorstore::span grid_cell_indices, - IndexTransformView<> cell_transform)> - func) { - internal_grid_partition::IndexTransformGridPartition partition_info; - auto status = internal_grid_partition::PrePartitionIndexTransformOverGrid( - transform, grid_output_dimensions, output_to_grid_cell, partition_info); - - if (!status.ok()) return status; - return internal_grid_partition::ConnectedSetIterateHelper( - {/*.info=*/partition_info, - /*.grid_output_dimensions=*/grid_output_dimensions, - /*.output_to_grid_cell=*/output_to_grid_cell, - /*.transform=*/transform, - /*.func=*/std::move(func)}) - .Iterate(); -} - -} // namespace internal - namespace internal_grid_partition { absl::Status GetGridCellRanges( const IndexTransformGridPartition& grid_partition, diff --git a/tensorstore/internal/grid_partition.h b/tensorstore/internal/grid_partition.h index 26438a6e1..e158c6105 100644 --- a/tensorstore/internal/grid_partition.h +++ b/tensorstore/internal/grid_partition.h @@ -113,40 +113,6 @@ namespace internal { using OutputToGridCellFn = absl::FunctionRef; -/// Partitions the input domain of a given `transform` from an input space -/// "full" to an output space "output" based on the grid (potentially irregular) -/// specified by `output_to_grid_cell`, which maps from a given dimension and -/// output_index to a grid cell and optional cell bounds. -/// -/// For each grid cell index vector `h` in `H`, calls -/// `func(h, cell_transform[h])`. -/// -/// To partition over a regular grid, `output_to_grid_cell` can be -/// internal_grid_partition::RegularGridRef. -/// -/// \param grid_output_dimensions The sequence of dimensions of the index space -/// "output" corresponding to the grid by which to partition "full". -/// \param output_to_grid_cell Function returning, for the provided grid -/// dimension, the cell index corresponding to output_index, optionally -/// filling the bounds for the cell. -/// \param transform The index transform from "full" to "output". Must be -/// valid. -/// \param func The function to be called for each partition. May return an -/// error `absl::Status` to abort the iteration. -/// \returns `absl::Status()` on success, or the last error returned by `func`. -/// \error `absl::StatusCode::kInvalidArgument` if any input dimension of -/// `transform` has an unbounded domain. -/// \error `absl::StatusCode::kInvalidArgument` if integer overflow occurs. -/// \error `absl::StatusCode::kOutOfRange` if an index array contains an -/// out-of-bounds index. -absl::Status PartitionIndexTransformOverGrid( - tensorstore::span grid_output_dimensions, - OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform, - absl::FunctionRef< - absl::Status(tensorstore::span grid_cell_indices, - IndexTransformView<> cell_transform)> - func); - absl::Status GetGridCellRanges( tensorstore::span grid_output_dimensions, BoxView<> grid_bounds, OutputToGridCellFn output_to_grid_cell, diff --git a/tensorstore/internal/grid_partition_iterator.cc b/tensorstore/internal/grid_partition_iterator.cc new file mode 100644 index 000000000..798f06351 --- /dev/null +++ b/tensorstore/internal/grid_partition_iterator.cc @@ -0,0 +1,233 @@ +// Copyright 2024 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/internal/grid_partition_iterator.h" + +#include + +#include + +#include "absl/container/fixed_array.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "tensorstore/index.h" +#include "tensorstore/index_interval.h" +#include "tensorstore/index_space/index_transform.h" +#include "tensorstore/index_space/internal/transform_rep.h" +#include "tensorstore/index_space/output_index_map.h" +#include "tensorstore/index_space/output_index_method.h" +#include "tensorstore/internal/grid_partition_impl.h" +#include "tensorstore/util/span.h" +#include "tensorstore/util/status.h" + +namespace tensorstore { +namespace internal_grid_partition { + +using IndexArraySet = IndexTransformGridPartition::IndexArraySet; +using StridedSet = IndexTransformGridPartition::StridedSet; + +PartitionIndexTransformIterator::PartitionIndexTransformIterator( + internal_grid_partition::IndexTransformGridPartition&& partition_info, + tensorstore::span grid_output_dimensions, + OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform) + : partition_info_(std::move(partition_info)), + grid_output_dimensions_(grid_output_dimensions.begin(), + grid_output_dimensions.end()), + output_to_grid_cell_(std::move(output_to_grid_cell)), + transform_(std::move(transform)), + at_end_(false), + cell_transform_(internal_grid_partition::InitializeCellTransform( + partition_info_, transform_)), + output_grid_cell_indices_(grid_output_dimensions_.size()), + position_(rank()), + upper_bound_(rank()), + strided_next_position_(partition_info_.strided_sets().size()) { + // Initialize the output_grid_cell_indices for the constant outputs. + for (DimensionIndex grid_dim = 0; grid_dim < grid_output_dimensions_.size(); + ++grid_dim) { + const DimensionIndex output_dim = grid_output_dimensions_[grid_dim]; + const OutputIndexMapRef<> map = transform_.output_index_map(output_dim); + if (map.method() != OutputIndexMethod::constant) continue; + output_grid_cell_indices_[grid_dim] = + output_to_grid_cell_(grid_dim, map.offset(), nullptr); + } + // Initialize the iteration positions. + for (size_t i = 0; i < rank(); ++i) { + if (i < partition_info_.index_array_sets().size()) { + ResetIndexArraySet(i); + } else { + ResetStridedSet(i); + } + at_end_ = at_end_ || (position_[i] == upper_bound_[i]); + } + if (!at_end_) { + for (size_t i = 0; i < rank(); ++i) { + if (i < partition_info_.index_array_sets().size()) { + ApplyIndexArraySet(i); + } else { + ApplyStridedSet(i); + } + } + } +} + +void PartitionIndexTransformIterator::Advance() { + ABSL_DCHECK(!at_end_); + // If callers of `cell_transform()` still hold a reference, then make a copy + // before modifying it. + cell_transform_ = MutableRep(std::move(cell_transform_)); + + // Advance to the next iterator position; this is in c-order and + // will update strided sets before index array sets. + size_t i = rank(); + while (i--) { + // Advance the iteration position for the set at index `i`. + if (i < partition_info_.index_array_sets().size()) { + position_[i] = AdvanceIndexArraySet(i); + } else { + position_[i] = AdvanceStridedSet(i); + } + if (position_[i] == upper_bound_[i]) { + if (i == 0) break; + // Reset the iteration position for the set at index `i` + // and advance to the next set. + if (i < partition_info_.index_array_sets().size()) { + ResetIndexArraySet(i); + } else { + ResetStridedSet(i); + } + continue; + } + // Update cell transforms for all updated sets. + for (; i < rank(); ++i) { + if (i < partition_info_.index_array_sets().size()) { + ApplyIndexArraySet(i); + } else { + ApplyStridedSet(i); + } + } + return; + } + // Iteration has completed. + at_end_ = true; +} + +void PartitionIndexTransformIterator::ResetIndexArraySet(size_t i) { + ABSL_CHECK_LT(i, partition_info_.index_array_sets().size()); + const IndexArraySet& index_array_set = partition_info_.index_array_sets()[i]; + position_[i] = 0; + upper_bound_[i] = index_array_set.num_partitions(); +} + +void PartitionIndexTransformIterator::ApplyIndexArraySet(size_t i) { + ABSL_CHECK_LT(position_[i], upper_bound_[i]); + ABSL_CHECK_LT(i, partition_info_.index_array_sets().size()); + const IndexArraySet& index_array_set = partition_info_.index_array_sets()[i]; + + // Assign the grid_cell_indices to the precomputed grid cell indices for + // this partition. + const Index grid_cell_indices_offset = + (position_[i]) * index_array_set.grid_dimensions.count(); + + DimensionIndex grid_i = 0; + for (DimensionIndex grid_dim : index_array_set.grid_dimensions.index_view()) { + output_grid_cell_indices_[grid_dim] = + index_array_set.grid_cell_indices[grid_cell_indices_offset + grid_i++]; + } + // Updates the cell_transform for the current index array. + UpdateCellTransformForIndexArraySetPartition(index_array_set, i, position_[i], + cell_transform_.get()); +} + +void PartitionIndexTransformIterator::ResetStridedSet(size_t i) { + ABSL_DCHECK_GE(i, partition_info_.index_array_sets().size()); + auto set_i = i - partition_info_.index_array_sets().size(); + ABSL_DCHECK_LT(set_i, partition_info_.strided_sets().size()); + + const auto& strided_set = partition_info_.strided_sets()[set_i]; + const IndexInterval domain = + transform_.input_domain()[strided_set.input_dimension]; + position_[i] = domain.inclusive_min(); + upper_bound_[i] = domain.exclusive_max(); + strided_next_position_[set_i] = domain.inclusive_min(); +} + +void PartitionIndexTransformIterator::ApplyStridedSet(size_t i) { + ABSL_DCHECK_LT(position_[i], upper_bound_[i]); + ABSL_DCHECK_GE(i, partition_info_.index_array_sets().size()); + auto set_i = i - partition_info_.index_array_sets().size(); + ABSL_DCHECK_LT(set_i, partition_info_.strided_sets().size()); + + const StridedSet& strided_set = partition_info_.strided_sets()[set_i]; + + IndexInterval restricted_domain = + IndexInterval::UncheckedHalfOpen(position_[i], upper_bound_[i]); + + // For each grid dimension in the connected set, compute the grid cell + // index corresponding to `input_index`, and constrain `restricted_domain` + // to the range of this grid cell. + for (const DimensionIndex grid_dim : + strided_set.grid_dimensions.index_view()) { + const DimensionIndex output_dim = grid_output_dimensions_[grid_dim]; + const OutputIndexMapRef<> map = transform_.output_index_map(output_dim); + IndexInterval cell_range; + output_grid_cell_indices_[grid_dim] = output_to_grid_cell_( + grid_dim, position_[i] * map.stride() + map.offset(), &cell_range); + // The check in PrePartitionIndexTransformOverGrid guarantees + // that GetAffineTransformDomain is successful. + const IndexInterval cell_domain = + GetAffineTransformDomain(cell_range, map.offset(), map.stride()) + .value(); + restricted_domain = Intersect(restricted_domain, cell_domain); + } + + ABSL_DCHECK(!restricted_domain.empty()); + + // Updates the cell transform input domain of `i`. + cell_transform_->input_origin()[i] = restricted_domain.inclusive_min(); + cell_transform_->input_shape()[i] = restricted_domain.size(); + + strided_next_position_[set_i] = restricted_domain.exclusive_max(); +} + +} // namespace internal_grid_partition +namespace internal { + +absl::Status PartitionIndexTransformOverGrid( + tensorstore::span grid_output_dimensions, + internal_grid_partition::OutputToGridCellFn output_to_grid_cell, + IndexTransformView<> transform, + absl::FunctionRef< + absl::Status(tensorstore::span grid_cell_indices, + IndexTransformView<> cell_transform)> + func) { + internal_grid_partition::IndexTransformGridPartition partition_info; + auto status = internal_grid_partition::PrePartitionIndexTransformOverGrid( + transform, grid_output_dimensions, output_to_grid_cell, partition_info); + + internal_grid_partition::PartitionIndexTransformIterator iterator( + std::move(partition_info), grid_output_dimensions, output_to_grid_cell, + transform); + while (!iterator.AtEnd()) { + TENSORSTORE_RETURN_IF_ERROR( + func(iterator.output_grid_cell_indices(), iterator.cell_transform())); + iterator.Advance(); + } + return absl::OkStatus(); +} + +} // namespace internal +} // namespace tensorstore diff --git a/tensorstore/internal/grid_partition_iterator.h b/tensorstore/internal/grid_partition_iterator.h new file mode 100644 index 000000000..666738171 --- /dev/null +++ b/tensorstore/internal/grid_partition_iterator.h @@ -0,0 +1,177 @@ +// Copyright 2024 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORSTORE_INTERNAL_GRID_PARTITION_ITERATOR_H_ +#define TENSORSTORE_INTERNAL_GRID_PARTITION_ITERATOR_H_ + +#include + +#include + +#include "absl/container/fixed_array.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "tensorstore/index.h" +#include "tensorstore/index_interval.h" +#include "tensorstore/index_space/index_transform.h" +#include "tensorstore/index_space/internal/transform_rep.h" +#include "tensorstore/internal/grid_partition_impl.h" +#include "tensorstore/util/iterate.h" +#include "tensorstore/util/span.h" + +namespace tensorstore { +namespace internal_grid_partition { + +/// For a given DimensionIndex dimension, returns the grid cell index +/// corresponding to the output_index, optionally filling the bounds for the +/// cell. +/// Implemented by `RegularGrid` and `IrregularGrid`, for example. +using OutputToGridCellFn = absl::FunctionRef; + +/// Iterator for constructing the `cell_transform` for each grid cell. +/// Requires IndexTransformGridPartition to be precomputed. +/// +/// After construction, the iterator is in an initial state where AtEnd may +/// be true, otherwise `output_grid_cell_indices()` and `cell_transform()` +/// are valid. +/// +/// To advance to the next grid cell, call `Advance()`. +class PartitionIndexTransformIterator { + public: + PartitionIndexTransformIterator( + internal_grid_partition::IndexTransformGridPartition&& partition_info, + tensorstore::span grid_output_dimensions, + OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform); + + // Indices to the current grid cell. + tensorstore::span output_grid_cell_indices() const { + return output_grid_cell_indices_; + } + + /// View of the current cell transform. + IndexTransformView<> cell_transform() { + return internal_index_space::TransformAccess::Make>( + cell_transform_.get()); + } + + /// Indicates whether iteration has completed. + /// When false, both cell_transform() and output_grid_cell_indices() are + /// valid. + bool AtEnd() const { return at_end_; } + + // Advance the iterator. + void Advance(); + + private: + size_t rank() const { + return partition_info_.index_array_sets_.size() + + partition_info_.strided_sets_.size(); + } + + // Advance the iteration position for the index array set at index `i`. + Index AdvanceIndexArraySet(size_t i) { return position_[i] + 1; } + + // Reset the iteration position for the index array set at index `i`. + void ResetIndexArraySet(size_t i); + + // For grid cell, `i`, updates the `output_grid_cell_indices` for the given + // index array as well as the cell_transform. + void ApplyIndexArraySet(size_t i); + + // Advance the iteration position for the strided set at index `i`. + // Assumes that ApplyStridedSet(i) was previously invoked. + Index AdvanceStridedSet(size_t i) { + ABSL_DCHECK_GE(i, partition_info_.index_array_sets().size()); + auto set_i = i - partition_info_.index_array_sets().size(); + ABSL_DCHECK_LT(set_i, partition_info_.strided_sets().size()); + return strided_next_position_[set_i]; + } + + // Reset the iteration position for the strided set at index `i`. + void ResetStridedSet(size_t i); + + // For grid cell, `i`, updates the `output_grid_cell_indices` and the + // `cell_transform` for the associated strided set. + void ApplyStridedSet(size_t i); + + internal_grid_partition::IndexTransformGridPartition partition_info_; + absl::FixedArray + grid_output_dimensions_; + + OutputToGridCellFn output_to_grid_cell_; + IndexTransformView<> transform_; + bool at_end_; + + // This stores the current value of `cell_transform[h]`, as defined in + // grid_partition.h, for `h = grid_cell_indices_`. This is modified in + // place while iterating over all values for grid_cell_indices_. + internal_index_space::TransformRep::Ptr<> cell_transform_; + + // Current grid cell index vector `h` as defined in grid_partition.h, modified + // in place while iterating over all index vectors in `H`. + absl::FixedArray output_grid_cell_indices_; + + // Iteration position for each connected set. + absl::FixedArray position_; + absl::FixedArray upper_bound_; + + // The next start position for each strided set. + absl::FixedArray strided_next_position_; +}; + +} // namespace internal_grid_partition +namespace internal { + +/// Partitions the input domain of a given `transform` from an input space +/// "full" to an output space "output" based on the grid (potentially irregular) +/// specified by `output_to_grid_cell`, which maps from a given dimension and +/// output_index to a grid cell and optional cell bounds. +/// +/// For each grid cell index vector `h` in `H`, calls +/// `func(h, cell_transform[h])`. +/// +/// To partition over a regular grid, `output_to_grid_cell` can be +/// internal_grid_partition::RegularGridRef. +/// +/// \param grid_output_dimensions The sequence of dimensions of the index space +/// "output" corresponding to the grid by which to partition "full". +/// \param output_to_grid_cell Function returning, for the provided grid +/// dimension, the cell index corresponding to output_index, optionally +/// filling the bounds for the cell. +/// \param transform The index transform from "full" to "output". Must be +/// valid. +/// \param func The function to be called for each partition. May return an +/// error `absl::Status` to abort the iteration. +/// \returns `absl::Status()` on success, or the last error returned by `func`. +/// \error `absl::StatusCode::kInvalidArgument` if any input dimension of +/// `transform` has an unbounded domain. +/// \error `absl::StatusCode::kInvalidArgument` if integer overflow occurs. +/// \error `absl::StatusCode::kOutOfRange` if an index array contains an +/// out-of-bounds index. +absl::Status PartitionIndexTransformOverGrid( + tensorstore::span grid_output_dimensions, + internal_grid_partition::OutputToGridCellFn output_to_grid_cell, + IndexTransformView<> transform, + absl::FunctionRef< + absl::Status(tensorstore::span grid_cell_indices, + IndexTransformView<> cell_transform)> + func); + +} // namespace internal +} // namespace tensorstore + +#endif // TENSORSTORE_INTERNAL_GRID_PARTITION_ITERATOR_H_ diff --git a/tensorstore/internal/grid_partition_iterator_test.cc b/tensorstore/internal/grid_partition_iterator_test.cc new file mode 100644 index 000000000..92dbb2236 --- /dev/null +++ b/tensorstore/internal/grid_partition_iterator_test.cc @@ -0,0 +1,699 @@ +// Copyright 2020 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/internal/grid_partition_iterator.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "tensorstore/array.h" +#include "tensorstore/index.h" +#include "tensorstore/index_interval.h" +#include "tensorstore/index_space/index_transform.h" +#include "tensorstore/index_space/index_transform_builder.h" +#include "tensorstore/internal/grid_partition_impl.h" +#include "tensorstore/internal/regular_grid.h" +#include "tensorstore/util/result.h" +#include "tensorstore/util/span.h" +#include "tensorstore/util/status.h" + +namespace { +using ::tensorstore::DimensionIndex; +using ::tensorstore::Index; +using ::tensorstore::IndexInterval; +using ::tensorstore::IndexTransform; +using ::tensorstore::IndexTransformBuilder; +using ::tensorstore::IndexTransformView; +using ::tensorstore::MakeArray; +using ::tensorstore::internal_grid_partition::IndexTransformGridPartition; +using ::tensorstore::internal_grid_partition::OutputToGridCellFn; +using ::tensorstore::internal_grid_partition::PartitionIndexTransformIterator; +using ::tensorstore::internal_grid_partition:: + PrePartitionIndexTransformOverGrid; +using ::tensorstore::internal_grid_partition::RegularGridRef; +using ::testing::ElementsAre; + +/// Representation of a partition, specifically the arguments supplied to the +/// callback passed to `PartitionIndexTransformOverRegularGrid`. This is a +/// pair of: +/// +/// 0. Grid cell index vector +/// 1. cell_transform transform +using R = std::pair, IndexTransform<>>; + +/// Returns the list of partitions generated by +/// `PartitionIndexTransformOverRegularGrid` when called with the specified +/// arguments. +/// +/// \param grid_output_dimensions The sequence of output dimensions of the index +/// space "output" corresponding to the grid. +/// \param grid_cell_shape Array of length `grid_output_dimensions.size()` +/// specifying the cell of a grid cell along each grid dimension. +/// \param transform A transform from the "full" input space to the "output" +/// index space. +/// \returns The list of partitions. +std::vector GetPartitions( + const std::vector& grid_output_dimensions, + const std::vector& grid_cell_shape, IndexTransformView<> transform) { + std::vector results; + + IndexTransformGridPartition info; + RegularGridRef grid{grid_cell_shape}; + TENSORSTORE_CHECK_OK(PrePartitionIndexTransformOverGrid( + transform, grid_output_dimensions, grid, info)); + TENSORSTORE_CHECK_OK(tensorstore::internal::PartitionIndexTransformOverGrid( + grid_output_dimensions, grid, transform, + [&](tensorstore::span grid_cell_indices, + IndexTransformView<> cell_transform) { + auto cell_transform_direct = info.GetCellTransform( + transform, grid_cell_indices, grid_output_dimensions, + [&](DimensionIndex dim, Index cell_index) { + return grid.GetCellOutputInterval(dim, cell_index); + }); + EXPECT_EQ(cell_transform_direct, cell_transform); + results.emplace_back(std::vector(grid_cell_indices.begin(), + grid_cell_indices.end()), + IndexTransform<>(cell_transform)); + return absl::OkStatus(); + })); + return results; +} + +std::vector GetPartitionsManual( + const std::vector& grid_output_dimensions, + const std::vector& grid_cell_shape, IndexTransformView<> transform) { + std::vector results; + + IndexTransformGridPartition info; + + RegularGridRef grid{grid_cell_shape}; + TENSORSTORE_CHECK_OK(PrePartitionIndexTransformOverGrid( + transform, grid_output_dimensions, grid, info)); + + PartitionIndexTransformIterator iterator( + std::move(info), grid_output_dimensions, grid, std::move(transform)); + while (!iterator.AtEnd()) { + results.emplace_back( + std::vector(iterator.output_grid_cell_indices().begin(), + iterator.output_grid_cell_indices().end()), + IndexTransform<>(iterator.cell_transform())); + iterator.Advance(); + } + + return results; +} + +// Tests that an empty input shape is handled correctly. +TEST(PartitionIndexTransformOverRegularGrid, EmptyDimension) { + const auto results = GetPartitions({0, 1}, {3, 3}, + IndexTransformBuilder<>(2, 2) + .input_shape({0, 0}) + .output_constant(0, 3) + .output_single_input_dimension(1, 1) + .Finalize() + .value()); + + EXPECT_THAT(results, ::testing::IsEmpty()); // +} + +// Tests that a one-dimensional transform with a constant output map is +// partitioned into 1 part. +TEST(PartitionIndexTransformOverRegularGrid, ConstantOneDimensional) { + const auto results = GetPartitions({0}, {2}, + IndexTransformBuilder<>(1, 1) + .input_origin({2}) + .input_shape({4}) + .output_constant(0, 3) + .Finalize() + .value()); + // Input index: 2 3 4 5 + // Output index: 3 + // Grid index: 1 + // = Output index / 2 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{1}, + IndexTransformBuilder<>(1, 1) + .input_origin({2}) + .input_shape({4}) + .output_single_input_dimension(0, 0) + .Finalize() + .value()})); +} + +// Tests that a two-dimensional transform with constant output maps is +// partitioned into 1 part. +TEST(PartitionIndexTransformOverRegularGrid, ConstantTwoDimensional) { + const auto results = GetPartitions({0, 1}, {2, 3}, + IndexTransformBuilder<>(2, 2) + .input_origin({2, 3}) + .input_shape({4, 5}) + .output_constant(0, 3) + .output_constant(1, 7) + .Finalize() + .value()); + // Input index 0: 2 3 4 5 + // Input index 1: 3 4 5 6 7 + + // Output index 0: 3 + // Grid index 0: 1 + // = Output index / 2 + // + // Output index 1: 7 + // Grid index 0: 2 + // = Output index / 3 + + EXPECT_THAT( // + results, // + ElementsAre( // + R{{1, 2}, + IndexTransformBuilder<>(2, 2) + .input_origin({2, 3}) + .input_shape({4, 5}) + .output_identity_transform() + .Finalize() + .value()})); +} + +// Tests that a one-dimensional identity transform over the domain `[-4,1]` with +// a cell size of `2` is partitioned into 3 parts, with the domains: `[-4,-3]`, +// `[-2,-1]`, and `[0,0]`. +TEST(PartitionIndexTransformOverRegularGrid, OneDimensionalUnitStride) { + const auto results = GetPartitions({0}, {2}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({5}) + .output_identity_transform() + .Finalize() + .value()); + // Input index: -4 -3 -2 -1 0 + // Output index: -4 -3 -2 -1 0 + // = Input index + // Grid index: -2 -2 -1 -1 0 + // = Output index / 2 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{-2}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({2}) + .output_identity_transform() + .Finalize() + .value()}, + R{{-1}, + IndexTransformBuilder<>(1, 1) + .input_origin({-2}) + .input_shape({2}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_identity_transform() + .Finalize() + .value()})); +} + +// Tests that a 2-d identity-mapped input domain over `[0,30)*[0,30)` with a +// grid size of `{20,10}` is correctly partitioned in 6 parts, with domains: +// `[0,20)*[0,10)`, `[0,20)*[10,20)`, `[0,20)*[20,30)`, `[20,30)*[0,10)`, +// `[20,30)*[10,20)`, `[20,30)*[20,30)`, +TEST(PartitionIndexTransformOverRegularGrid, TwoDimensionalIdentity) { + const auto results = GetPartitions({0, 1}, {20, 10}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, 0}) + .input_shape({30, 30}) + .output_identity_transform() + .Finalize() + .value()); + EXPECT_THAT( // + results, // + ElementsAre( // + R{{0, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, 0}) + .input_shape({20, 10}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, 10}) + .input_shape({20, 10}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0, 2}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, 20}) + .input_shape({20, 10}) + .output_identity_transform() + .Finalize() + .value()}, + R{{1, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({20, 0}) + .input_shape({10, 10}) + .output_identity_transform() + .Finalize() + .value()}, + R{{1, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({20, 10}) + .input_shape({10, 10}) + .output_identity_transform() + .Finalize() + .value()}, + R{{1, 2}, + IndexTransformBuilder<>(2, 2) + .input_origin({20, 20}) + .input_shape({10, 10}) + .output_identity_transform() + .Finalize() + .value()})); +} + +// Same as previous test, but with non-unit stride and a cell size of 10. The +// input domain `[-4,1]` is partitioned into 2 parts, with the domains `[-4,-2]` +// and `[-1,1]`. +TEST(PartitionIndexTransformOverRegularGrid, SingleStridedDimension) { + const auto results = + GetPartitions({0}, {10}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({6}) + .output_single_input_dimension(0, 5, 3, 0) + .Finalize() + .value()); + // Input index: -4 -3 -2 -1 0 1 + // Output index: -7 -4 -1 2 5 8 + // = 5 + 3 * Input index + // Grid index: -1 -1 -1 0 0 0 + // = Output index / 10 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{-1}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({3}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0}, + IndexTransformBuilder<>(1, 1) + .input_origin({-1}) + .input_shape({3}) + .output_identity_transform() + .Finalize() + .value()})); +} + +TEST(PartitionIndexTransformOverRegularGrid, SingleStridedDimensionManual) { + const auto results = + GetPartitionsManual({0}, {10}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({6}) + .output_single_input_dimension(0, 5, 3, 0) + .Finalize() + .value()); + // Input index: -4 -3 -2 -1 0 1 + // Output index: -7 -4 -1 2 5 8 + // = 5 + 3 * Input index + // Grid index: -1 -1 -1 0 0 0 + // = Output index / 10 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{-1}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({3}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0}, + IndexTransformBuilder<>(1, 1) + .input_origin({-1}) + .input_shape({3}) + .output_identity_transform() + .Finalize() + .value()})); +} + +// Tests that a diagonal transform that maps two different gridded output +// dimension to a single input dimension, where a different cell size is used +// for the two grid dimensions, is partitioned into 3 parts, with domains +// `[-4,-2]`, `[-1,-1]`, and `[0,1]`. +TEST(PartitionIndexTransformOverRegularGrid, DiagonalStridedDimensions) { + const auto results = + GetPartitions({0, 1}, {10, 8}, + IndexTransformBuilder<>(1, 2) + .input_origin({-4}) + .input_shape({6}) + .output_single_input_dimension(0, 5, 3, 0) + .output_single_input_dimension(1, 7, -2, 0) + .Finalize() + .value()); + // Input index: -4 -3 -2 -1 0 1 + // + // Output index 0: -7 -4 -1 2 5 8 + // = 5 + 3 * Input index 0 + // Grid index 0: -1 -1 -1 0 0 0 + // = Output index 0 / 10 + // + // Output index 1: 15 13 11 9 7 5 + // = 7 - 2 * Input index 1 + // Grid index 0: 1 1 1 1 0 0 + // = Output index 1 / 8 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{-1, 1}, + IndexTransformBuilder<>(1, 1) + .input_origin({-4}) + .input_shape({3}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0, 1}, + IndexTransformBuilder<>(1, 1) + .input_origin({-1}) + .input_shape({1}) + .output_identity_transform() + .Finalize() + .value()}, + R{{0, 0}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({2}) + .output_identity_transform() + .Finalize() + .value()})); +} + +// Tests that a transform that maps via an index array the domain `[100,107]` -> +// `[1,8]`, when partitioned using a grid cell size of 3, results in 3 parts +// with domains: {100, 101}, {102, 103, 104}, and {105, 106, 107}. +TEST(PartitionIndexTransformOverRegularGrid, SingleIndexArrayDimension) { + const auto results = + GetPartitions({0}, {3}, + IndexTransformBuilder<>(1, 1) + .input_origin({100}) + .input_shape({8}) + .output_index_array( + 0, 0, 1, MakeArray({1, 2, 3, 4, 5, 6, 7, 8})) + .Finalize() + .value()); + // Input index: 100 101 102 103 104 105 106 107 + // Index array : 1 2 3 4 5 6 7 8 + // Output index: 1 2 3 4 5 6 7 8 + // Grid index: 0 0 1 1 1 2 2 2 + EXPECT_THAT( // + results, // + ElementsAre( + R{{0}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({2}) + .output_index_array(0, 0, 1, MakeArray({100, 101})) + .Finalize() + .value()}, + R{{1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({3}) + .output_index_array(0, 0, 1, MakeArray({102, 103, 104})) + .Finalize() + .value()}, + R{{2}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({3}) + .output_index_array(0, 0, 1, MakeArray({105, 106, 107})) + .Finalize() + .value()})); +} + +// Tests that a transform with a single gridded output dimension with an `array` +// map from a single input dimension with non-unit stride is correctly +// partitioned. +TEST(PartitionIndexTransformOverRegularGrid, SingleIndexArrayDimensionStrided) { + const auto results = GetPartitions( + {0}, {10}, + IndexTransformBuilder<>(1, 1) + .input_origin({100}) + .input_shape({6}) + .output_index_array(0, 5, 3, MakeArray({10, 3, 4, -5, -6, 11})) + .Finalize() + .value()); + // Input index: 100 101 102 103 104 105 + // Index array: 10 3 4 -5 -6 11 + // Output index: 35 14 17 -10 -13 38 + // = 5 + 3 * Index array + // Grid index: 3 1 1 -1 -2 3 + // = Output index / 3 + EXPECT_THAT( // + results, // + ElementsAre( // + R{{-2}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({104})) + .Finalize() + .value()}, + R{{-1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({103})) + .Finalize() + .value()}, + R{{1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({2}) + .output_index_array(0, 0, 1, MakeArray({101, 102})) + .Finalize() + .value()}, + R{{3}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({2}) + .output_index_array(0, 0, 1, MakeArray({100, 105})) + .Finalize() + .value()})); +} + +// Tests that an index transform with two gridded output dimensions that are +// mapped using an `array` output index map from a single input dimension, which +// leads to a single connected set, is correctly handled. +TEST(PartitionIndexTransformOverRegularGrid, TwoIndexArrayDimensions) { + const auto results = GetPartitions( + {0, 1}, {10, 8}, + IndexTransformBuilder<>(1, 2) + .input_origin({100}) + .input_shape({6}) + .output_index_array(0, 5, 3, MakeArray({10, 3, 4, -5, -6, 11})) + .output_index_array(1, 4, -2, MakeArray({5, 1, 7, -3, -2, 5})) + .Finalize() + .value()); + // Input index: 100 101 102 103 104 105 + // + // Index array 0: 10 3 4 -5 -6 11 + // Output index 0: 35 14 17 -10 -13 38 + // = 5 + 3 * Index array 0 + // Grid index 0: 3 1 1 -1 -2 3 + // = Output index 0 / 10 + // + // Index array 1: 5 1 7 -3 -2 5 + // Output index 1: -6 2 -10 10 8 -6 + // = 4 - 2 * Index array 1 + // Grid index 1: -1 0 -2 2 1 -1 + // = Output index 1 / 8 + EXPECT_THAT( + results, + ElementsAre( + R{{-2, 1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({104})) + .Finalize() + .value()}, + R{{-1, 1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({103})) + .Finalize() + .value()}, + R{{1, -2}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({102})) + .Finalize() + .value()}, + R{{1, 0}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({1}) + .output_index_array(0, 0, 1, MakeArray({101})) + .Finalize() + .value()}, + R{{3, -1}, + IndexTransformBuilder<>(1, 1) + .input_origin({0}) + .input_shape({2}) + .output_index_array(0, 0, 1, MakeArray({100, 105})) + .Finalize() + .value()})); +} + +// Tests that a index transform with a gridded `array` output dimension that +// depends on one input dimension, and a gridded `single_input_dimension` output +// dimension that depends on the other input dimension, which leads to two +// connected sets, is handled correctly. +TEST(PartitionIndexTransformOverRegularGrid, IndexArrayAndStridedDimensions) { + const auto results = GetPartitions( + {0, 1}, {10, 8}, + IndexTransformBuilder<>(2, 2) + .input_origin({-4, 100}) + .input_shape({6, 3}) + .output_index_array(0, 5, 3, MakeArray({{10, 3, 4}})) + .output_single_input_dimension(1, 4, -2, 0) + .Finalize() + .value()); + + // Input index 1: 100 101 102 + // Index array 0: 10 3 4 + // Output index 0: 35 14 17 + // = 5 + 3 * Index array 0 + // Grid index 0: 3 1 1 + // = Output index 0 / 10 + // + // Input index 0: -4 -3 -2 -1 0 1 + // Output index 1: 12 10 8 6 4 2 + // = 4 - 2 * Input index 0 + // Grid index 1: 1 1 1 0 0 0 + // = Output index 1 / 8 + EXPECT_THAT( + results, + ElementsAre( + R{{1, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -4}) + .input_shape({2, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) + .Finalize() + .value()}, + R{{1, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -1}) + .input_shape({2, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) + .Finalize() + .value()}, + R{{3, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -4}) + .input_shape({1, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{100}})) + .Finalize() + .value()}, + R{{3, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -1}) + .input_shape({1, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{100}})) + .Finalize() + .value()})); +} + +TEST(PartitionIndexTransformOverRegularGrid, + IndexArrayAndStridedDimensionsManual) { + const auto results = GetPartitionsManual( + {0, 1}, {10, 8}, + IndexTransformBuilder<>(2, 2) + .input_origin({-4, 100}) + .input_shape({6, 3}) + .output_index_array(0, 5, 3, MakeArray({{10, 3, 4}})) + .output_single_input_dimension(1, 4, -2, 0) + .Finalize() + .value()); + + // Input index 1: 100 101 102 + // Index array 0: 10 3 4 + // Output index 0: 35 14 17 + // = 5 + 3 * Index array 0 + // Grid index 0: 3 1 1 + // = Output index 0 / 10 + // + // Input index 0: -4 -3 -2 -1 0 1 + // Output index 1: 12 10 8 6 4 2 + // = 4 - 2 * Input index 0 + // Grid index 1: 1 1 1 0 0 0 + // = Output index 1 / 8 + EXPECT_THAT( + results, + ElementsAre( + R{{1, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -4}) + .input_shape({2, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) + .Finalize() + .value()}, + R{{1, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -1}) + .input_shape({2, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) + .Finalize() + .value()}, + R{{3, 1}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -4}) + .input_shape({1, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{100}})) + .Finalize() + .value()}, + R{{3, 0}, + IndexTransformBuilder<>(2, 2) + .input_origin({0, -1}) + .input_shape({1, 3}) + .output_single_input_dimension(0, 1) + .output_index_array(1, 0, 1, MakeArray({{100}})) + .Finalize() + .value()})); +} + +} // namespace diff --git a/tensorstore/internal/grid_partition_test.cc b/tensorstore/internal/grid_partition_test.cc index 876d37d7e..d43c0b81b 100644 --- a/tensorstore/internal/grid_partition_test.cc +++ b/tensorstore/internal/grid_partition_test.cc @@ -14,7 +14,6 @@ #include "tensorstore/internal/grid_partition.h" -#include #include #include @@ -53,711 +52,6 @@ using ::tensorstore::internal_grid_partition:: using ::tensorstore::internal_grid_partition::RegularGridRef; using ::testing::ElementsAre; -namespace partition_tests { -/// Representation of a partition, specifically the arguments supplied to the -/// callback passed to `PartitionIndexTransformOverRegularGrid`. This is a -/// pair of: -/// -/// 0. Grid cell index vector -/// 1. cell_transform transform -using R = std::pair, IndexTransform<>>; - -/// Returns the list of partitions generated by -/// `PartitionIndexTransformOverRegularGrid` when called with the specified -/// arguments. -/// -/// \param grid_output_dimensions The sequence of output dimensions of the index -/// space "output" corresponding to the grid. -/// \param grid_cell_shape Array of length `grid_output_dimensions.size()` -/// specifying the cell of a grid cell along each grid dimension. -/// \param transform A transform from the "full" input space to the "output" -/// index space. -/// \returns The list of partitions. -std::vector GetPartitions( - const std::vector& grid_output_dimensions, - const std::vector& grid_cell_shape, IndexTransformView<> transform) { - std::vector results; - - IndexTransformGridPartition info; - RegularGridRef grid{grid_cell_shape}; - TENSORSTORE_CHECK_OK(PrePartitionIndexTransformOverGrid( - transform, grid_output_dimensions, grid, info)); - TENSORSTORE_CHECK_OK( - tensorstore::internal::PartitionIndexTransformOverGrid( - grid_output_dimensions, grid, transform, - [&](tensorstore::span grid_cell_indices, - IndexTransformView<> cell_transform) { - auto cell_transform_direct = info.GetCellTransform( - transform, grid_cell_indices, grid_output_dimensions, - [&](DimensionIndex dim, Index cell_index) { - return grid.GetCellOutputInterval(dim, cell_index); - }); - EXPECT_EQ(cell_transform_direct, cell_transform); - results.emplace_back(std::vector(grid_cell_indices.begin(), - grid_cell_indices.end()), - IndexTransform<>(cell_transform)); - return absl::OkStatus(); - })); - return results; -} - -// Tests that a one-dimensional transform with a constant output map is -// partitioned into 1 part. -TEST(PartitionIndexTransformOverRegularGrid, ConstantOneDimensional) { - const auto results = GetPartitions({0}, {2}, - IndexTransformBuilder<>(1, 1) - .input_origin({2}) - .input_shape({4}) - .output_constant(0, 3) - .Finalize() - .value()); - // Input index: 2 3 4 5 - // Output index: 3 - // Grid index: 1 - // = Output index / 2 - EXPECT_THAT( // - results, // - ElementsAre( // - R{{1}, - IndexTransformBuilder<>(1, 1) - .input_origin({2}) - .input_shape({4}) - .output_single_input_dimension(0, 0) - .Finalize() - .value()})); -} - -// Tests that a two-dimensional transform with constant output maps is -// partitioned into 1 part. -TEST(PartitionIndexTransformOverRegularGrid, ConstantTwoDimensional) { - const auto results = GetPartitions({0, 1}, {2, 3}, - IndexTransformBuilder<>(2, 2) - .input_origin({2, 3}) - .input_shape({4, 5}) - .output_constant(0, 3) - .output_constant(1, 7) - .Finalize() - .value()); - // Input index 0: 2 3 4 5 - // Input index 1: 3 4 5 6 7 - - // Output index 0: 3 - // Grid index 0: 1 - // = Output index / 2 - // - // Output index 1: 7 - // Grid index 0: 2 - // = Output index / 3 - - EXPECT_THAT( // - results, // - ElementsAre( // - R{{1, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({2, 3}) - .input_shape({4, 5}) - .output_identity_transform() - .Finalize() - .value()})); -} - -// Tests that a one-dimensional identity transform over the domain `[-4,1]` with -// a cell size of `2` is partitioned into 3 parts, with the domains: `[-4,-3]`, -// `[-2,-1]`, and `[0,0]`. -TEST(PartitionIndexTransformOverRegularGrid, OneDimensionalUnitStride) { - const auto results = GetPartitions({0}, {2}, - IndexTransformBuilder<>(1, 1) - .input_origin({-4}) - .input_shape({5}) - .output_identity_transform() - .Finalize() - .value()); - // Input index: -4 -3 -2 -1 0 - // Output index: -4 -3 -2 -1 0 - // = Input index - // Grid index: -2 -2 -1 -1 0 - // = Output index / 2 - EXPECT_THAT( // - results, // - ElementsAre( // - R{{-2}, - IndexTransformBuilder<>(1, 1) - .input_origin({-4}) - .input_shape({2}) - .output_identity_transform() - .Finalize() - .value()}, - R{{-1}, - IndexTransformBuilder<>(1, 1) - .input_origin({-2}) - .input_shape({2}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_identity_transform() - .Finalize() - .value()})); -} - -// Tests that a 2-d identity-mapped input domain over `[0,30)*[0,30)` with a -// grid size of `{20,10}` is correctly partitioned in 6 parts, with domains: -// `[0,20)*[0,10)`, `[0,20)*[10,20)`, `[0,20)*[20,30)`, `[20,30)*[0,10)`, -// `[20,30)*[10,20)`, `[20,30)*[20,30)`, -TEST(PartitionIndexTransformOverRegularGrid, TwoDimensionalIdentity) { - const auto results = GetPartitions({0, 1}, {20, 10}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({30, 30}) - .output_identity_transform() - .Finalize() - .value()); - EXPECT_THAT( // - results, // - ElementsAre( // - R{{0, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({20, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 10}) - .input_shape({20, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 20}) - .input_shape({20, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{1, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({20, 0}) - .input_shape({10, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{1, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({20, 10}) - .input_shape({10, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{1, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({20, 20}) - .input_shape({10, 10}) - .output_identity_transform() - .Finalize() - .value()})); -} - -// Same as previous test, but with non-unit stride and a cell size of 10. The -// input domain `[-4,1]` is partitioned into 2 parts, with the domains `[-4,-2]` -// and `[-1,1]`. -TEST(PartitionIndexTransformOverRegularGrid, SingleStridedDimension) { - const auto results = - GetPartitions({0}, {10}, - IndexTransformBuilder<>(1, 1) - .input_origin({-4}) - .input_shape({6}) - .output_single_input_dimension(0, 5, 3, 0) - .Finalize() - .value()); - // Input index: -4 -3 -2 -1 0 1 - // Output index: -7 -4 -1 2 5 8 - // = 5 + 3 * Input index - // Grid index: -1 -1 -1 0 0 0 - // = Output index / 10 - EXPECT_THAT( // - results, // - ElementsAre( // - R{{-1}, - IndexTransformBuilder<>(1, 1) - .input_origin({-4}) - .input_shape({3}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0}, - IndexTransformBuilder<>(1, 1) - .input_origin({-1}) - .input_shape({3}) - .output_identity_transform() - .Finalize() - .value()})); -} - -// Tests that a diagonal transform that maps two different gridded output -// dimension to a single input dimension, where a different cell size is used -// for the two grid dimensions, is partitioned into 3 parts, with domains -// `[-4,-2]`, `[-1,-1]`, and `[0,1]`. -TEST(PartitionIndexTransformOverRegularGrid, DiagonalStridedDimensions) { - const auto results = - GetPartitions({0, 1}, {10, 8}, - IndexTransformBuilder<>(1, 2) - .input_origin({-4}) - .input_shape({6}) - .output_single_input_dimension(0, 5, 3, 0) - .output_single_input_dimension(1, 7, -2, 0) - .Finalize() - .value()); - // Input index: -4 -3 -2 -1 0 1 - // - // Output index 0: -7 -4 -1 2 5 8 - // = 5 + 3 * Input index 0 - // Grid index 0: -1 -1 -1 0 0 0 - // = Output index 0 / 10 - // - // Output index 1: 15 13 11 9 7 5 - // = 7 - 2 * Input index 1 - // Grid index 0: 1 1 1 1 0 0 - // = Output index 1 / 8 - EXPECT_THAT( // - results, // - ElementsAre( // - R{{-1, 1}, - IndexTransformBuilder<>(1, 1) - .input_origin({-4}) - .input_shape({3}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 1}, - IndexTransformBuilder<>(1, 1) - .input_origin({-1}) - .input_shape({1}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 0}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({2}) - .output_identity_transform() - .Finalize() - .value()})); -} - -// Tests that a transform that maps via an index array the domain `[100,107]` -> -// `[1,8]`, when partitioned using a grid cell size of 3, results in 3 parts -// with domains: {100, 101}, {102, 103, 104}, and {105, 106, 107}. -TEST(PartitionIndexTransformOverRegularGrid, SingleIndexArrayDimension) { - const auto results = - GetPartitions({0}, {3}, - IndexTransformBuilder<>(1, 1) - .input_origin({100}) - .input_shape({8}) - .output_index_array( - 0, 0, 1, MakeArray({1, 2, 3, 4, 5, 6, 7, 8})) - .Finalize() - .value()); - // Input index: 100 101 102 103 104 105 106 107 - // Index array : 1 2 3 4 5 6 7 8 - // Output index: 1 2 3 4 5 6 7 8 - // Grid index: 0 0 1 1 1 2 2 2 - EXPECT_THAT( // - results, // - ElementsAre( - R{{0}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({2}) - .output_index_array(0, 0, 1, MakeArray({100, 101})) - .Finalize() - .value()}, - R{{1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({3}) - .output_index_array(0, 0, 1, MakeArray({102, 103, 104})) - .Finalize() - .value()}, - R{{2}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({3}) - .output_index_array(0, 0, 1, MakeArray({105, 106, 107})) - .Finalize() - .value()})); -} - -// Tests that a transform with a single gridded output dimension with an `array` -// map from a single input dimension with non-unit stride is correctly -// partitioned. -TEST(PartitionIndexTransformOverRegularGrid, SingleIndexArrayDimensionStrided) { - const auto results = GetPartitions( - {0}, {10}, - IndexTransformBuilder<>(1, 1) - .input_origin({100}) - .input_shape({6}) - .output_index_array(0, 5, 3, MakeArray({10, 3, 4, -5, -6, 11})) - .Finalize() - .value()); - // Input index: 100 101 102 103 104 105 - // Index array: 10 3 4 -5 -6 11 - // Output index: 35 14 17 -10 -13 38 - // = 5 + 3 * Index array - // Grid index: 3 1 1 -1 -2 3 - // = Output index / 3 - EXPECT_THAT( // - results, // - ElementsAre( // - R{{-2}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({104})) - .Finalize() - .value()}, - R{{-1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({103})) - .Finalize() - .value()}, - R{{1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({2}) - .output_index_array(0, 0, 1, MakeArray({101, 102})) - .Finalize() - .value()}, - R{{3}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({2}) - .output_index_array(0, 0, 1, MakeArray({100, 105})) - .Finalize() - .value()})); -} - -// Tests that an index transform with two gridded output dimensions that are -// mapped using an `array` output index map from a single input dimension, which -// leads to a single connected set, is correctly handled. -TEST(PartitionIndexTransformOverRegularGrid, TwoIndexArrayDimensions) { - const auto results = GetPartitions( - {0, 1}, {10, 8}, - IndexTransformBuilder<>(1, 2) - .input_origin({100}) - .input_shape({6}) - .output_index_array(0, 5, 3, MakeArray({10, 3, 4, -5, -6, 11})) - .output_index_array(1, 4, -2, MakeArray({5, 1, 7, -3, -2, 5})) - .Finalize() - .value()); - // Input index: 100 101 102 103 104 105 - // - // Index array 0: 10 3 4 -5 -6 11 - // Output index 0: 35 14 17 -10 -13 38 - // = 5 + 3 * Index array 0 - // Grid index 0: 3 1 1 -1 -2 3 - // = Output index 0 / 10 - // - // Index array 1: 5 1 7 -3 -2 5 - // Output index 1: -6 2 -10 10 8 -6 - // = 4 - 2 * Index array 1 - // Grid index 1: -1 0 -2 2 1 -1 - // = Output index 1 / 8 - EXPECT_THAT( - results, - ElementsAre( - R{{-2, 1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({104})) - .Finalize() - .value()}, - R{{-1, 1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({103})) - .Finalize() - .value()}, - R{{1, -2}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({102})) - .Finalize() - .value()}, - R{{1, 0}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({1}) - .output_index_array(0, 0, 1, MakeArray({101})) - .Finalize() - .value()}, - R{{3, -1}, - IndexTransformBuilder<>(1, 1) - .input_origin({0}) - .input_shape({2}) - .output_index_array(0, 0, 1, MakeArray({100, 105})) - .Finalize() - .value()})); -} - -// Tests that a index transform with a gridded `array` output dimension that -// depends on one input dimension, and a gridded `single_input_dimension` output -// dimension that depends on the other input dimension, which leads to two -// connected sets, is handled correctly. -TEST(PartitionIndexTransformOverRegularGrid, IndexArrayAndStridedDimensions) { - const auto results = GetPartitions( - {0, 1}, {10, 8}, - IndexTransformBuilder<>(2, 2) - .input_origin({-4, 100}) - .input_shape({6, 3}) - .output_index_array(0, 5, 3, MakeArray({{10, 3, 4}})) - .output_single_input_dimension(1, 4, -2, 0) - .Finalize() - .value()); - - // Input index 1: 100 101 102 - // Index array 0: 10 3 4 - // Output index 0: 35 14 17 - // = 5 + 3 * Index array 0 - // Grid index 0: 3 1 1 - // = Output index 0 / 10 - // - // Input index 0: -4 -3 -2 -1 0 1 - // Output index 1: 12 10 8 6 4 2 - // = 4 - 2 * Input index 0 - // Grid index 1: 1 1 1 0 0 0 - // = Output index 1 / 8 - EXPECT_THAT( - results, - ElementsAre( - R{{1, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -4}) - .input_shape({2, 3}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) - .Finalize() - .value()}, - R{{1, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -1}) - .input_shape({2, 3}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{101}, {102}})) - .Finalize() - .value()}, - R{{3, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -4}) - .input_shape({1, 3}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{100}})) - .Finalize() - .value()}, - R{{3, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -1}) - .input_shape({1, 3}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{100}})) - .Finalize() - .value()})); -} - -/// Returns the list of partitions generated by -/// `PartitionIndexTransformOverRegularGrid` when called with the specified -/// arguments. -/// -/// \param grid_output_dimensions The sequence of output dimensions of the index -/// space "output" corresponding to the grid. -/// \param grid_cell_shape Array of length `grid_output_dimensions.size()` -/// specifying the cell of a grid cell along each grid dimension. -/// \param transform A transform from the "full" input space to the "output" -/// index space. -/// \returns The list of partitions. -std::vector GetIrregularPartitions( - const std::vector& grid_output_dimensions, - const IrregularGrid& grid, IndexTransformView<> transform) { - std::vector results; - TENSORSTORE_CHECK_OK(tensorstore::internal::PartitionIndexTransformOverGrid( - grid_output_dimensions, grid, transform, - [&](tensorstore::span grid_cell_indices, - IndexTransformView<> cell_transform) { - results.emplace_back(std::vector(grid_cell_indices.begin(), - grid_cell_indices.end()), - IndexTransform<>(cell_transform)); - return absl::OkStatus(); - })); - return results; -} - -// Tests that a 2-d identity-mapped input domain over `[0,30)*[0,30)` -TEST(PartitionIndexTransformOverIrregularGrid, TwoDimensionalIdentity) { - const std::vector grid_output_dimensions{0, 1}; - std::vector dimension0{15}; // single split point - std::vector dimension1{-10, 10, 100}; // multiple split points - IrregularGrid grid({dimension0, dimension1}); - - std::vector results = - GetIrregularPartitions(grid_output_dimensions, grid, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({30, 30}) - .output_identity_transform() - .Finalize() - .value()); - - // According to SimpleIrregularGrid, indices < 0 are below the minimum bound - // and in real code could be clipped. - EXPECT_THAT( // - results, // - ElementsAre( // - R{{-1, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({15, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{-1, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 10}) - .input_shape({15, 20}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 0}, - IndexTransformBuilder<>(2, 2) - .input_origin({15, 0}) - .input_shape({15, 10}) - .output_identity_transform() - .Finalize() - .value()}, - R{{0, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({15, 10}) - .input_shape({15, 20}) - .output_identity_transform() - .Finalize() - .value()} // - )); -} - -TEST(PartitionIndexTransformOverIrregularGrid, IndexArrayAndStridedDimensions) { - std::vector dimension0{10, 15, 20, 30, 50}; - std::vector dimension1{0, 1, 5, 10, 13}; - IrregularGrid grid({dimension0, dimension1}); - - std::vector results = GetIrregularPartitions( - {0, 1}, grid, - IndexTransformBuilder<>(2, 2) - .input_origin({-4, 100}) - .input_shape({6, 3}) - .output_index_array(0, 5, 3, MakeArray({{10, 3, 4}})) - .output_single_input_dimension(1, 4, -2, 0) - .Finalize() - .value()); - - // Input index 1: 100 101 102 - // Index array 0: 10 3 4 - // Output index 0: 35 14 17 - // = 5 + 3 * Index array 0 - // Grid index 0: 3 0 1 - // - // Input index 0: -4 -3 -2 -1 0 1 - // Output index 1: 12 10 8 6 4 2 - // = 4 - 2 * Input index 0 - // Grid index 1: 3 3 2 2 1 1 - EXPECT_THAT( - results, - ElementsAre(R{{0, 3}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -4}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{101}})) - .Finalize() - .value()}, - R{{0, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -2}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{101}})) - .Finalize() - .value()}, - R{{0, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{101}})) - .Finalize() - .value()}, - - R{{1, 3}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -4}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{102}})) - .Finalize() - .value()}, - R{{1, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -2}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{102}})) - .Finalize() - .value()}, - R{{1, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{102}})) - .Finalize() - .value()}, - - R{{3, 3}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -4}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{100}})) - .Finalize() - .value()}, - R{{3, 2}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, -2}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{100}})) - .Finalize() - .value()}, - R{{3, 1}, - IndexTransformBuilder<>(2, 2) - .input_origin({0, 0}) - .input_shape({1, 2}) - .output_single_input_dimension(0, 1) - .output_index_array(1, 0, 1, MakeArray({{100}})) - .Finalize() - .value()} // - )); -} - -} // namespace partition_tests - namespace get_grid_cell_ranges_tests { using R = Box<>;