Skip to content

Commit

Permalink
Rework PartitionIndexTransformOverGrid to avoid recursion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695486670
Change-Id: Id0ad384bba73227b0bc286796aa491ee4b443382
  • Loading branch information
laramiel authored and copybara-github committed Nov 11, 2024
1 parent 059955b commit 0411334
Show file tree
Hide file tree
Showing 11 changed files with 1,139 additions and 936 deletions.
1 change: 1 addition & 0 deletions tensorstore/driver/neuroglancer_precomputed/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "tensorstore/internal/chunk_grid_specification.h"
#include "tensorstore/internal/grid_chunk_key_ranges_base10.h"
#include "tensorstore/internal/grid_partition.h"
#include "tensorstore/internal/grid_partition_iterator.h"
#include "tensorstore/internal/grid_storage_statistics.h"
#include "tensorstore/internal/json_binding/bindable.h"
#include "tensorstore/internal/json_binding/json_binding.h"
Expand Down
3 changes: 2 additions & 1 deletion tensorstore/driver/stack/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

#include "tensorstore/driver/driver.h"

#include <assert.h>
#include <stddef.h>

#include <algorithm>
#include <cassert>
#include <numeric>
#include <optional>
#include <string_view>
Expand Down Expand Up @@ -52,6 +52,7 @@
#include "tensorstore/internal/concurrency_resource.h"
#include "tensorstore/internal/data_copy_concurrency_resource.h"
#include "tensorstore/internal/grid_partition.h"
#include "tensorstore/internal/grid_partition_iterator.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/internal/irregular_grid.h"
#include "tensorstore/internal/json_binding/json_binding.h"
Expand Down
1 change: 1 addition & 0 deletions tensorstore/driver/zarr3/chunk_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "tensorstore/internal/cache/kvs_backed_chunk_cache.h"
#include "tensorstore/internal/chunk_grid_specification.h"
#include "tensorstore/internal/grid_partition.h"
#include "tensorstore/internal/grid_partition_iterator.h"
#include "tensorstore/internal/grid_storage_statistics.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/internal/lexicographical_grid_index_key.h"
Expand Down
27 changes: 25 additions & 2 deletions tensorstore/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,16 @@ tensorstore_cc_test(

tensorstore_cc_library(
name = "grid_partition",
srcs = ["grid_partition.cc"],
srcs = [
"grid_partition.cc",
"grid_partition_iterator.cc",
],
hdrs = [
"grid_partition.h",
"grid_partition_iterator.h",
],
deps = [
":grid_partition_impl",
":intrusive_ptr",
"//tensorstore:box",
"//tensorstore:index",
"//tensorstore:index_interval",
Expand All @@ -512,6 +515,7 @@ tensorstore_cc_library(
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
],
)
Expand Down Expand Up @@ -1869,3 +1873,22 @@ tensorstore_cc_test(
"@com_google_googletest//:gtest_main",
],
)

tensorstore_cc_test(
name = "grid_partition_iterator_test",
srcs = ["grid_partition_iterator_test.cc"],
deps = [
":grid_partition",
":grid_partition_impl",
":regular_grid",
"//tensorstore:array",
"//tensorstore:index",
"//tensorstore:index_interval",
"//tensorstore/index_space:index_transform",
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
],
)
1 change: 1 addition & 0 deletions tensorstore/internal/cache/chunk_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "tensorstore/internal/chunk_grid_specification.h"
#include "tensorstore/internal/elementwise_function.h"
#include "tensorstore/internal/grid_partition.h"
#include "tensorstore/internal/grid_partition_iterator.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/internal/lock_collection.h"
#include "tensorstore/internal/memory.h"
Expand Down
193 changes: 0 additions & 193 deletions tensorstore/internal/grid_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
#include <algorithm>
#include <array>
#include <cassert>
#include <utility>
#include <vector>

#include "absl/container/fixed_array.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
Expand All @@ -34,10 +32,8 @@
#include "tensorstore/index_space/output_index_map.h"
#include "tensorstore/index_space/output_index_method.h"
#include "tensorstore/internal/grid_partition_impl.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/rank.h"
#include "tensorstore/util/dimension_set.h"
#include "tensorstore/util/iterate.h"
#include "tensorstore/util/result.h"
#include "tensorstore/util/span.h"
#include "tensorstore/util/status.h"
Expand All @@ -49,7 +45,6 @@ namespace tensorstore {
namespace internal_grid_partition {
namespace {


using IndexArraySet = IndexTransformGridPartition::IndexArraySet;
using StridedSet = IndexTransformGridPartition::StridedSet;

Expand All @@ -58,10 +53,6 @@ struct ConnectedSetIterateParameters {
tensorstore::span<const DimensionIndex> grid_output_dimensions;
OutputToGridCellFn output_to_grid_cell;
IndexTransformView<> transform;
absl::FunctionRef<absl::Status(
tensorstore::span<const Index> grid_cell_indices,
IndexTransformView<> cell_transform)>
func;
};

/// Sets the fixed grid cell indices for all grid dimensions that do not
Expand Down Expand Up @@ -149,165 +140,6 @@ class StridedSetGridCellIterator {
Index input_index_;
};

// Java-style iterator for computing the grid cells that intersect the original
// input domain for a given `StridedSet`.
//
/// For each grid cell, updates the `output_grid_cell_indices` for the given
/// index array.
class IndexArraySetIterator {
public:
IndexArraySetIterator(const IndexArraySet& index_array_set)
: grid_dimensions_(index_array_set.grid_dimensions),
grid_cell_indices_(index_array_set.grid_cell_indices),
partition_end_index_(index_array_set.num_partitions()),
partition_index_(0) {}

void Reset() { partition_index_ = 0; }

bool AtEnd() const { return partition_index_ == partition_end_index_; }

Index Next(tensorstore::span<Index> output_grid_cell_indices) {
assert(!AtEnd());

// Assign the grid_cell_indices to the precomputed grid cell indices for
// this partition.
const Index grid_cell_indices_offset =
partition_index_ * grid_dimensions_.count();
DimensionIndex grid_i = 0;
for (DimensionIndex grid_dim : grid_dimensions_.index_view()) {
output_grid_cell_indices[grid_dim] =
grid_cell_indices_[grid_cell_indices_offset + grid_i++];
}

return partition_index_++;
}

private:
DimensionSet grid_dimensions_;
tensorstore::span<const Index> grid_cell_indices_;
Index partition_end_index_;
Index partition_index_;
};

/// Helper class for iterating over the grid cell index vectors and computing
/// the `cell_transform` for each grid cell, based on precomputed
/// `IndexTransformGridPartition` data.
class ConnectedSetIterateHelper {
public:
explicit ConnectedSetIterateHelper(ConnectedSetIterateParameters params)
: params_(std::move(params)),
grid_cell_indices_(params_.grid_output_dimensions.size()),
cell_transform_(internal_grid_partition::InitializeCellTransform(
params_.info, params_.transform)) {
InitializeConstantGridCellIndices(
params_.transform, params_.grid_output_dimensions,
params_.output_to_grid_cell, grid_cell_indices_);
}

/// Iterates over all grid cells and invokes the iteration callback function.
///
/// This is implemented by recursively iterating over the partitions of each
/// connected set.
absl::Status Iterate() { return IterateOverIndexArraySets(0); }

private:
/// Recursively iterates over the partial grid cells corresponding to the
/// index array connected sets, starting with `set_i`.
///
/// For each grid cell, updates the `grid_cell_indices` for all grid
/// dimensions in the connected set and updates the `cell_transform` array
/// output index maps corresponding to each original input dimension in the
/// connected set.
///
/// If there are no remaining index array connected sets over which to
/// recurse, starts recusing over the strided connected sets.
///
/// Iteration is aborted if `InvokeCallback` returns an error.
///
/// \param set_i The next index array connected set over which to iterate, in
/// the range `[0, info.index_array_sets().size()]`.
/// \returns The return value of the last recursively call.
absl::Status IterateOverIndexArraySets(DimensionIndex set_i) {
if (set_i == params_.info.index_array_sets().size()) {
return IterateOverStridedSets(0);
}
const IndexArraySet& index_array_set =
params_.info.index_array_sets()[set_i];
IndexArraySetIterator iterator(index_array_set);
while (!iterator.AtEnd()) {
Index partition_i = iterator.Next(grid_cell_indices_);
UpdateCellTransformForIndexArraySetPartition(
index_array_set, set_i, partition_i, cell_transform_.get());
TENSORSTORE_RETURN_IF_ERROR(IterateOverIndexArraySets(set_i + 1));
}
return absl::OkStatus();
}

/// Recursively iterates over the partial grid cells corresponding to the
/// strided connected sets, starting with `set_i`.
///
/// For each grid cell, updates the `grid_cell_indices` for all grid
/// dimensions in the connected set, and updates the input domain of the
/// corresponding synthetic input dimension of `cell_transform`. The output
/// index maps do not need to be updated.
///
/// If there are no remaining strided sets over which to recurse, just invokes
/// the iteration callback function.
///
/// Iteration is aborted if `InvokeCallback` returns an error.
///
/// \param set_i The next strided connected set over which to iterate, in the
/// range `[0, info.strided_sets().size()]`.
/// \returns The return value of the last recursive call, or the last call to
/// `InvokeCallback`.
absl::Status IterateOverStridedSets(DimensionIndex set_i) {
if (set_i == params_.info.strided_sets().size()) return InvokeCallback();
StridedSetGridCellIterator iterator(
params_.transform, params_.grid_output_dimensions,
params_.output_to_grid_cell, params_.info.strided_sets()[set_i]);
const DimensionIndex cell_input_dim =
set_i + params_.info.index_array_sets().size();
while (!iterator.AtEnd()) {
auto restricted_domain = iterator.Next(grid_cell_indices_);
// Set the input domain for `cell_input_dim` for the duration of the
// subsequent recursive call to IterateOverStridedSets.
cell_transform_->input_origin()[cell_input_dim] =
restricted_domain.inclusive_min();
cell_transform_->input_shape()[cell_input_dim] = restricted_domain.size();
// Recursively iterate over the next strided connected set.
TENSORSTORE_RETURN_IF_ERROR(IterateOverStridedSets(set_i + 1));
}
return absl::OkStatus();
}

/// Calls the iteration callback function.
///
/// If an error `absl::Status` is returned, iteration should stop.
///
/// \error Any error returned by the iteration callback function.
absl::Status InvokeCallback() {
internal_index_space::DebugCheckInvariants(cell_transform_.get());
auto status = params_.func(
grid_cell_indices_,
TransformAccess::Make<IndexTransformView<>>(cell_transform_.get()));
// If `func` created and is still holding a reference to `cell_transform_`,
// we need to make a copy before modifying it.
cell_transform_ = MutableRep(std::move(cell_transform_));
return status;
}

ConnectedSetIterateParameters params_;

// Current grid cell index vector `h` as defined in grid_partition.h, modified
// in place while iterating over all index vectors in `H`.
absl::FixedArray<Index, internal::kNumInlinedDims> grid_cell_indices_;

// This stores the current value of `cell_transform[h]`, as defined in
// grid_partition.h, for `h = grid_cell_indices_`. This is modified in
// place while iterating over all values for grid_cell_indices_.
internal_index_space::TransformRep::Ptr<> cell_transform_;
};

bool GetStridedGridCellRanges(
IndexTransformView<> transform, OutputToGridCellFn output_to_grid_cell,
DimensionIndex grid_dim, DimensionIndex output_dim,
Expand Down Expand Up @@ -511,31 +343,6 @@ class GetGridCellRangesIterateHelper {
} // namespace
} // namespace internal_grid_partition

namespace internal {

absl::Status PartitionIndexTransformOverGrid(
tensorstore::span<const DimensionIndex> grid_output_dimensions,
OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform,
absl::FunctionRef<
absl::Status(tensorstore::span<const Index> grid_cell_indices,
IndexTransformView<> cell_transform)>
func) {
internal_grid_partition::IndexTransformGridPartition partition_info;
auto status = internal_grid_partition::PrePartitionIndexTransformOverGrid(
transform, grid_output_dimensions, output_to_grid_cell, partition_info);

if (!status.ok()) return status;
return internal_grid_partition::ConnectedSetIterateHelper(
{/*.info=*/partition_info,
/*.grid_output_dimensions=*/grid_output_dimensions,
/*.output_to_grid_cell=*/output_to_grid_cell,
/*.transform=*/transform,
/*.func=*/std::move(func)})
.Iterate();
}

} // namespace internal

namespace internal_grid_partition {
absl::Status GetGridCellRanges(
const IndexTransformGridPartition& grid_partition,
Expand Down
34 changes: 0 additions & 34 deletions tensorstore/internal/grid_partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,40 +113,6 @@ namespace internal {
using OutputToGridCellFn = absl::FunctionRef<Index(
DimensionIndex grid_dim, Index output_index, IndexInterval* cell_bounds)>;

/// Partitions the input domain of a given `transform` from an input space
/// "full" to an output space "output" based on the grid (potentially irregular)
/// specified by `output_to_grid_cell`, which maps from a given dimension and
/// output_index to a grid cell and optional cell bounds.
///
/// For each grid cell index vector `h` in `H`, calls
/// `func(h, cell_transform[h])`.
///
/// To partition over a regular grid, `output_to_grid_cell` can be
/// internal_grid_partition::RegularGridRef.
///
/// \param grid_output_dimensions The sequence of dimensions of the index space
/// "output" corresponding to the grid by which to partition "full".
/// \param output_to_grid_cell Function returning, for the provided grid
/// dimension, the cell index corresponding to output_index, optionally
/// filling the bounds for the cell.
/// \param transform The index transform from "full" to "output". Must be
/// valid.
/// \param func The function to be called for each partition. May return an
/// error `absl::Status` to abort the iteration.
/// \returns `absl::Status()` on success, or the last error returned by `func`.
/// \error `absl::StatusCode::kInvalidArgument` if any input dimension of
/// `transform` has an unbounded domain.
/// \error `absl::StatusCode::kInvalidArgument` if integer overflow occurs.
/// \error `absl::StatusCode::kOutOfRange` if an index array contains an
/// out-of-bounds index.
absl::Status PartitionIndexTransformOverGrid(
tensorstore::span<const DimensionIndex> grid_output_dimensions,
OutputToGridCellFn output_to_grid_cell, IndexTransformView<> transform,
absl::FunctionRef<
absl::Status(tensorstore::span<const Index> grid_cell_indices,
IndexTransformView<> cell_transform)>
func);

absl::Status GetGridCellRanges(
tensorstore::span<const DimensionIndex> grid_output_dimensions,
BoxView<> grid_bounds, OutputToGridCellFn output_to_grid_cell,
Expand Down
Loading

0 comments on commit 0411334

Please sign in to comment.