Skip to content

Commit

Permalink
partition() no longer needs an array shape to chunk indices
Browse files Browse the repository at this point in the history
Added two tests for appending to a dataset
  • Loading branch information
peytondmurray committed Mar 15, 2024
1 parent 89f4577 commit 48ea00f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
29 changes: 10 additions & 19 deletions versioned_hdf5/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def apply(self, f: File, name: str, version: str) -> Dict[Tuple, Tuple]:

for arr_chunk, virtual_chunk in zip(
partition(arr, chunk_size),
partition(index, chunk_size, arr.shape),
partition(index, chunk_size),
strict=True,
):
data_dict[virtual_chunk] = arr[arr_chunk.raw]
Expand Down Expand Up @@ -1022,7 +1022,7 @@ def write_to_dataset(
with Hashtable(f, name) as hashtable:
for data_slice, vchunk in zip(
partition(data, chunk_size),
partition(virtual_slice, chunk_size, shape=data.shape),
partition(virtual_slice, chunk_size),
strict=True,
):
arr = data[data_slice.raw]
Expand All @@ -1031,16 +1031,20 @@ def write_to_dataset(
if data_hash in hashtable:
slices[vchunk] = hashtable[data_hash]
else:
# TODO do I need to resize to get another chunk? Probably
new_chunk_axis_size = raw_data.shape[0] + len(data_slice.args[0])

# There's new
rchunk = Tuple(
Slice(
raw_data.shape[0],
raw_data.shape[0] + data_slice.shape[0],
new_chunk_axis_size,
),
*[Slice(None, None) for _ in raw_data.shape[1:]],
)

# Resize the dataset to include a new chunk
raw_data.resize(raw_data.shape[0] + chunk_size, axis=0)

# Map the virtual chunk to the raw data chunk
slices[vchunk] = rchunk

Expand Down Expand Up @@ -1170,7 +1174,6 @@ def split_across_unused(
def partition(
obj: Union[np.ndarray, Tuple],
chunk_size: int,
shape: Optional[tuple[int, ...]] = None,
) -> Iterator[Tuple]:
"""Break an array or a Tuple of slices into chunks of the given chunk size.
Expand All @@ -1180,14 +1183,6 @@ def partition(
Array or Tuple index to partition
chunk_size : int
The size of each partitioned chunk
shape: Optional[tuple[int]]
Shape that the index should be partitioned onto. To partition an array,
this should be the array shape; if None, the shape of the array is used.
To partition an index, this must be the shape of the array the index will
be indexing into.
This is needed because the index is chunked along the first axis, but the
shape of the other axes is needed to produce the partitioned slices.
Returns
-------
Expand All @@ -1196,13 +1191,9 @@ def partition(
"""
if isinstance(obj, np.ndarray):
index = Tuple(*[Slice(0, dim) for dim in obj.shape])
if shape is None:
shape = obj.shape
shape = obj.shape
else:
if shape is None:
raise ValueError(
"A shape must be specified to partition the index {obj} onto."
)
index = obj
shape = tuple(dim.stop for dim in index.args)

yield from ChunkSize((chunk_size,)).as_subchunks(index, shape)
51 changes: 51 additions & 0 deletions versioned_hdf5/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import h5py
import numpy as np
from h5py._hl.filters import guess_chunk
from ndindex import ndindex
from numpy.testing import assert_equal
from pytest import mark, raises

Expand Down Expand Up @@ -2667,6 +2668,29 @@ def test_append_small_dataset(tmp_path):
with vf.stage_version("r1") as sv:
sv["values"].append(np.array([1, 2, 3]))

raw_data = f["_version_data"]["values"]["raw_data"]
chunks = list(raw_data.iter_chunks())

# Raw data should have two chunks of length 10
assert_equal(
raw_data[ndindex(chunks[0]).raw],
np.array([0, 1, 2, 3, 0, 0, 0, 0, 0, 0]),
)

# Virtual datasets should only have the numbers 0 -> 3
assert len(chunks) == 1
assert_equal(
f["_version_data"]["versions"]["r0"]["values"],
np.array([0]),
)
assert_equal(
f["_version_data"]["versions"]["r1"]["values"],
np.array([0, 1, 2, 3]),
)

# 4 elements were written
assert raw_data.attrs["last_element"] == 4


@mark.append
def test_append_big_dataset(tmp_path):
Expand All @@ -2686,3 +2710,30 @@ def test_append_big_dataset(tmp_path):

with vf.stage_version("r1") as sv:
sv["values"].append(np.arange(1, 12))

raw_data = f["_version_data"]["values"]["raw_data"]
chunks = list(raw_data.iter_chunks())

# Raw data should have two chunks of length 10
assert len(chunks) == 2
assert_equal(
raw_data[ndindex(chunks[0]).raw],
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
)
assert_equal(
raw_data[ndindex(chunks[1]).raw],
np.array([10, 11, 0, 0, 0, 0, 0, 0, 0, 0]),
)

# Virtual datasets should only have the numbers 0 -> 11
assert_equal(
f["_version_data"]["versions"]["r0"]["values"],
np.array([0]),
)
assert_equal(
f["_version_data"]["versions"]["r1"]["values"],
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
)

# 12 elements were written
assert raw_data.attrs["last_element"] == 12

0 comments on commit 48ea00f

Please sign in to comment.