diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 37b6e2f1d..d88098f99 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -51,11 +51,11 @@ jobs: shell: bash run: | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: pip-${{ runner.os }}-buildwheel-${{ hashFiles('tools/ci/*_requirements.txt', 'third_party/pypa/*_requirements_frozen.txt') }} - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: | ~/.cache/cibuildwheel_bazel_cache/cache/repos diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7d80d4937..dd0386f03 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -7,45 +7,46 @@ jobs: strategy: matrix: python-version: - - '3.9' + - "3.12" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - # Need full history to determine version number. - fetch-depth: 0 - - name: 'Set up Python ${{ matrix.python-version }}' - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: 'Configure bazel remote cache write credentials' - env: - BAZEL_CACHE_SERVICE_ACCOUNT_KEY: ${{ secrets.BAZEL_CACHE_SERVICE_ACCOUNT_KEY }} - run: python ./tools/ci/configure_bazel_remote_cache.py --bazelrc ~/ci_bazelrc docs - shell: bash - - name: Get pip cache dir - id: pip-cache - shell: bash - run: | - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - uses: actions/cache@v2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: pip-${{ runner.os }}-docs-${{ matrix.python-version }}-${{ hashFiles('third_party/pypa/workspace.bzl') }} - - uses: actions/cache@v2 - with: - path: | - ~/.cache/bazel/_bazel_*/cache/repos - ~/.cache/bazelisk - key: bazel-docs-${{ hashFiles('.bazelversion', 'WORKSPACE', 'external.bzl', 'third_party/**') }} - - name: Build documentation - run: CC=gcc-10 python -u bazelisk.py --bazelrc ~/ci_bazelrc run --announce_rc --show_timestamps --keep_going --color=yes --verbose_failures //docs:build_docs -- --output docs_output - shell: bash - - name: Upload docs as artifact - uses: actions/upload-artifact@v4 - with: - name: docs - path: docs_output + - uses: actions/checkout@v4 + with: + # Need full history to determine version number. + fetch-depth: 0 + - name: "Set up Python ${{ matrix.python-version }}" + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: "Configure bazel remote cache write credentials" + env: + BAZEL_CACHE_SERVICE_ACCOUNT_KEY: ${{ secrets.BAZEL_CACHE_SERVICE_ACCOUNT_KEY }} + run: python ./tools/ci/configure_bazel_remote_cache.py --bazelrc ~/ci_bazelrc docs + shell: bash + - name: Get pip cache dir + id: pip-cache + shell: bash + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: pip-${{ runner.os }}-docs-${{ matrix.python-version }}-${{ hashFiles('third_party/pypa/workspace.bzl') }} + - uses: actions/cache@v4 + with: + path: | + ~/.cache/bazel/_bazel_*/cache/repos + ~/.cache/bazelisk + key: bazel-docs-${{ hashFiles('.bazelversion', 'WORKSPACE', 'external.bzl', 'third_party/**') }} + - name: Build documentation + run: CC=gcc-10 python -u bazelisk.py --bazelrc ~/ci_bazelrc run --announce_rc --show_timestamps --keep_going --color=yes --verbose_failures //docs:build_docs -- --output docs_output + shell: bash + - run: zip -r docs_output.zip docs_output + - name: Upload docs as artifact + uses: actions/upload-artifact@v4 + with: + name: docs + path: docs_output.zip publish-docs: # Only publish package on push to tag or default branch. @@ -54,12 +55,12 @@ jobs: needs: - build-docs steps: - - uses: actions/download-artifact@v4 - with: - name: docs - path: docs_output - - name: Publish to gh-pages - uses: peaceiris/actions-gh-pages@bbdfb200618d235585ad98e965f4aafc39b4c501 # v3.7.3 (2020-10-20) - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./docs_output + - uses: actions/download-artifact@v4 + with: + name: docs + - run: unzip docs_output.zip + - name: Publish to gh-pages + uses: peaceiris/actions-gh-pages@bbdfb200618d235585ad98e965f4aafc39b4c501 # v3.7.3 (2020-10-20) + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs_output diff --git a/docs/context_schema.yml b/docs/context_schema.yml index 4ff4a4034..4e9cc5f2d 100644 --- a/docs/context_schema.yml +++ b/docs/context_schema.yml @@ -17,35 +17,35 @@ properties: :literal:``. The resource specification must be compatible with :literal:``. examples: -- "cache_pool": - total_bytes_limit: 10000000 - "cache_pool#remote": - total_bytes_limit: 100000000 - "data_copy_concurrency": - limit: 8 + - "cache_pool": + total_bytes_limit: 10000000 + "cache_pool#remote": + total_bytes_limit: 100000000 + "data_copy_concurrency": + limit: 8 definitions: resource: $id: ContextResource description: |- Specifies a context resource of a particular :literal:``. oneOf: - - oneOf: - - type: object - - type: boolean - - type: number - description: |- - Specifies the resource directly. Any constraints on the value are - determined by the particular :literal:``. - - type: string - description: |- - References another resource of the same type in the current or parent - context using the syntax ``""`` or - ``"#"``, where :literal:`` - matches the type of this resource. - - type: 'null' - description: |- - Specifies a new instance of the default resource of the given - :literal:``. Only valid within a `Context` specification. + - oneOf: + - type: object + - type: boolean + - type: number + description: |- + Specifies the resource directly. Any constraints on the value are + determined by the particular :literal:``. + - type: string + description: |- + References another resource of the same type in the current or parent + context using the syntax ``""`` or + ``"#"``, where :literal:`` + matches the type of this resource. + - type: "null" + description: |- + Specifies a new instance of the default resource of the given + :literal:``. Only valid within a `Context` specification. cache_pool: $id: Context.cache_pool description: |- @@ -53,14 +53,33 @@ definitions: :literal:`cache_pool` resource specifies a separate memory pool. type: object properties: + disabled: + type: boolean + default: false + title: | + May be set to ``true`` to disable the cache entirely. + description: | + If set to ``true``, no other properties may be specified. Compared to + setting `.total_bytes_limit` to ``0``, multiple concurrent reads (e.g. + of the same chunk of an array) won't be coalesced. total_bytes_limit: type: integer minimum: 0 - description: |- - Soft limit on the total number of bytes in the cache. The - least-recently used data that is not in use is evicted from the cache - when this limit is reached. + title: |- + Soft limit on the total number of bytes in the cache. + description: | + The least-recently used data that is not in use is evicted from the + cache when this limit is reached. In-use data remains cached + regardless of the limit. default: 0 + queued_for_writeback_bytes_limit: + type: integer + minimum: 0 + description: |- + Soft limit on the total number of bytes of data pending writeback. + Writeback is initated on the least-recently used data that is pending + writeback when this limit is reached. Defaults to half of + `.total_bytes_limit`. data_copy_concurrency: $id: Context.data_copy_concurrency description: |- @@ -70,9 +89,9 @@ definitions: properties: limit: oneOf: - - type: integer - minimum: 1 - - const: "shared" + - type: integer + minimum: 1 + - const: "shared" description: |- The maximum number of CPU cores that may be used. If the special value of ``"shared"`` is specified, a shared global limit equal to the diff --git a/tensorstore/driver/driver_testutil.cc b/tensorstore/driver/driver_testutil.cc index b0a890f14..c2ba99890 100644 --- a/tensorstore/driver/driver_testutil.cc +++ b/tensorstore/driver/driver_testutil.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -582,9 +583,13 @@ absl::Status TestDriverWriteReadChunks( absl::BitGenRef gen, const TestDriverWriteReadChunksOptions& options) { Context context(options.context_spec); const auto is_write = options.total_write_bytes != 0; - tensorstore::OpenMode open_mode = is_write - ? tensorstore::OpenMode::open_or_create - : tensorstore::OpenMode::open; + tensorstore::OpenMode open_mode = + is_write + ? (options.delete_existing ? (tensorstore::OpenMode::create | + tensorstore::OpenMode::delete_existing) + : (tensorstore::OpenMode::open | + tensorstore::OpenMode::create)) + : tensorstore::OpenMode::open; tensorstore::ReadWriteMode read_write_mode = is_write ? tensorstore::ReadWriteMode::read_write @@ -629,11 +634,20 @@ absl::Status TestDriverWriteReadChunks( ABSL_LOG(INFO) << "read/write shape " << span(chunk_shape, ts.rank()); ABSL_LOG(INFO) << "Starting writes: " << options.repeat_writes << ", total_write_bytes=" << options.total_write_bytes; + + auto result_callback = options.result_callback; + if (!result_callback) { + result_callback = + [](const TestDriverWriteReadChunksOptions::Results& results) { + ABSL_LOG(INFO) << results.FormatSummary(); + return absl::OkStatus(); + }; + } for (int64_t i = 0; i < options.repeat_writes; i++) { TENSORSTORE_RETURN_IF_ERROR( TestDriverReadOrWriteChunks(gen, ts, span(chunk_shape, ts.rank()), options.total_write_bytes, options.strategy, - /*read=*/false)); + /*read=*/false, result_callback)); } ABSL_LOG(INFO) << "Starting reads: " << options.repeat_reads @@ -642,7 +656,7 @@ absl::Status TestDriverWriteReadChunks( TENSORSTORE_RETURN_IF_ERROR( TestDriverReadOrWriteChunks(gen, ts, span(chunk_shape, ts.rank()), options.total_read_bytes, options.strategy, - /*read=*/true)); + /*read=*/true, result_callback)); } return absl::OkStatus(); } @@ -696,10 +710,22 @@ void ForEachChunk(BoxView<> domain, DataType dtype, absl::BitGenRef gen, } // namespace +std::string TestDriverWriteReadChunksOptions::Results::FormatSummary() const { + auto elapsed_s = absl::FDivDuration(elapsed_time, absl::Seconds(1)); + double bytes_mb = static_cast(total_bytes) / 1e6; + + return absl::StrFormat( + "%s summary: %d bytes in %.0f ms: %.3f MB/second (%d chunks of %d " + "bytes)", + (read ? "Read" : "Write"), total_bytes, elapsed_s * 1e3, + bytes_mb / elapsed_s, num_chunks, chunk_bytes); +} + absl::Status TestDriverReadOrWriteChunks( absl::BitGenRef gen, tensorstore::TensorStore<> ts, span chunk_shape, int64_t total_bytes, - TestDriverWriteReadChunksOptions::Strategy strategy, bool read) { + TestDriverWriteReadChunksOptions::Strategy strategy, bool read, + const TestDriverWriteReadChunksOptions::ResultCallback& result_callback) { if (total_bytes == 0) return absl::OkStatus(); if (total_bytes < 0) { @@ -742,18 +768,15 @@ absl::Status TestDriverReadOrWriteChunks( op.future.Wait(); TENSORSTORE_RETURN_IF_ERROR(op.future.result()); - auto elapsed_s = - absl::FDivDuration(absl::Now() - start_time, absl::Seconds(1)); - double bytes_mb = static_cast(bytes_completed.load()) / 1e6; - - ABSL_LOG(INFO) - << (read ? "Read" : "Write") << " summary: " - << absl::StrFormat( - "%d bytes in %.0f ms: %.3f MB/second (%d chunks of %d bytes)", - bytes_completed.load(), elapsed_s * 1e3, bytes_mb / elapsed_s, - chunks_completed.load(), chunk_bytes); + TestDriverWriteReadChunksOptions::Results results; + results.chunk_shape = chunk_shape; + results.total_bytes = bytes_completed.load(); + results.chunk_bytes = chunk_bytes; + results.num_chunks = chunks_completed.load(); + results.elapsed_time = absl::Now() - start_time; + results.read = read; - return absl::OkStatus(); + return result_callback(results); } void RegisterTensorStoreDriverBasicFunctionalityTest( diff --git a/tensorstore/driver/driver_testutil.h b/tensorstore/driver/driver_testutil.h index cb6540258..2d6ffb3c6 100644 --- a/tensorstore/driver/driver_testutil.h +++ b/tensorstore/driver/driver_testutil.h @@ -278,6 +278,25 @@ struct TestDriverWriteReadChunksOptions { // Number of times to repeat the writes. int64_t repeat_writes = 1; + + // Delete existing data before writing. + bool delete_existing = true; + + struct Results { + span chunk_shape; + int64_t total_bytes; + int64_t chunk_bytes; + int64_t num_chunks; + absl::Duration elapsed_time; + bool read; + + std::string FormatSummary() const; + }; + + using ResultCallback = std::function; + + // Callback to invoke instead of logging results. + ResultCallback result_callback; }; // Tests concurrently reading and/or writing multiple chunks. @@ -296,7 +315,8 @@ absl::Status TestDriverWriteReadChunks( absl::Status TestDriverReadOrWriteChunks( absl::BitGenRef gen, tensorstore::TensorStore<> ts, span chunk_shape, int64_t total_bytes, - TestDriverWriteReadChunksOptions::Strategy strategy, bool read); + TestDriverWriteReadChunksOptions::Strategy strategy, bool read, + const TestDriverWriteReadChunksOptions::ResultCallback& result_callback); void TestTensorStoreCreateWithSchemaImpl(::nlohmann::json json_spec, const Schema& schema); diff --git a/tensorstore/driver/zarr3/BUILD b/tensorstore/driver/zarr3/BUILD index 20f8e66af..89c33c020 100644 --- a/tensorstore/driver/zarr3/BUILD +++ b/tensorstore/driver/zarr3/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("//bazel:tensorstore.bzl", "tensorstore_cc_binary", "tensorstore_cc_library", "tensorstore_cc_test") load("//docs:doctest.bzl", "doctest_test") @@ -190,10 +191,27 @@ tensorstore_cc_test( ], ) +bool_flag( + name = "disable_automatic_shard_batching", + build_setting_default = False, +) + +config_setting( + name = "disable_automatic_shard_batching_setting", + flag_values = { + ":disable_automatic_shard_batching": "True", + }, + visibility = ["//visibility:private"], +) + tensorstore_cc_library( name = "chunk_cache", srcs = ["chunk_cache.cc"], hdrs = ["chunk_cache.h"], + local_defines = select({ + ":disable_automatic_shard_batching_setting": ["TENSORSTORE_ZARR3_AUTOMATIC_SHARD_BATCHING=0"], + "//conditions:default": [], + }), deps = [ "//tensorstore:array", "//tensorstore:array_storage_statistics", diff --git a/tensorstore/driver/zarr3/chunk_cache.cc b/tensorstore/driver/zarr3/chunk_cache.cc index e3f00e24b..b98921ac4 100644 --- a/tensorstore/driver/zarr3/chunk_cache.cc +++ b/tensorstore/driver/zarr3/chunk_cache.cc @@ -63,6 +63,10 @@ #include "tensorstore/util/result.h" #include "tensorstore/util/span.h" +#ifndef TENSORSTORE_ZARR3_AUTOMATIC_SHARD_BATCHING +#define TENSORSTORE_ZARR3_AUTOMATIC_SHARD_BATCHING 1 +#endif + namespace tensorstore { namespace internal_zarr3 { @@ -317,9 +321,11 @@ void ZarrShardedChunkCache::Read( batch = std::move(request.batch), staleness_bound = request.staleness_bound](auto entry) { Batch shard_batch = batch; +#if TENSORSTORE_ZARR3_AUTOMATIC_SHARD_BATCHING if (!shard_batch) { shard_batch = Batch::New(); } +#endif return [=, shard_batch = std::move(shard_batch), entry = std::move(entry)]( span decoded_shape, IndexTransform<> transform, diff --git a/tensorstore/driver/zarr3/driver_test.cc b/tensorstore/driver/zarr3/driver_test.cc index 4ae374777..aeb4ed57e 100644 --- a/tensorstore/driver/zarr3/driver_test.cc +++ b/tensorstore/driver/zarr3/driver_test.cc @@ -340,6 +340,35 @@ TENSORSTORE_GLOBAL_INITIALIZER { std::move(options)); } +TENSORSTORE_GLOBAL_INITIALIZER { + tensorstore::internal::TensorStoreDriverBasicFunctionalityTestOptions options; + options.test_name = "zarr3/cache_disabled"; + options.create_spec = { + {"driver", "zarr3"}, + {"kvstore", {{"driver", "memory"}}}, + {"path", "prefix/"}, + {"cache_pool", {{"disabled", true}}}, + {"metadata", + { + {"data_type", "uint16"}, + {"shape", {10, 11}}, + {"chunk_grid", + {{"name", "regular"}, + {"configuration", {{"chunk_shape", {4, 5}}}}}}, + }}, + }; + options.expected_domain = tensorstore::IndexDomainBuilder(2) + .shape({10, 11}) + .implicit_upper_bounds({1, 1}) + .Finalize() + .value(); + options.initial_value = tensorstore::AllocateArray( + tensorstore::BoxView({10, 11}), tensorstore::c_order, + tensorstore::value_init); + tensorstore::internal::RegisterTensorStoreDriverBasicFunctionalityTest( + std::move(options)); +} + TENSORSTORE_GLOBAL_INITIALIZER { tensorstore::internal::TestTensorStoreDriverResizeOptions options; options.test_name = "zarr3/metadata"; diff --git a/tensorstore/driver/zarr3/sharding_benchmark_test.cc b/tensorstore/driver/zarr3/sharding_benchmark_test.cc index b761e20fd..aa8de2d92 100644 --- a/tensorstore/driver/zarr3/sharding_benchmark_test.cc +++ b/tensorstore/driver/zarr3/sharding_benchmark_test.cc @@ -21,11 +21,8 @@ // // cache_pool_size: // -// Size of the context "cache_pool" "total_bytes_limit" -// -// total_size: -// -// Indicates a volume of shape `total_size^3`. +// Size of the context "cache_pool" "total_bytes_limit", or `-1` to disable +// the cache completely. // // write_chunk_size: // @@ -85,8 +82,11 @@ struct BenchmarkHelper { {"driver", "zarr3"}, {"kvstore", "memory://"}, }; + if (cache_pool_size > 0) { json_spec["cache_pool"] = {{"total_bytes_limit", cache_pool_size}}; + } else if (cache_pool_size == -1) { + json_spec["cache_pool"] = {{"disabled", true}}; } TENSORSTORE_CHECK_OK_AND_ASSIGN(spec, Spec::FromJson(json_spec)); diff --git a/tensorstore/index_space/index_domain.h b/tensorstore/index_space/index_domain.h index a6e7d6840..4d1b096ae 100644 --- a/tensorstore/index_space/index_domain.h +++ b/tensorstore/index_space/index_domain.h @@ -525,9 +525,7 @@ explicit IndexDomain(const BoxType& box) explicit IndexDomain() -> IndexDomain<>; -/// Specializes the HasBoxDomain metafunction for IndexTransform. -/// -/// \relates IndexDomain +// Specializes the HasBoxDomain metafunction for IndexTransform. template constexpr inline bool HasBoxDomain> = true; diff --git a/tensorstore/index_space/index_transform.h b/tensorstore/index_space/index_transform.h index 46d16521e..cc6fb39a8 100644 --- a/tensorstore/index_space/index_transform.h +++ b/tensorstore/index_space/index_transform.h @@ -572,9 +572,9 @@ class IndexTransform { Ptr rep_{}; }; -/// Specializes the HasBoxDomain metafunction for `IndexTransform`. -/// -/// \relates IndexTransform +// Specializes the HasBoxDomain metafunction for `IndexTransform`. +// +// \relates IndexTransform template constexpr inline bool diff --git a/tensorstore/internal/BUILD b/tensorstore/internal/BUILD index 46b6a9908..515452118 100644 --- a/tensorstore/internal/BUILD +++ b/tensorstore/internal/BUILD @@ -23,6 +23,19 @@ config_setting( }, ) +bool_flag( + name = "nditerable_disable_2d_block", + build_setting_default = False, +) + +config_setting( + name = "nditerable_disable_2d_block_setting", + flag_values = { + ":nditerable_disable_2d_block": "True", + }, + visibility = ["//visibility:private"], +) + bool_flag( name = "grid_storage_statistics_debug", build_setting_default = False, @@ -1471,7 +1484,10 @@ tensorstore_cc_library( name = "nditerable_util", srcs = ["nditerable_util.cc"], hdrs = ["nditerable_util.h"], - local_defines = NDITERABLE_TEST_UNIT_BLOCK_SIZE_DEFINES, + local_defines = NDITERABLE_TEST_UNIT_BLOCK_SIZE_DEFINES + select({ + "nditerable_disable_2d_block_setting": ["TENSORSTORE_NDITERABLE_2D_BLOCK=0"], + "//conditions:default": [], + }), deps = [ ":arena", ":elementwise_function", diff --git a/tensorstore/internal/benchmark/ts_benchmark.cc b/tensorstore/internal/benchmark/ts_benchmark.cc index d05991c8e..5725c66b6 100644 --- a/tensorstore/internal/benchmark/ts_benchmark.cc +++ b/tensorstore/internal/benchmark/ts_benchmark.cc @@ -160,6 +160,8 @@ ABSL_FLAG(tensorstore::JsonAbslFlag, context_spec, ABSL_FLAG(std::string, strategy, "random", "Specifies the strategy to use: 'sequential' or 'random'."); +ABSL_FLAG(std::string, benchmark_id, "", "Identifier for the benchmark."); + ABSL_FLAG(tensorstore::VectorFlag, chunk_shape, {}, "Read/write chunks of --chunk_shape dimensions."); @@ -180,6 +182,8 @@ ABSL_FLAG(int64_t, repeat_reads, 1, ABSL_FLAG(int64_t, repeat_writes, 0, "Number of times to repeat write benchmark."); +ABSL_FLAG(bool, dump_metrics, true, "Print metrics to stdout."); + namespace tensorstore { namespace { @@ -213,6 +217,13 @@ void DoTsBenchmark() { options.repeat_writes = absl::GetFlag(FLAGS_repeat_writes); options.total_write_bytes = absl::GetFlag(FLAGS_total_write_bytes); options.total_read_bytes = absl::GetFlag(FLAGS_total_read_bytes); + options.result_callback = [&](const auto& results) { + auto summary_line = results.FormatSummary(); + auto benchmark_id = absl::GetFlag(FLAGS_benchmark_id); + std::cout << "[" << benchmark_id << "] " << summary_line << std::endl; + ABSL_LOG(INFO) << "[" << benchmark_id << "] " << summary_line; + return absl::OkStatus(); + }; if (options.total_write_bytes == 0 && options.total_read_bytes == 0) { ABSL_LOG(FATAL) @@ -223,7 +234,9 @@ void DoTsBenchmark() { absl::InsecureBitGen gen; TENSORSTORE_CHECK_OK(TestDriverWriteReadChunks(gen, options)); - internal::DumpMetrics(""); + if (absl::GetFlag(FLAGS_dump_metrics)) { + internal::DumpMetrics(""); + } } } // namespace diff --git a/tensorstore/internal/cache/BUILD b/tensorstore/internal/cache/BUILD index 690450ef0..36c7fdffc 100644 --- a/tensorstore/internal/cache/BUILD +++ b/tensorstore/internal/cache/BUILD @@ -275,6 +275,7 @@ tensorstore_cc_library( ":cache", "//tensorstore:context", "//tensorstore/internal:intrusive_ptr", + "//tensorstore/internal/cache_key", "//tensorstore/internal/json_binding", "//tensorstore/internal/json_binding:bindable", "//tensorstore/util:result", diff --git a/tensorstore/internal/cache/cache_pool_resource.cc b/tensorstore/internal/cache/cache_pool_resource.cc index e3785f257..ea617e7de 100644 --- a/tensorstore/internal/cache/cache_pool_resource.cc +++ b/tensorstore/internal/cache/cache_pool_resource.cc @@ -14,10 +14,14 @@ #include "tensorstore/internal/cache/cache_pool_resource.h" +#include +#include + #include #include "tensorstore/context.h" #include "tensorstore/context_resource_provider.h" #include "tensorstore/internal/cache/cache.h" +#include "tensorstore/internal/cache_key/std_optional.h" #include "tensorstore/internal/intrusive_ptr.h" #include "tensorstore/internal/json_binding/bindable.h" #include "tensorstore/internal/json_binding/json_binding.h" @@ -29,29 +33,53 @@ namespace { struct CachePoolResourceTraits : public ContextResourceTraits { - using Spec = CachePool::Limits; + // Specifies cache pool limits, or `nullopt` to indicate the cache is disabled + // completely. + using Spec = std::optional; using Resource = typename CachePoolResource::Resource; - static constexpr Spec Default() { return {}; } + static constexpr Spec Default() { return Spec{std::in_place}; } static constexpr auto JsonBinder() { namespace jb = tensorstore::internal_json_binding; return jb::Object( - jb::Member("total_bytes_limit", - jb::Projection(&Spec::total_bytes_limit, - jb::DefaultValue([](auto* v) { *v = 0; })))); + jb::Member( + "disabled", + jb::GetterSetter( + [](auto& obj) { return !obj.has_value(); }, + [](auto& obj, bool disabled) { + if (disabled) { + obj = std::nullopt; + } else { + obj.emplace(); + } + }, + jb::DefaultInitializedValue())), + [](auto is_loading, const auto& options, auto* obj, auto* j) { + if (!*obj) return absl::OkStatus(); + return jb::Member( + "total_bytes_limit", + jb::Projection(&CachePool::Limits::total_bytes_limit, + jb::DefaultValue([](auto* v) { *v = 0; })))( + is_loading, options, obj, j); + }); } static Result Create(const Spec& limits, ContextResourceCreationContext context) { - return CachePool::WeakPtr(CachePool::Make(limits)); + if (!limits) return CachePool::WeakPtr(); + return CachePool::WeakPtr(CachePool::Make(*limits)); } static Spec GetSpec(const Resource& pool, const ContextSpecBuilder& builder) { - return pool->limits(); + return pool ? Spec(pool->limits()) : Spec(std::nullopt); } static void AcquireContextReference(const Resource& p) { - internal_cache::StrongPtrTraitsCachePool::increment(p.get()); + if (p) { + internal_cache::StrongPtrTraitsCachePool::increment(p.get()); + } } static void ReleaseContextReference(const Resource& p) { - internal_cache::StrongPtrTraitsCachePool::decrement(p.get()); + if (p) { + internal_cache::StrongPtrTraitsCachePool::decrement(p.get()); + } } }; diff --git a/tensorstore/internal/cache/cache_pool_resource_test.cc b/tensorstore/internal/cache/cache_pool_resource_test.cc index 9e628ed91..a80222ec2 100644 --- a/tensorstore/internal/cache/cache_pool_resource_test.cc +++ b/tensorstore/internal/cache/cache_pool_resource_test.cc @@ -36,6 +36,13 @@ TEST(CachePoolResourceTest, Default) { EXPECT_EQ(0u, (*cache)->limits().total_bytes_limit); } +TEST(CachePoolResourceTest, Disable) { + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto cache, + Context::Default().GetResource({{"disabled", true}})); + EXPECT_FALSE(cache->get()); +} + TEST(CachePoolResourceTest, EmptyObject) { TENSORSTORE_ASSERT_OK_AND_ASSIGN( auto resource_spec, Context::Resource::FromJson( diff --git a/tensorstore/internal/json_binding/BUILD b/tensorstore/internal/json_binding/BUILD index 5082a59d8..45b92b5ee 100644 --- a/tensorstore/internal/json_binding/BUILD +++ b/tensorstore/internal/json_binding/BUILD @@ -106,7 +106,10 @@ tensorstore_cc_test( tensorstore_cc_library( name = "bindable", - hdrs = ["bindable.h"], + hdrs = [ + "bindable.h", + "json_binding_fwd.h", + ], deps = [ "//tensorstore:json_serialization_options_base", "//tensorstore/util:result", diff --git a/tensorstore/internal/json_binding/json_binding.h b/tensorstore/internal/json_binding/json_binding.h index 8b2bab5d6..a2bda257a 100644 --- a/tensorstore/internal/json_binding/json_binding.h +++ b/tensorstore/internal/json_binding/json_binding.h @@ -530,6 +530,21 @@ constexpr auto GetterSetter(Get get, Set set, Binder binder = DefaultBinder<>) { }; } +template )> +constexpr auto Setter(Set set, Binder binder = DefaultBinder<>) { + return [set = std::move(set), binder = std::move(binder)]( + auto is_loading, const auto& options, auto* obj, + auto* j) -> absl::Status { + if constexpr (is_loading) { + T projected; + TENSORSTORE_RETURN_IF_ERROR(binder(is_loading, options, &projected, j)); + return internal::InvokeForStatus(set, *obj, std::move(projected)); + } else { + return absl::OkStatus(); + } + }; +} + // Binder parameterized by distinct load and save objects. // Invokes LoadBinder when loading and SaveBinder when saving. template + +#include +#include "tensorstore/json_serialization_options_base.h" + +namespace tensorstore { +namespace internal_json_binding { + +/// Helper type used by `TENSORSTORE_DECLARE_JSON_BINDER_METHODS` for "parsing" +/// the macro varargs and handling default arguments. +template +struct JsonBindingMethodTypeHelper { + using Value = T; + using FromJsonOptions = FromJsonOptionsType; + using ToJsonOptions = ToJsonOptionsType; + using JsonValue = JsonValueType; +}; + +/// Declares that an arbitrary class may be used as a JSON binder type. +/// +/// This allows the members of the class to be used as parameters of the binder, +/// without requiring the JSON binder to be defined inline. +/// +/// The macro should be invoked with up to 4 types as arguments: +/// +/// - ValueType (required) +/// - FromJsonOptionsType = NoOptions +/// - ToJsonOptionsType = IncludeDefaults +/// - JsonValueType = ::nlohmann::json +/// +/// This macro handles the arguments specially such that commas within <> +/// brackets are not a problem. +/// +/// Example: +/// +/// struct Foo { +/// int value; +/// }; +/// +/// struct FooBinder { +/// int default_value = 10; +/// +/// TENSORSTORE_DECLARE_JSON_BINDER_METHODS(Foo) +/// }; +/// +/// namespace jb = tensorstore::internal_json_binding; +/// TENSORSTORE_DEFINE_JSON_BINDER_METHODS( +/// FooBinder, +/// jb::Object( +/// jb::Member("value", jb::Projection<&Foo::value>( +/// jb::DefaultValue([&](int *v) { *v = default_value; }))))) +/// +#define TENSORSTORE_DECLARE_JSON_BINDER_METHODS(...) \ + using BindingHelperType = \ + ::tensorstore::internal_json_binding::JsonBindingMethodTypeHelper< \ + __VA_ARGS__>; \ + using Value = typename BindingHelperType::Value; \ + using JsonValue = typename BindingHelperType::JsonValue; \ + using JsonBinderFromJsonOptions = \ + typename BindingHelperType::FromJsonOptions; \ + using JsonBinderToJsonOptions = typename BindingHelperType::ToJsonOptions; \ + absl::Status operator()(std::true_type is_loading, \ + const JsonBinderFromJsonOptions& options, \ + Value* value, JsonValue* j) const { \ + return this->DoLoadSaveJson(is_loading, options, value, j); \ + } \ + absl::Status operator()(std::false_type is_loading, \ + const JsonBinderToJsonOptions& options, \ + const Value* value, JsonValue* j) const { \ + return this->DoLoadSaveJson(is_loading, options, value, j); \ + } \ + template \ + absl::Status DoLoadSaveJson( \ + std::integral_constant is_loading, \ + const std::conditional_t& options, \ + std::conditional_t* value, JsonValue* j) \ + const; \ + /**/ + +/// Defines the JSON binder methods declared by +/// `TENSORSTORE_DECLARE_JSON_BINDER_METHODS`. See the documentation of that +/// macro for details. +#define TENSORSTORE_DEFINE_JSON_BINDER_METHODS(NAME, ...) \ + template \ + absl::Status NAME::DoLoadSaveJson( \ + std::integral_constant json_binder_is_loading, \ + const std::conditional_t& json_binder_options, \ + std::conditional_t* json_binder_value, \ + JsonValue* json_binder_j) const { \ + return (__VA_ARGS__)(json_binder_is_loading, json_binder_options, \ + json_binder_value, json_binder_j); \ + } \ + template absl::Status NAME::DoLoadSaveJson( \ + std::true_type, const JsonBinderFromJsonOptions&, Value*, JsonValue*) \ + const; \ + template absl::Status NAME::DoLoadSaveJson( \ + std::false_type, const JsonBinderToJsonOptions&, const Value*, \ + JsonValue*) const; \ + /**/ + +} // namespace internal_json_binding +} // namespace tensorstore + +#endif // TENSORSTORE_INTERNAL_JSON_BINDING_JSON_BINDING_FWD_H_ diff --git a/tensorstore/internal/nditerable_util.cc b/tensorstore/internal/nditerable_util.cc index eed86c05a..8af01ab4a 100644 --- a/tensorstore/internal/nditerable_util.cc +++ b/tensorstore/internal/nditerable_util.cc @@ -30,6 +30,10 @@ #include "tensorstore/util/iterate.h" #include "tensorstore/util/span.h" +#ifndef TENSORSTORE_NDITERABLE_2D_BLOCK +#define TENSORSTORE_NDITERABLE_2D_BLOCK 1 +#endif + namespace tensorstore { namespace internal { @@ -184,10 +188,12 @@ IterationBufferShape GetNDIterationBlockShape( const Index block_inner_size = std::max(Index(1), std::min(last_dimension_size, target_size)); Index block_outer_size = 1; +#if TENSORSTORE_NDITERABLE_2D_BLOCK if (block_inner_size < target_size) { block_outer_size = std::min(penultimate_dimension_size, target_size / block_inner_size); } +#endif return {block_outer_size, block_inner_size}; } #endif diff --git a/tensorstore/kvstore/ocdbt/io/manifest_cache.cc b/tensorstore/kvstore/ocdbt/io/manifest_cache.cc index c5752f2ee..05f94435a 100644 --- a/tensorstore/kvstore/ocdbt/io/manifest_cache.cc +++ b/tensorstore/kvstore/ocdbt/io/manifest_cache.cc @@ -171,6 +171,7 @@ void DoReadImpl(EntryOrNode* entry_or_node, internal::AsyncCache::AsyncCacheReadRequest request) { kvstore::ReadOptions kvstore_options; kvstore_options.staleness_bound = request.staleness_bound; + kvstore_options.batch = request.batch; auto read_state = internal::AsyncCache::ReadLock(*entry_or_node).read_state(); kvstore_options.generation_conditions.if_not_equal = diff --git a/tensorstore/serialization/BUILD b/tensorstore/serialization/BUILD index 2ca021c92..7dc5ff7ea 100644 --- a/tensorstore/serialization/BUILD +++ b/tensorstore/serialization/BUILD @@ -157,8 +157,8 @@ tensorstore_cc_library( deps = [ ":serialization", ":status", - "//tensorstore/internal:attributes", "//tensorstore/util:result", + "@com_google_absl//absl/base:core_headers", ], ) diff --git a/tensorstore/serialization/result.h b/tensorstore/serialization/result.h index 7a30787de..0e00d18b5 100644 --- a/tensorstore/serialization/result.h +++ b/tensorstore/serialization/result.h @@ -15,6 +15,7 @@ #ifndef TENSORSTORE_SERIALIZATION_RESULT_H_ #define TENSORSTORE_SERIALIZATION_RESULT_H_ +#include "absl/base/attributes.h" #include "tensorstore/serialization/serialization.h" #include "tensorstore/serialization/status.h" #include "tensorstore/util/result.h" @@ -22,19 +23,19 @@ namespace tensorstore { namespace serialization { -template -struct Serializer> { - [[nodiscard]] static bool Encode(EncodeSink& sink, const Result& value) { +template > +struct ResultSerializer { + [[nodiscard]] bool Encode(EncodeSink& sink, const Result& value) const { return serialization::Encode(sink, value.ok()) && - (value.ok() ? serialization::Encode(sink, *value) + (value.ok() ? value_serializer.Encode(sink, *value) : serialization::Encode(sink, value.status())); } - [[nodiscard]] static bool Decode(DecodeSource& source, Result& value) { + [[nodiscard]] bool Decode(DecodeSource& source, Result& value) const { bool has_value; if (!serialization::Decode(source, has_value)) return false; if (has_value) { - return serialization::Decode(source, value.emplace()); + return value_serializer.Decode(source, value.emplace()); } else { absl::Status status; if (!ErrorStatusSerializer::Decode(source, status)) return false; @@ -42,8 +43,13 @@ struct Serializer> { return true; } } + + ABSL_ATTRIBUTE_NO_UNIQUE_ADDRESS ValueSerializer value_serializer; }; +template +struct Serializer> : public ResultSerializer {}; + } // namespace serialization } // namespace tensorstore diff --git a/tensorstore/transaction.cc b/tensorstore/transaction.cc index 41d138fea..7094abdfa 100644 --- a/tensorstore/transaction.cc +++ b/tensorstore/transaction.cc @@ -104,7 +104,9 @@ TransactionState::OpenPtr TransactionState::AcquireImplicitOpenPtr() { // Future reference was already released. Try to obtain another future // reference. future = promise_.future(); - if (future.null()) return {}; + if (future.null()) { + return {}; + } } if (!future.null()) { future_ = std::move(future); @@ -118,9 +120,14 @@ TransactionState::TransactionState(TransactionMode mode, commit_reference_count_{kFutureReferenceIncrement + kCommitReferenceIncrement}, open_reference_count_{implicit_transaction ? 1u : 0u}, - // Two weak references initially, one owned by the promise callback - // attached to `future_`, and one for the initial `Transaction` object. - weak_reference_count_{2}, + // Initial weak references: + // + // - one owned by the promise force callback attached to `future_`, + // + // - one owned by the promise not needed callback, + // + // - and one for the initial `Transaction` object. + weak_reference_count_{3}, total_bytes_{0}, commit_state_{kOpen}, implicit_transaction_(implicit_transaction) { @@ -130,12 +137,14 @@ TransactionState::TransactionState(TransactionMode mode, phase_ = 0; } auto [promise, future] = PromiseFuturePair::Make(MakeResult()); - promise_callback_ = promise.ExecuteWhenForced( + promise_force_callback_ = promise.ExecuteWhenForced( [self = IntrusivePtr>( this, adopt_object_ref)](Promise promise) { self->RequestCommit(); }); + promise_not_needed_callback_ = promise.ExecuteWhenNotNeeded( + [self = WeakPtr(this, adopt_object_ref)] { self->RequestAbort(); }); promise_ = std::move(promise); future_ = std::move(future); } @@ -182,25 +191,28 @@ void TransactionState::RequestAbort(const absl::Status& error, UniqueWriterLock lock) { auto commit_state = commit_state_; if (commit_state > kOpenAndCommitRequested) return; - SetDeferredResult(promise_, error); if (open_reference_count_.load(std::memory_order_relaxed) != 0) { // A thread is not permitted to increase `open_reference_count_` from 0 // except while owning a lock on `mutex_`. If another thread concurrently // decreases `open_reference_count_` to 0, `NoMoreOpenReferences` will take // care of calling `ExecuteAbort()`. commit_state_ = kAbortRequested; + lock.unlock(); + SetDeferredResult(promise_, error); return; } else { commit_state_ = kAborted; } // Ensure `ExecuteAbort` is run with the lock released. lock.unlock(); + SetDeferredResult(promise_, error); ExecuteAbort(); } void TransactionState::ExecuteAbort() { // Release the promise callback to break the reference cycle. - promise_callback_.Unregister(); + promise_force_callback_.Unregister(); + promise_not_needed_callback_.Unregister(); if (nodes_.empty()) { // Nothing to abort, just release `promise_` so that it becomes ready with // the error set previously by `SetDeferredResult`. @@ -246,7 +258,8 @@ void TransactionState::DecrementNodesPendingAbort(size_t count) { void TransactionState::ExecuteCommit() { assert(commit_state_ == kCommitStarted); // Release the promise callback to break the reference cycle. - promise_callback_.Unregister(); + promise_force_callback_.Unregister(); + promise_not_needed_callback_.Unregister(); ExecuteCommitPhase(); } diff --git a/tensorstore/transaction_impl.h b/tensorstore/transaction_impl.h index f8f77f667..89b51125f 100644 --- a/tensorstore/transaction_impl.h +++ b/tensorstore/transaction_impl.h @@ -823,7 +823,11 @@ class TransactionState { /// transaction. The callback holds a commit reference. This is unregistered /// when all other commit references have been released, in order to break the /// reference cycle. - FutureCallbackRegistration promise_callback_; + FutureCallbackRegistration promise_force_callback_; + + /// Registration of "not needed" callback on `promise_` that aborts the + /// transaction. + FutureCallbackRegistration promise_not_needed_callback_; /// Retained until all write handles are released. Promise promise_; diff --git a/tensorstore/transaction_test.cc b/tensorstore/transaction_test.cc index 393b5c082..b80e661fc 100644 --- a/tensorstore/transaction_test.cc +++ b/tensorstore/transaction_test.cc @@ -744,4 +744,31 @@ TEST(TransactionTest, ReleaseTransactionReferenceDuringAbort) { MatchesStatus(absl::StatusCode::kUnknown, "failed")); } +// Tests that releasing all of the `Future` references after requesting a +// commit, but before the commit actually starts, results in the transaction +// being aborted. +TEST(TransactionTest, ReleaseFutureReferencesAfterRequestCommit) { + NodeLog log; + auto txn = Transaction(tensorstore::isolated); + TransactionState::WeakPtr weak_txn(TransactionState::get(txn)); + WeakTransactionNodePtr weak_node(new TestNode(&log, 1)); + weak_txn->AcquireCommitBlock(); + { + TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto open_ptr, + AcquireOpenTransactionPtrOrError(txn)); + { + weak_node->SetTransaction(*open_ptr); + TENSORSTORE_EXPECT_OK(weak_node->Register()); + } + // Request commit. + txn.CommitAsync().IgnoreFuture(); + txn = no_transaction; + } + weak_txn->ReleaseCommitBlock(); + + EXPECT_THAT(log, ::testing::ElementsAre("abort:1")); + EXPECT_TRUE(weak_txn->aborted()); + weak_node->AbortDone(); +} + } // namespace diff --git a/tensorstore/util/BUILD b/tensorstore/util/BUILD index d628ee66c..6d61ff6c5 100644 --- a/tensorstore/util/BUILD +++ b/tensorstore/util/BUILD @@ -582,6 +582,23 @@ tensorstore_cc_test( ], ) +tensorstore_cc_library( + name = "split_box", + srcs = ["split_box.cc"], + hdrs = ["split_box.h"], + deps = ["//tensorstore:box"], +) + +tensorstore_cc_test( + name = "split_box_test", + size = "small", + srcs = ["split_box_test.cc"], + deps = [ + ":split_box", + "@com_google_googletest//:gtest_main", + ], +) + tensorstore_cc_library( name = "status", srcs = [ diff --git a/tensorstore/util/constant_vector.cc b/tensorstore/util/constant_vector.cc index b3ecfc9f8..73bb8f798 100644 --- a/tensorstore/util/constant_vector.cc +++ b/tensorstore/util/constant_vector.cc @@ -23,5 +23,10 @@ namespace internal_constant_vector { const std::string kStringArray[kMaxRank] = {}; +const DimensionIndex kIdentityPermutation[kMaxRank] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, +}; + } // namespace internal_constant_vector } // namespace tensorstore diff --git a/tensorstore/util/constant_vector.h b/tensorstore/util/constant_vector.h index e25159bea..e8f81d262 100644 --- a/tensorstore/util/constant_vector.h +++ b/tensorstore/util/constant_vector.h @@ -44,6 +44,8 @@ constexpr inline std::array kConstantArray = extern const std::string kStringArray[kMaxRank]; +extern const DimensionIndex kIdentityPermutation[kMaxRank]; + } // namespace internal_constant_vector /// Returns a `tensorstore::span` of length `length` filled with @@ -117,6 +119,16 @@ GetDefaultStringVector(std::integral_constant = {}) { return {internal_constant_vector::kStringArray, Length}; } +inline constexpr span GetIdentityPermutation( + DimensionIndex rank) { + assert(IsValidRank(rank)); + return {internal_constant_vector::kIdentityPermutation, rank}; +} + +inline constexpr const DimensionIndex* GetIdentityPermutation() { + return internal_constant_vector::kIdentityPermutation; +} + } // namespace tensorstore #endif // TENSORSTORE_UTIL_CONSTANT_VECTOR_H_ diff --git a/tensorstore/util/split_box.cc b/tensorstore/util/split_box.cc new file mode 100644 index 000000000..744581b87 --- /dev/null +++ b/tensorstore/util/split_box.cc @@ -0,0 +1,62 @@ +// Copyright 2020 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/util/split_box.h" + +#include + +#include "tensorstore/box.h" + +namespace tensorstore { + +bool SplitBoxByGrid(BoxView<> input, BoxView<> grid_cell_template, + std::array, 2> split_output) { + const DimensionIndex rank = input.rank(); + assert(rank == grid_cell_template.rank()); + assert(rank == split_output[0].rank()); + assert(rank == split_output[1].rank()); + + DimensionIndex split_dim = -1; + Index split_dim_min_cell = 0; + Index split_dim_max_cell = 1; + for (DimensionIndex dim = 0; dim < rank; ++dim) { + const IndexInterval input_interval = input[dim]; + const IndexInterval cell = grid_cell_template[dim]; + assert(tensorstore::IsFinite(input_interval) || + tensorstore::Contains(cell, input_interval)); + assert(!cell.empty()); + const Index min_cell = FloorOfRatio( + input_interval.inclusive_min() - cell.inclusive_min(), cell.size()); + const Index max_cell = CeilOfRatio( + input_interval.inclusive_max() - cell.inclusive_min() + 1, cell.size()); + if (max_cell - min_cell > split_dim_max_cell - split_dim_min_cell) { + split_dim = dim; + split_dim_max_cell = max_cell; + split_dim_min_cell = min_cell; + } + } + if (split_dim == -1) return false; + const Index split_cell = (split_dim_min_cell + split_dim_max_cell) / 2; + const Index split_index = grid_cell_template[split_dim].inclusive_min() + + split_cell * grid_cell_template[split_dim].size(); + split_output[0].DeepAssign(input); + split_output[1].DeepAssign(input); + split_output[0][split_dim] = IndexInterval::UncheckedHalfOpen( + input[split_dim].inclusive_min(), split_index); + split_output[1][split_dim] = IndexInterval::UncheckedHalfOpen( + split_index, input[split_dim].exclusive_max()); + return true; +} + +} // namespace tensorstore diff --git a/tensorstore/util/split_box.h b/tensorstore/util/split_box.h new file mode 100644 index 000000000..48470dbed --- /dev/null +++ b/tensorstore/util/split_box.h @@ -0,0 +1,50 @@ +// Copyright 2020 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORSTORE_UTIL_SPLIT_BOX_H_ +#define TENSORSTORE_UTIL_SPLIT_BOX_H_ + +#include + +#include "tensorstore/box.h" + +namespace tensorstore { + +/// Splits an input box into two approximately equal-sized halves aligned to a +/// grid cell boundary. +/// +/// If multiple dimensions can be split, splits along the dimension `i` for +/// which the largest number of grid cells intersects `input[i]`. +/// +/// \param input The input box to split. +/// \param grid_cell_template Specifies the grid to which the split must be +/// aligned. Each grid cell has a shape of `grid_cell_template.shape()` and +/// extends infinitely in all directions from an origin of +/// `grid_cell_template.origin()`. +/// \param split_output[out] Location where split result is stored. +/// \dchecks `IsFinite(input[i]) || Contains(grid_cell_template[i], input[i])` +/// for all dimensions `i`. +/// \dchecks `input.rank() == grid_cell_template.rank()` +/// \dchecks `split_output[0].rank() == input.rank()` +/// \dchecks `split_output[1].rank() == input.rank()` +/// \returns `true` if `input` could be split, i.e. it intersects more than one +/// grid cell. In this case, `split_output` is set to the split result. +/// Returns `false` if `input` intersects only a single grid cell. In this +/// case `split_output` is unchanged. +bool SplitBoxByGrid(BoxView<> input, BoxView<> grid_cell_template, + std::array, 2> split_output); + +} // namespace tensorstore + +#endif // TENSORSTORE_UTIL_SPLIT_BOX_H_ diff --git a/tensorstore/util/split_box_test.cc b/tensorstore/util/split_box_test.cc new file mode 100644 index 000000000..a4c5db082 --- /dev/null +++ b/tensorstore/util/split_box_test.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/util/split_box.h" + +#include +#include + +namespace { + +using ::tensorstore::Box; +using ::tensorstore::SplitBoxByGrid; + +TEST(SplitBoxByGridTest, NoSplitRank1) { + Box<> a(1), b(1); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({3}), + /*grid_cell_template=*/Box<1>({3}), {{a, b}})); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({3}), + /*grid_cell_template=*/Box<1>({4}), {{a, b}})); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({2}, {3}), + /*grid_cell_template=*/Box<1>({1}, {4}), + {{a, b}})); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({-2}, {3}), + /*grid_cell_template=*/Box<1>({-3}, {4}), + {{a, b}})); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({-2}, {3}), + /*grid_cell_template=*/Box<1>({1}, {4}), + {{a, b}})); + EXPECT_FALSE(SplitBoxByGrid(/*input=*/Box<1>({2}, {3}), + /*grid_cell_template=*/Box<1>({0}, {5}), + {{a, b}})); +} + +TEST(SplitBoxByGridTest, SplitRank1) { + Box<> a(1), b(1); + EXPECT_TRUE(SplitBoxByGrid(/*input=*/Box<1>({3}), + /*grid_cell_template=*/Box<1>({2}), {{a, b}})); + EXPECT_EQ(Box<1>({0}, {2}), a); + EXPECT_EQ(Box<1>({2}, {1}), b); +} + +TEST(SplitBoxByGridTest, SplitRank2) { + Box<> a(2), b(2); + EXPECT_TRUE(SplitBoxByGrid(/*input=*/Box<2>({3, 10}), + /*grid_cell_template=*/Box<2>({2, 3}), {{a, b}})); + EXPECT_EQ(Box<2>({0, 0}, {3, 6}), a); + EXPECT_EQ(Box<2>({0, 6}, {3, 4}), b); +} + +} // namespace diff --git a/third_party/python/python_configure.bzl b/third_party/python/python_configure.bzl index 4d6bca862..74bf4e708 100644 --- a/third_party/python/python_configure.bzl +++ b/third_party/python/python_configure.bzl @@ -205,7 +205,8 @@ def _get_python_include(repository_ctx, python_bin): "import importlib; " + "import importlib.util; " + "print(importlib.import_module('distutils.sysconfig').get_python_inc() " + - "if importlib.util.find_spec('distutils.sysconfig') " + + "if (importlib.util.find_spec('distutils') and " + + " importlib.util.find_spec('distutils.sysconfig')) " + "else importlib.import_module('sysconfig').get_path('include'))", ], error_msg = "Problem getting python include path.", diff --git a/tools/ci/cibuildwheel.py b/tools/ci/cibuildwheel.py index 855e83659..6e4562946 100755 --- a/tools/ci/cibuildwheel.py +++ b/tools/ci/cibuildwheel.py @@ -96,7 +96,8 @@ def run(args, extra_args): env["CIBW_ARCHS_MACOS"] = "x86_64 arm64" env["CIBW_SKIP"] = ( - "cp27-* cp35-* cp36-* cp37-* cp38-* pp* *_i686 *-win32 *-musllinux*" + "cp27-* cp35-* cp36-* cp37-* cp38-* cp313-* pp* *_i686 *-win32" + " *-musllinux*" ) env["CIBW_TEST_COMMAND"] = ( "python -m pytest {project}/python/tensorstore/tests -vv -s"