Skip to content

Commit

Permalink
Change JsonSpecifiedCompressor parameter types to match Zarr v3 codecs
Browse files Browse the repository at this point in the history
This simplifies reuse of riegeli Reader/Writer implementations without
needing templates, and is the first step in unifying
JsonSpecifiedCompressor with Zarr v3 bytes to bytes codecs.

PiperOrigin-RevId: 691924016
Change-Id: I0bd636a581a4147f7885bc2e25ad0de03d1b33ab
  • Loading branch information
jbms authored and copybara-github committed Oct 31, 2024
1 parent 00319ae commit 948d5f3
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 86 deletions.
30 changes: 19 additions & 11 deletions tensorstore/driver/n5/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ Result<SharedArray<const void>> DecodeChunk(const N5Metadata& metadata,
tensorstore::StrCat("Expected header of length ", header_size,
", but chunk has size ", buffer.size()));
}
std::unique_ptr<riegeli::Reader> reader =
std::make_unique<riegeli::CordReader<>>(&buffer);
riegeli::CordReader<> base_reader(&buffer);
riegeli::Reader* reader = &base_reader;
uint16_t mode;
uint16_t num_dims;
if (!riegeli::ReadBigEndian16(*reader, mode) ||
Expand Down Expand Up @@ -271,9 +271,11 @@ Result<SharedArray<const void>> DecodeChunk(const N5Metadata& metadata,
tensorstore::span(metadata.chunk_shape)));
}
}
std::unique_ptr<riegeli::Reader> compressed_reader;
if (metadata.compressor) {
reader = metadata.compressor->GetReader(std::move(reader),
metadata.dtype.size());
compressed_reader =
metadata.compressor->GetReader(base_reader, metadata.dtype.size());
reader = compressed_reader.get();
}
SharedArray<const void> decoded_array;
if (absl::c_equal(encoded_shape, metadata.chunk_shape)) {
Expand All @@ -292,18 +294,19 @@ Result<SharedArray<const void>> DecodeChunk(const N5Metadata& metadata,
*reader, endian::big, fortran_order, partial_decoded_array));
decoded_array = std::move(array);
}
if (!reader->VerifyEndAndClose()) {
return reader->status();
if (compressed_reader && !compressed_reader->VerifyEndAndClose()) {
return compressed_reader->status();
}
if (!base_reader.VerifyEndAndClose()) return base_reader.status();
return decoded_array;
}

Result<absl::Cord> EncodeChunk(const N5Metadata& metadata,
SharedArrayView<const void> array) {
assert(absl::c_equal(metadata.chunk_shape, array.shape()));
absl::Cord encoded;
std::unique_ptr<riegeli::Writer> writer =
std::make_unique<riegeli::CordWriter<>>(&encoded);
riegeli::CordWriter<> base_writer(&encoded);
riegeli::Writer* writer = &base_writer;

// Write header
// mode: 0x0 = default
Expand All @@ -316,17 +319,22 @@ Result<absl::Cord> EncodeChunk(const N5Metadata& metadata,
return writer->status();
}
}
std::unique_ptr<riegeli::Writer> compressed_writer;
if (metadata.compressor) {
writer = metadata.compressor->GetWriter(std::move(writer),
metadata.dtype.size());
compressed_writer =
metadata.compressor->GetWriter(base_writer, metadata.dtype.size());
writer = compressed_writer.get();
}
// Always write chunks as full size, to avoid race conditions or data loss
// in the event of a concurrent resize.
if (!internal::EncodeArrayEndian(array, endian::big, fortran_order,
*writer)) {
return writer->status();
}
if (!writer->Close()) return writer->status();
if (compressed_writer && !compressed_writer->Close()) {
return compressed_writer->status();
}
if (!base_writer.Close()) return base_writer.status();
return encoded;
}

