Skip to content

Commit

Permalink
neuroglancer_precomputed: Ensure that shard-aligned writes are atomic
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599337105
Change-Id: I6d45917e902c84e2116226a82eb7934bdcca17f7
  • Loading branch information
jbms authored and copybara-github committed Jan 18, 2024
1 parent 70b16ae commit 956d0a2
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 34 deletions.
4 changes: 4 additions & 0 deletions tensorstore/driver/neuroglancer_precomputed/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,15 @@ tensorstore_cc_library(
"//tensorstore:strided_layout",
"//tensorstore:transaction",
"//tensorstore/driver",
"//tensorstore/driver:chunk",
"//tensorstore/driver:chunk_cache_driver",
"//tensorstore/driver:chunk_receiver_utils",
"//tensorstore/driver:kvs_backed_chunk_driver",
"//tensorstore/index_space:dimension_units",
"//tensorstore/index_space:index_transform",
"//tensorstore/internal:chunk_grid_specification",
"//tensorstore/internal:grid_chunk_key_ranges_base10",
"//tensorstore/internal:grid_partition",
"//tensorstore/internal:grid_storage_statistics",
"//tensorstore/internal:json_fwd",
"//tensorstore/internal:lexicographical_grid_index_key",
Expand All @@ -224,6 +227,7 @@ tensorstore_cc_library(
"//tensorstore/util:status",
"//tensorstore/util:str_cat",
"//tensorstore/util:unit",
"//tensorstore/util/execution:any_receiver",
"//tensorstore/util/garbage_collection",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
169 changes: 139 additions & 30 deletions tensorstore/driver/neuroglancer_precomputed/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
#include "tensorstore/context.h"
#include "tensorstore/contiguous_layout.h"
#include "tensorstore/data_type.h"
#include "tensorstore/driver/chunk.h"
#include "tensorstore/driver/chunk_cache_driver.h"
#include "tensorstore/driver/chunk_receiver_utils.h"
#include "tensorstore/driver/kvs_backed_chunk_driver.h"
#include "tensorstore/driver/neuroglancer_precomputed/chunk_encoding.h"
#include "tensorstore/driver/neuroglancer_precomputed/metadata.h"
Expand All @@ -59,6 +61,7 @@
#include "tensorstore/internal/cache_key/cache_key.h"
#include "tensorstore/internal/chunk_grid_specification.h"
#include "tensorstore/internal/grid_chunk_key_ranges_base10.h"
#include "tensorstore/internal/grid_partition.h"
#include "tensorstore/internal/grid_storage_statistics.h"
#include "tensorstore/internal/json_binding/bindable.h"
#include "tensorstore/internal/json_binding/json_binding.h"
Expand All @@ -75,6 +78,7 @@
#include "tensorstore/util/constant_vector.h"
#include "tensorstore/util/dimension_set.h"
#include "tensorstore/util/division.h"
#include "tensorstore/util/execution/any_receiver.h"
#include "tensorstore/util/future.h"
#include "tensorstore/util/garbage_collection/fwd.h"
#include "tensorstore/util/result.h"
Expand Down Expand Up @@ -592,36 +596,18 @@ class ShardedDataCache : public DataCacheBase {
const void* metadata_ptr, size_t component_index) override {
const auto& metadata = this->metadata();
const auto& scale = metadata.scales[scale_index_];
const auto& sharding = *std::get_if<ShardingSpec>(&scale.sharding);
TENSORSTORE_ASSIGN_OR_RETURN(
auto layout, GetBaseChunkLayout(metadata, ChunkLayout::kRead));
if (ShardChunkHierarchy hierarchy; GetShardChunkHierarchy(
sharding, scale.box.shape(), scale.chunk_sizes[0], hierarchy)) {
// Each shard corresponds to a rectangular region.
Index write_chunk_shape[4];
write_chunk_shape[0] = metadata.num_channels;
for (int dim = 0; dim < 3; ++dim) {
const Index chunk_size = scale.chunk_sizes[0][dim];
const Index volume_size = scale.box.shape()[dim];
write_chunk_shape[3 - dim] = RoundUpTo(
std::min(hierarchy.shard_shape_in_chunks[dim] * chunk_size,
volume_size),
chunk_size);
}
TENSORSTORE_RETURN_IF_ERROR(
layout.Set(ChunkLayout::WriteChunkShape(write_chunk_shape)));
} else {
// Each shard does not correspond to a rectangular region. The write
// chunk shape is equal to the full domain.
Index write_chunk_shape[4];
write_chunk_shape[0] = metadata.num_channels;
for (int dim = 0; dim < 3; ++dim) {
write_chunk_shape[3 - dim] =
RoundUpTo(scale.box.shape()[dim], scale.chunk_sizes[0][dim]);
}
TENSORSTORE_RETURN_IF_ERROR(
layout.Set(ChunkLayout::WriteChunkShape(write_chunk_shape)));
// Each shard does not correspond to a rectangular region. The write
// chunk shape is equal to the full domain.
Index write_chunk_shape[4];
write_chunk_shape[0] = metadata.num_channels;
for (int dim = 0; dim < 3; ++dim) {
write_chunk_shape[3 - dim] =
RoundUpTo(scale.box.shape()[dim], scale.chunk_sizes[0][dim]);
}
TENSORSTORE_RETURN_IF_ERROR(
layout.Set(ChunkLayout::WriteChunkShape(write_chunk_shape)));
TENSORSTORE_RETURN_IF_ERROR(layout.Finalize());
return layout;
}
Expand All @@ -637,6 +623,121 @@ class ShardedDataCache : public DataCacheBase {
std::array<int, 3> compressed_z_index_bits_;
};

// DataCache for sharded format in the case that shards correspond to
// rectangular regions.
class RegularlyShardedDataCache : public ShardedDataCache {
public:
RegularlyShardedDataCache(Initializer initializer,
std::string_view key_prefix,
const MultiscaleMetadata& metadata,
size_t scale_index,
std::array<Index, 3> chunk_size_xyz,
ShardChunkHierarchy hierarchy)
: ShardedDataCache(std::move(initializer), key_prefix, metadata,
scale_index, chunk_size_xyz),
hierarchy_(hierarchy) {}

Result<ChunkLayout> GetChunkLayoutFromMetadata(
const void* metadata_ptr, size_t component_index) override {
const auto& metadata = this->metadata();
const auto& scale = metadata.scales[scale_index_];
TENSORSTORE_ASSIGN_OR_RETURN(
auto layout, GetBaseChunkLayout(metadata, ChunkLayout::kRead));
// Each shard corresponds to a rectangular region.
Index write_chunk_shape[4];
write_chunk_shape[0] = metadata.num_channels;
for (int dim = 0; dim < 3; ++dim) {
const Index chunk_size = scale.chunk_sizes[0][dim];
const Index volume_size = scale.box.shape()[dim];
write_chunk_shape[3 - dim] =
RoundUpTo(std::min(hierarchy_.shard_shape_in_chunks[dim] * chunk_size,
volume_size),
chunk_size);
}
TENSORSTORE_RETURN_IF_ERROR(
layout.Set(ChunkLayout::WriteChunkShape(write_chunk_shape)));
TENSORSTORE_RETURN_IF_ERROR(layout.Finalize());
return layout;
}

void Read(internal::OpenTransactionPtr transaction, size_t component_index,
IndexTransform<> transform, absl::Time staleness,
AnyFlowReceiver<absl::Status, internal::ReadChunk, IndexTransform<>>
receiver) override {
return ShardedReadOrWrite(
std::move(transaction), std::move(transform), std::move(receiver),
[&](internal::OpenTransactionPtr transaction,
IndexTransform<> transform,
AnyFlowReceiver<absl::Status, internal::ReadChunk, IndexTransform<>>
receiver) {
return ShardedDataCache::Read(std::move(transaction), component_index,
std::move(transform), staleness,
std::move(receiver));
});
}

void Write(
internal::OpenTransactionPtr transaction, size_t component_index,
IndexTransform<> transform,
AnyFlowReceiver<absl::Status, internal::WriteChunk, IndexTransform<>>
receiver) override {
return ShardedReadOrWrite(
std::move(transaction), std::move(transform), std::move(receiver),
[&](internal::OpenTransactionPtr transaction,
IndexTransform<> transform,
AnyFlowReceiver<absl::Status, internal::WriteChunk,
IndexTransform<>>
receiver) {
return ShardedDataCache::Write(std::move(transaction),
component_index, std::move(transform),
std::move(receiver));
});
}

private:
template <typename ChunkType, typename Callback>
void ShardedReadOrWrite(
internal::OpenTransactionPtr transaction, IndexTransform<> transform,
AnyFlowReceiver<absl::Status, ChunkType, IndexTransform<>> receiver,
Callback callback) {
const auto& metadata = this->metadata();
const auto& scale = metadata.scales[scale_index_];
const DimensionIndex chunked_to_cell_dimensions[] = {3, 2, 1};
Index shard_shape_in_elements[3];
for (DimensionIndex dim = 0; dim < 3; ++dim) {
shard_shape_in_elements[dim] =
scale.chunk_sizes[0][dim] * hierarchy_.shard_shape_in_chunks[dim];
}
using State = internal::ChunkOperationState<ChunkType>;
using ForwardingReceiver =
internal::ForwardingChunkOperationReceiver<State>;
auto state = internal::MakeIntrusivePtr<State>(std::move(receiver));
auto status = internal::PartitionIndexTransformOverRegularGrid(
chunked_to_cell_dimensions, shard_shape_in_elements, transform,
[&](span<const Index> grid_cell_indices,
IndexTransformView<> cell_transform) -> absl::Status {
if (state->cancelled()) {
return absl::CancelledError("");
}
TENSORSTORE_ASSIGN_OR_RETURN(
auto cell_to_source,
ComposeTransforms(transform, cell_transform));
internal::OpenTransactionPtr shard_transaction = transaction;
if constexpr (std::is_same_v<ChunkType, internal::WriteChunk>) {
if (!shard_transaction) {
shard_transaction = internal::TransactionState::MakeImplicit();
shard_transaction->RequestCommit();
}
}
callback(std::move(shard_transaction), std::move(cell_to_source),
ForwardingReceiver{state, cell_transform});
return absl::OkStatus();
});
}

ShardChunkHierarchy hierarchy_;
};

class NeuroglancerPrecomputedDriver;
using NeuroglancerPrecomputedDriverBase =
internal_kvs_backed_chunk_driver::RegisteredKvsDriver<
Expand Down Expand Up @@ -744,9 +845,17 @@ class NeuroglancerPrecomputedDriver::OpenState
assert(scale_index_);
const auto& scale = metadata.scales[scale_index_.value()];
if (std::holds_alternative<ShardingSpec>(scale.sharding)) {
return std::make_unique<ShardedDataCache>(
std::move(initializer), spec().store.path, metadata,
scale_index_.value(), chunk_size_xyz_);
if (ShardChunkHierarchy hierarchy; GetShardChunkHierarchy(
std::get<ShardingSpec>(scale.sharding), scale.box.shape(),
scale.chunk_sizes[0], hierarchy)) {
return std::make_unique<RegularlyShardedDataCache>(
std::move(initializer), spec().store.path, metadata,
scale_index_.value(), chunk_size_xyz_, hierarchy);
} else {
return std::make_unique<ShardedDataCache>(
std::move(initializer), spec().store.path, metadata,
scale_index_.value(), chunk_size_xyz_);
}
} else {
return std::make_unique<UnshardedDataCache>(
std::move(initializer), spec().store.path, metadata,
Expand Down
88 changes: 88 additions & 0 deletions tensorstore/driver/neuroglancer_precomputed/driver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,94 @@ TEST(FullShardWriteTest, WithTransaction) {
TENSORSTORE_ASSERT_OK(txn.future());
}

TEST(FullShardWriteTest, WithoutTransaction) {
auto context = Context::Default();

TENSORSTORE_ASSERT_OK_AND_ASSIGN(
auto mock_key_value_store_resource,
context.GetResource<tensorstore::internal::MockKeyValueStoreResource>());
auto mock_key_value_store = *mock_key_value_store_resource;

::nlohmann::json json_spec{
{"driver", "neuroglancer_precomputed"},
{"kvstore",
{
{"driver", "mock_key_value_store"},
{"path", "prefix/"},
}},
{"create", true},
{"multiscale_metadata",
{
{"data_type", "uint16"},
{"num_channels", 1},
{"type", "image"},
}},
{"scale_metadata",
{
{"key", "1_1_1"},
{"resolution", {1, 1, 1}},
{"encoding", "raw"},
{"chunk_size", {2, 2, 2}},
{"size", {4, 6, 10}},
{"voxel_offset", {0, 0, 0}},
{"sharding",
{{"@type", "neuroglancer_uint64_sharded_v1"},
{"preshift_bits", 1},
{"minishard_bits", 2},
{"shard_bits", 3},
{"data_encoding", "raw"},
{"minishard_index_encoding", "raw"},
{"hash", "identity"}}},
}},
};

// Grid shape: {2, 3, 5}
// Full shard shape is {2, 2, 2} in chunks.
// Full shard shape is {4, 4, 4} in voxels.
// Shard 0 origin: {0, 0, 0}
// Shard 1 origin: {0, 4, 0}
// Shard 2 origin: {0, 0, 4}
// Shard 3 origin: {0, 4, 4}
// Shard 4 origin: {0, 0, 8}
// Shard 5 origin: {0, 4, 8}

auto store_future = tensorstore::Open(json_spec, context);
store_future.Force();

{
auto req = mock_key_value_store->read_requests.pop();
EXPECT_EQ("prefix/info", req.key);
req.promise.SetResult(kvstore::ReadResult::Missing(absl::Now()));
}

{
auto req = mock_key_value_store->write_requests.pop();
EXPECT_EQ("prefix/info", req.key);
EXPECT_EQ(StorageGeneration::NoValue(), req.options.if_equal);
req.promise.SetResult(TimestampedStorageGeneration{
StorageGeneration::FromString("g0"), absl::Now()});
}

TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto store, store_future.result());

auto future = tensorstore::Write(
tensorstore::MakeScalarArray<uint16_t>(42),
store | tensorstore::Dims(0, 1, 2).SizedInterval({0, 4, 8}, {4, 2, 2}));

future.Force();

{
auto req = mock_key_value_store->write_requests.pop();
ASSERT_EQ("prefix/1_1_1/5.shard", req.key);
// Writeback is unconditional because the entire shard is being written.
ASSERT_EQ(StorageGeneration::Unknown(), req.options.if_equal);
req.promise.SetResult(TimestampedStorageGeneration{
StorageGeneration::FromString("g0"), absl::Now()});
}

TENSORSTORE_ASSERT_OK(future);
}

// Tests that an empty path is handled correctly.
TEST(DriverTest, NoPrefix) {
auto context = Context::Default();
Expand Down
8 changes: 4 additions & 4 deletions tensorstore/internal/cache/chunk_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ class ChunkCache : public AsyncCache {
/// \param staleness Cached data older than `staleness` will not be returned
/// without being rechecked.
/// \param receiver Receiver for the chunks.
void Read(
internal::OpenTransactionPtr transaction, std::size_t component_index,
virtual void Read(
internal::OpenTransactionPtr transaction, size_t component_index,
IndexTransform<> transform, absl::Time staleness,
AnyFlowReceiver<absl::Status, ReadChunk, IndexTransform<>> receiver);

Expand All @@ -234,8 +234,8 @@ class ChunkCache : public AsyncCache {
/// `[0, grid().components.size())`.
/// \param transform The transform to apply.
/// \param receiver Receiver for the chunks.
void Write(
internal::OpenTransactionPtr transaction, std::size_t component_index,
virtual void Write(
internal::OpenTransactionPtr transaction, size_t component_index,
IndexTransform<> transform,
AnyFlowReceiver<absl::Status, WriteChunk, IndexTransform<>> receiver);

Expand Down

0 comments on commit 956d0a2

Please sign in to comment.