From 1c7c7772d5c6b2aaa4e8a76254d0dc002a5ee9cd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 28 Oct 2024 12:53:11 +0000 Subject: [PATCH] [InMemoryDataset redesign] EntireChunksMapper --- versioned_hdf5/subchunk_map.py | 46 +++++++++++++++++++++ versioned_hdf5/tests/test_subchunk_map.py | 50 +++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/versioned_hdf5/subchunk_map.py b/versioned_hdf5/subchunk_map.py index 7838ea37..24e8c79b 100755 --- a/versioned_hdf5/subchunk_map.py +++ b/versioned_hdf5/subchunk_map.py @@ -514,6 +514,52 @@ def value_view_idx(self) -> slice | None: return None +@cython.cclass +class EntireChunksMapper(BasicChunkMapper): + """Special mapper that selects all points on the chunks selected by another mapper. + + This is used to load the entire chunk for the purpose of caching when the + actual selection may only target part of it. + """ + + _chunks_indexer: slice | NDArray[np.intp] + + def __init__(self, other: IndexChunkMapper): + self._chunks_indexer = other.chunks_indexer() + super().__init__(other.chunk_indices, other.dset_size, other.chunk_size) + + @cython.ccall + def chunk_submap( + self, chunk_idx: hsize_t + ) -> tuple[Slice, AnySlicer | DropAxis, AnySlicer]: + raise NotImplementedError( # pragma: nocover + "not used in legacy as_subchunk_map" + ) + + @cython.cfunc + @cython.nogil + @cython.exceptval(check=False) + def _read_many_slices_param( + self, chunk_idx: hsize_t + ) -> tuple[hsize_t, hsize_t, hsize_t, hsize_t]: + chunk_start, chunk_stop = self._chunk_start_stop(chunk_idx) + + return ( + 0, # chunk_sub_start + chunk_start, # value_sub_start + chunk_stop - chunk_start, # count + 1, # chunk_sub_stride + ) + + @cython.ccall + def chunks_indexer(self): + return self._chunks_indexer + + @cython.ccall + def whole_chunks_idxidx(self): + return slice(0, len(self.chunk_indices), 1) + + @cython.cclass class IntegerArrayMapper(IndexChunkMapper): """IndexChunkMapper for one-dimensional fancy integer array indices. diff --git a/versioned_hdf5/tests/test_subchunk_map.py b/versioned_hdf5/tests/test_subchunk_map.py index 37546ec3..feb5e4a1 100644 --- a/versioned_hdf5/tests/test_subchunk_map.py +++ b/versioned_hdf5/tests/test_subchunk_map.py @@ -14,6 +14,7 @@ from ..slicetools import read_many_slices from ..subchunk_map import ( DROP_AXIS, + EntireChunksMapper, SliceMapper, TransferType, as_subchunk_map, @@ -409,6 +410,55 @@ def test_read_many_slices_param_nd(args): assert_array_equal(getitem_dst3, expect) +@pytest.mark.slow +@given(idx_shape_chunks_st(max_ndim=1)) +@hypothesis.settings(max_examples=max_examples, deadline=None) +def test_entire_chunks_mapper(args): + idx, shape, chunks = args + _, mappers = index_chunk_mappers(idx, shape, chunks) + if not mappers: + return # Early exit for empty index + assert len(shape) == len(chunks) == len(mappers) == 1 + chunk_size = chunks[0] + orig_mapper = mappers[0] + entire_mapper = EntireChunksMapper(orig_mapper) + + # Test chunks_indexer() and whole_chunks_idxidx() + all_chunks = np.arange(orig_mapper.n_chunks) + sel_chunks = all_chunks[orig_mapper.chunks_indexer()] + np.testing.assert_array_equal( + all_chunks[entire_mapper.chunks_indexer()], sel_chunks + ) + np.testing.assert_array_equal( + sel_chunks[entire_mapper.whole_chunks_idxidx()], sel_chunks + ) + + # Test read_many_slices_params() + entire_mapper = EntireChunksMapper(orig_mapper) + entire_slices, entire_chunks_to_slices = entire_mapper.read_many_slices_params() + assert entire_chunks_to_slices is None + + n_sel_chunks = len(orig_mapper.chunk_indices) + expect_count = np.full(n_sel_chunks, chunk_size, dtype=np_hsize_t) + if orig_mapper.chunk_indices[-1] == orig_mapper.n_chunks - 1: + expect_count[-1] = orig_mapper.last_chunk_size + + assert_array_equal( + entire_slices, + np.stack( + [ + np.zeros(n_sel_chunks, dtype=np_hsize_t), # src_start, + np.asarray(orig_mapper.chunk_indices) * chunk_size, # dst_start, + expect_count, # count + np.ones(n_sel_chunks, dtype=np_hsize_t), # src_stride, + np.ones(n_sel_chunks, dtype=np_hsize_t), # dst_stride, + ], + axis=1, + ), + strict=True, + ) + + def test_simplify_indices(): """Test that a fancy index that can be redefined globally as a slice results in a SliceMapper