Expand Down
37 changes: 21 additions & 16 deletions tensorstore/driver/zarr/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,20 @@ Result<absl::InlinedVector<SharedArray<const void>, 1>> DecodeChunk(
std::reverse(c_order_shape, c_order_shape + metadata.rank);
c_order_shape_span = span(&c_order_shape[0], full_chunk_shape.size());
}
std::unique_ptr<riegeli::Reader> reader =
std::make_unique<riegeli::CordReader<absl::Cord>>(std::move(buffer));
riegeli::CordReader<absl::Cord> base_reader(std::move(buffer));
std::unique_ptr<riegeli::Reader> compressed_reader;
if (metadata.compressor) {
reader = metadata.compressor->GetReader(
std::move(reader), metadata.dtype.bytes_per_outer_element);
compressed_reader = metadata.compressor->GetReader(
base_reader, metadata.dtype.bytes_per_outer_element);
}
TENSORSTORE_ASSIGN_OR_RETURN(
auto array, internal::DecodeArrayEndian(*reader, dtype_field.dtype,
c_order_shape_span,
dtype_field.endian, c_order));
auto array, internal::DecodeArrayEndian(
compressed_reader ? *compressed_reader : base_reader,
dtype_field.dtype, c_order_shape_span,
dtype_field.endian, c_order));
if (compressed_reader) {
if (!base_reader.VerifyEndAndClose()) return base_reader.status();
}
if (metadata.order == fortran_order) {
std::reverse(array.shape().begin(),
array.shape().begin() + metadata.rank);
Expand All @@ -516,11 +520,12 @@ Result<absl::InlinedVector<SharedArray<const void>, 1>> DecodeChunk(
return field_arrays;
}
if (metadata.compressor) {
std::unique_ptr<riegeli::Reader> reader =
std::make_unique<riegeli::CordReader<absl::Cord>>(std::move(buffer));
reader = metadata.compressor->GetReader(
std::move(reader), metadata.dtype.bytes_per_outer_element);
TENSORSTORE_RETURN_IF_ERROR(riegeli::ReadAll(std::move(reader), buffer));
riegeli::CordReader<absl::Cord> base_reader(std::move(buffer));
auto compressed_reader = metadata.compressor->GetReader(
base_reader, metadata.dtype.bytes_per_outer_element);
TENSORSTORE_RETURN_IF_ERROR(
riegeli::ReadAll(std::move(compressed_reader), buffer));
if (!base_reader.VerifyEndAndClose()) return base_reader.status();
}
if (static_cast<Index>(buffer.size()) !=
metadata.chunk_layout.bytes_per_chunk) {
Expand Down Expand Up @@ -605,12 +610,12 @@ Result<absl::Cord> EncodeChunk(
}
if (metadata.compressor) {
absl::Cord encoded;
std::unique_ptr<riegeli::Writer> writer =
std::make_unique<riegeli::CordWriter<absl::Cord*>>(&encoded);
writer = metadata.compressor->GetWriter(
std::move(writer), metadata.dtype.bytes_per_outer_element);
riegeli::CordWriter<absl::Cord*> base_writer(&encoded);
auto writer = metadata.compressor->GetWriter(
base_writer, metadata.dtype.bytes_per_outer_element);
TENSORSTORE_RETURN_IF_ERROR(
riegeli::Write(std::move(output), std::move(writer)));
if (!base_writer.Close()) return base_writer.status();
return encoded;
}
return output;
Expand Down
1 change: 1 addition & 0 deletions tensorstore/internal/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tensorstore_cc_library(
"//tensorstore/internal:intrusive_ptr",
"//tensorstore/internal:json_registry",
"//tensorstore/util:status",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
"@com_google_riegeli//riegeli/bytes:cord_reader",
Expand Down
17 changes: 8 additions & 9 deletions tensorstore/internal/compression/blosc_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ namespace {
class BloscDeferredWriter : public riegeli::CordWriter<absl::Cord> {
public:
explicit BloscDeferredWriter(blosc::Options options,
std::unique_ptr<riegeli::Writer> base_writer)
riegeli::Writer& base_writer)
: CordWriter(riegeli::CordWriterBase::Options().set_max_block_size(
std::numeric_limits<size_t>::max())),
options_(std::move(options)),
base_writer_(std::move(base_writer)) {}
base_writer_(base_writer) {}

void Done() override {
CordWriter::Done();
Expand All @@ -55,7 +55,7 @@ class BloscDeferredWriter : public riegeli::CordWriter<absl::Cord> {
Fail(std::move(output).status());
return;
}
auto status = riegeli::Write(*std::move(output), std::move(base_writer_));
auto status = riegeli::Write(*std::move(output), base_writer_);
if (!status.ok()) {
Fail(std::move(status));
return;
Expand All @@ -64,23 +64,22 @@ class BloscDeferredWriter : public riegeli::CordWriter<absl::Cord> {

private:
blosc::Options options_;
std::unique_ptr<riegeli::Writer> base_writer_;
riegeli::Writer& base_writer_;
};

} // namespace

std::unique_ptr<riegeli::Writer> BloscCompressor::GetWriter(
std::unique_ptr<riegeli::Writer> base_writer, size_t element_bytes) const {
riegeli::Writer& base_writer, size_t element_bytes) const {
return std::make_unique<BloscDeferredWriter>(
blosc::Options{codec.c_str(), level, shuffle, blocksize, element_bytes},
std::move(base_writer));
base_writer);
}

std::unique_ptr<riegeli::Reader> BloscCompressor::GetReader(
std::unique_ptr<riegeli::Reader> base_reader, size_t element_bytes) const {
riegeli::Reader& base_reader, size_t element_bytes) const {
auto output = riegeli::ReadAll(
std::move(base_reader),
[](absl::string_view input) -> absl::StatusOr<std::string> {
base_reader, [](absl::string_view input) -> absl::StatusOr<std::string> {
auto output = blosc::Decode(input);
if (!output.ok()) return std::move(output).status();
return *std::move(output);
Expand Down
6 changes: 2 additions & 4 deletions tensorstore/internal/compression/blosc_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ namespace internal {
class BloscCompressor : public JsonSpecifiedCompressor {
public:
std::unique_ptr<riegeli::Writer> GetWriter(
std::unique_ptr<riegeli::Writer> base_writer,
size_t element_bytes) const override;
riegeli::Writer& base_writer, size_t element_bytes) const override;

std::unique_ptr<riegeli::Reader> GetReader(
std::unique_ptr<riegeli::Reader> base_reader,
size_t element_bytes) const override;
riegeli::Reader& base_reader, size_t element_bytes) const override;

static constexpr auto CodecBinder() {
namespace jb = tensorstore::internal_json_binding;
Expand Down
12 changes: 6 additions & 6 deletions tensorstore/internal/compression/bzip2_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ namespace tensorstore {
namespace internal {

std::unique_ptr<riegeli::Writer> Bzip2Compressor::GetWriter(
std::unique_ptr<riegeli::Writer> base_writer, size_t element_bytes) const {
using Writer = riegeli::Bzip2Writer<std::unique_ptr<riegeli::Writer>>;
riegeli::Writer& base_writer, size_t element_bytes) const {
using Writer = riegeli::Bzip2Writer<riegeli::Writer*>;
Writer::Options options;
options.set_compression_level(level);
return std::make_unique<Writer>(std::move(base_writer), options);
return std::make_unique<Writer>(&base_writer, options);
}

std::unique_ptr<riegeli::Reader> Bzip2Compressor::GetReader(
std::unique_ptr<riegeli::Reader> base_reader, size_t element_bytes) const {
using Reader = riegeli::Bzip2Reader<std::unique_ptr<riegeli::Reader>>;
return std::make_unique<Reader>(std::move(base_reader));
riegeli::Reader& base_reader, size_t element_bytes) const {
using Reader = riegeli::Bzip2Reader<riegeli::Reader*>;
return std::make_unique<Reader>(&base_reader);
}

} // namespace internal
Expand Down
6 changes: 2 additions & 4 deletions tensorstore/internal/compression/bzip2_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ class Bzip2Compressor : public internal::JsonSpecifiedCompressor,
public Bzip2Options {
public:
std::unique_ptr<riegeli::Writer> GetWriter(
std::unique_ptr<riegeli::Writer> base_writer,
size_t element_bytes) const override;
riegeli::Writer& base_writer, size_t element_bytes) const override;

virtual std::unique_ptr<riegeli::Reader> GetReader(
std::unique_ptr<riegeli::Reader> base_reader,
size_t element_bytes) const override;
riegeli::Reader& base_reader, size_t element_bytes) const override;
};

} // namespace internal
Expand Down
10 changes: 6 additions & 4 deletions tensorstore/internal/compression/json_specified_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,27 @@ JsonSpecifiedCompressor::~JsonSpecifiedCompressor() = default;
absl::Status JsonSpecifiedCompressor::Encode(const absl::Cord& input,
absl::Cord* output,
size_t element_bytes) const {
auto base_writer = std::make_unique<riegeli::CordWriter<>>(
riegeli::CordWriter<> base_writer(
output, riegeli::CordWriterBase::Options().set_append(true));
auto writer = GetWriter(std::move(base_writer), element_bytes);
auto writer = GetWriter(base_writer, element_bytes);

TENSORSTORE_RETURN_IF_ERROR(
riegeli::Write(input, std::move(writer)),
MaybeConvertStatusTo(_, absl::StatusCode::kInvalidArgument));
if (!base_writer.Close()) return base_writer.status();
return absl::OkStatus();
}

absl::Status JsonSpecifiedCompressor::Decode(const absl::Cord& input,
absl::Cord* output,
size_t element_bytes) const {
auto base_reader = std::make_unique<riegeli::CordReader<>>(&input);
auto reader = GetReader(std::move(base_reader), element_bytes);
riegeli::CordReader<> base_reader(&input);
auto reader = GetReader(base_reader, element_bytes);

TENSORSTORE_RETURN_IF_ERROR(
riegeli::ReadAndAppendAll(std::move(reader), *output),
MaybeConvertStatusTo(_, absl::StatusCode::kInvalidArgument));
if (!base_reader.VerifyEndAndClose()) return base_reader.status();
return absl::OkStatus();
}

Expand Down
5 changes: 3 additions & 2 deletions tensorstore/internal/compression/json_specified_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstddef>
#include <memory>

#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "riegeli/bytes/reader.h"
Expand All @@ -40,7 +41,7 @@ class JsonSpecifiedCompressor

/// Returns a writer that encodes the compression format.
virtual std::unique_ptr<riegeli::Writer> GetWriter(
std::unique_ptr<riegeli::Writer> base_writer,
riegeli::Writer& base_writer ABSL_ATTRIBUTE_LIFETIME_BOUND,
size_t element_bytes) const = 0;

/// Returns a reader that decodes the compression format.
Expand All @@ -50,7 +51,7 @@ class JsonSpecifiedCompressor
/// compressor, e.g. `4` if `input` is actually a sequence of `int32_t`
/// values. Must be `> 0`.
virtual std::unique_ptr<riegeli::Reader> GetReader(
std::unique_ptr<riegeli::Reader> base_reader,
riegeli::Reader& base_reader ABSL_ATTRIBUTE_LIFETIME_BOUND,
size_t element_bytes) const = 0;

/// Encodes `input`.
Expand Down
12 changes: 6 additions & 6 deletions tensorstore/internal/compression/xz_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ namespace tensorstore {
namespace internal {

std::unique_ptr<riegeli::Writer> XzCompressor::GetWriter(
std::unique_ptr<riegeli::Writer> base_writer, size_t element_bytes) const {
using Writer = riegeli::XzWriter<std::unique_ptr<riegeli::Writer>>;
riegeli::Writer& base_writer, size_t element_bytes) const {
using Writer = riegeli::XzWriter<riegeli::Writer*>;
Writer::Options options;
options.set_container(Writer::Container::kXz);
options.set_check(static_cast<Writer::Check>(check));
options.set_compression_level(level);
options.set_extreme(extreme);
return std::make_unique<Writer>(std::move(base_writer), options);
return std::make_unique<Writer>(&base_writer, options);
}

std::unique_ptr<riegeli::Reader> XzCompressor::GetReader(
std::unique_ptr<riegeli::Reader> base_reader, size_t element_bytes) const {
using Reader = riegeli::XzReader<std::unique_ptr<riegeli::Reader>>;
riegeli::Reader& base_reader, size_t element_bytes) const {
using Reader = riegeli::XzReader<riegeli::Reader*>;
Reader::Options options;
options.set_container(Reader::Container::kXzOrLzma);
options.set_concatenate(true);
return std::make_unique<Reader>(std::move(base_reader), options);
return std::make_unique<Reader>(&base_reader, options);
}

} // namespace internal
Expand Down
6 changes: 2 additions & 4 deletions tensorstore/internal/compression/xz_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ struct XzOptions {
class XzCompressor : public JsonSpecifiedCompressor, public XzOptions {
public:
std::unique_ptr<riegeli::Writer> GetWriter(
std::unique_ptr<riegeli::Writer> base_writer,
size_t element_bytes) const override;
riegeli::Writer& base_writer, size_t element_bytes) const override;

std::unique_ptr<riegeli::Reader> GetReader(
std::unique_ptr<riegeli::Reader> base_reader,
size_t element_bytes) const override;
riegeli::Reader& base_reader, size_t element_bytes) const override;
};

} // namespace internal
Expand Down
12 changes: 6 additions & 6 deletions tensorstore/internal/compression/zlib_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ namespace tensorstore {
namespace internal {

std::unique_ptr<riegeli::Writer> ZlibCompressor::GetWriter(
std::unique_ptr<riegeli::Writer> base_writer, size_t element_bytes) const {
using Writer = riegeli::ZlibWriter<std::unique_ptr<riegeli::Writer>>;
riegeli::Writer& base_writer, size_t element_bytes) const {
using Writer = riegeli::ZlibWriter<riegeli::Writer*>;
Writer::Options options;
if (level != -1) options.set_compression_level(level);
options.set_header(use_gzip_header ? Writer::Header::kGzip
: Writer::Header::kZlib);
return std::make_unique<Writer>(std::move(base_writer), options);
return std::make_unique<Writer>(&base_writer, options);
}

std::unique_ptr<riegeli::Reader> ZlibCompressor::GetReader(
std::unique_ptr<riegeli::Reader> base_reader, size_t element_bytes) const {
using Reader = riegeli::ZlibReader<std::unique_ptr<riegeli::Reader>>;
riegeli::Reader& base_reader, size_t element_bytes) const {
using Reader = riegeli::ZlibReader<riegeli::Reader*>;
Reader::Options options;
options.set_header(use_gzip_header ? Reader::Header::kGzip
: Reader::Header::kZlib);
return std::make_unique<Reader>(std::move(base_reader), options);
return std::make_unique<Reader>(&base_reader, options);
}

} // namespace internal
Expand Down
6 changes: 2 additions & 4 deletions tensorstore/internal/compression/zlib_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ namespace internal {
class ZlibCompressor : public JsonSpecifiedCompressor, public zlib::Options {
public:
std::unique_ptr<riegeli::Writer> GetWriter(
std::unique_ptr<riegeli::Writer> base_writer,
size_t element_bytes) const override;
riegeli::Writer& base_writer, size_t element_bytes) const override;

virtual std::unique_ptr<riegeli::Reader> GetReader(
std::unique_ptr<riegeli::Reader> base_reader,
size_t element_bytes) const override;
riegeli::Reader& base_reader, size_t element_bytes) const override;
};

} // namespace internal
Expand Down
Loading

0 comments on commit 948d5f3

Please sign in to comment.