From e1476572e6d313a2f2ba0a2cd6d2892fed6e0fb2 Mon Sep 17 00:00:00 2001 From: pdmurray Date: Sat, 2 Dec 2023 13:02:38 -0800 Subject: [PATCH] In the midst of fixing tests broken by this check... --- versioned_hdf5/backend.py | 142 ++++++++++++++++++++++++++++---------- 1 file changed, 104 insertions(+), 38 deletions(-) diff --git a/versioned_hdf5/backend.py b/versioned_hdf5/backend.py index 72899674..8a7fd1bb 100644 --- a/versioned_hdf5/backend.py +++ b/versioned_hdf5/backend.py @@ -1,3 +1,5 @@ +from typing import Dict + import numpy as np from numpy.testing import assert_array_equal from h5py._hl.filters import guess_chunk @@ -132,33 +134,23 @@ def write_dataset(f, name, data, chunks=None, dtype=None, compression=None, data_hash = hashtable.hash(data_s) if data_hash in hashtable: - slices[s] = hashtable[data_hash] + hashed_slice = hashtable[data_hash] + slices[s] = hashed_slice + + _verify_new_chunk_reuse( + raw_data=ds, + new_data=data, + data_hash=data_hash, + hashed_slice=hashed_slice, + chunk_being_written=data_s, + slices_to_write=slices_to_write, + ) + else: slices[s] = raw_slice hashtable[data_hash] = raw_slice slices_to_write[raw_slice] = s - # Check that the data from the slice in the hashtable matches the - # data we are attempting to write - if raw_slice in slices_to_write: - reused_s = data[slices_to_write[raw_slice].raw] - else: - reused_slice = Tuple( - raw_slice, - *[slice(None, None) for _ in data.shape[1:]] - ) - reused_s = ds[reused_slice.raw] - - assert_array_equal( - reused_s, - data_s, - err_msg=( - f"Hash {data_hash} of existing data chunk {reused_s} " - f"matches the hash of new data chunk {data_s}, but data " - "does not." - ) - ) - ds.resize((old_shape[0] + len(slices_to_write)*chunk_size,) + chunks[1:]) for raw_slice, s in slices_to_write.items(): data_s = data[s.raw] @@ -166,6 +158,60 @@ def write_dataset(f, name, data, chunks=None, dtype=None, compression=None, ds[idx.raw] = data[s.raw] return slices +def _verify_new_chunk_reuse( + raw_data: np.ndarray, + new_data: np.ndarray, + data_hash: bytes, + hashed_slice: Tuple, + chunk_being_written: Tuple, + slices_to_write: Dict[Tuple, np.ndarray], +): + """Check that the data corresponding to the slice in the hashtable matches the data + that is going to be written. + + Raises a ValueError if the data reference by the hashed slice doesn't match the + underlying raw data. + + Parameters + ---------- + raw_data : np.ndarray + Raw data that already exists in the file + new_data : np.ndarray + New data that we are writing + data_hash : bytes + Hash of the new data chunk + hashed_slice : Tuple + Slice that is stored in the hash table for the given data_hash + chunk_being_written : np.ndarray + New data chunk to be written + slices_to_write : Tuple + Dict of slices which will be written + """ + if hashed_slice in slices_to_write: + # The hash table contains a slice we will write but haven't yet; grab the + # chunk from the new data being written + reused_chunk = new_data[slices_to_write[hashed_slice].raw] + else: + # The hash table contains a slice that was written in a previous + # write operation; grab that chunk from the existing raw data + reused_slice = Tuple( + hashed_slice, + *[slice(0, size) for size in new_data.shape[1:]] + ) + reused_chunk = raw_data[reused_slice.raw] + + assert_array_equal( + reused_chunk, + chunk_being_written, + err_msg=( + f"Hash {data_hash} of existing data chunk {reused_chunk} " + f"matches the hash of new data chunk {chunk_being_written}, " + "but data does not." + ) + ) + + + def write_dataset_chunks(f, name, data_dict, shape=None): """ data_dict should be a dictionary mapping chunk_size index to either an @@ -175,42 +221,62 @@ def write_dataset_chunks(f, name, data_dict, shape=None): if name not in f['_version_data']: raise NotImplementedError("Use write_dataset() if the dataset does not yet exist") - ds = f['_version_data'][name]['raw_data'] - chunks = tuple(ds.attrs['chunks']) + raw_data = f['_version_data'][name]['raw_data'] + chunks = tuple(raw_data.attrs['chunks']) chunk_size = chunks[0] if shape is None: shape = tuple(max(c.args[i].stop for c in data_dict) for i in range(len(chunks))) - # all_chunks = list(ChunkSize(chunks).indices(shape)) - # for c in data_dict: - # if c not in all_chunks: - # raise ValueError(f"data_dict contains extra chunks ({c})") with Hashtable(f, name) as hashtable: slices = {i: None for i in data_dict} + + # Mapping from slices in the dataset after this write is complete to chunks of + # the new data which will be written data_to_write = {} - for chunk, data_s in data_dict.items(): - if not isinstance(data_s, (slice, tuple, Tuple, Slice)) and data_s.dtype != ds.dtype: - raise ValueError(f"dtypes do not match ({data_s.dtype} != {ds.dtype})") - idx = hashtable.largest_index + # Mapping from slices in the dataset after this write is complete to ndarray + # chunks of the new data which will be written + slices_to_write = {} + for chunk, data_s in data_dict.items(): if isinstance(data_s, (slice, tuple, Tuple, Slice)): slices[chunk] = ndindex(data_s) else: + if data_s.dtype != raw_data.dtype: + raise ValueError( + f"dtypes do not match ({data_s.dtype} != {raw_data.dtype})" + ) + + idx = hashtable.largest_index raw_slice = Slice(idx*chunk_size, idx*chunk_size + data_s.shape[0]) data_hash = hashtable.hash(data_s) - raw_slice2 = hashtable.setdefault(data_hash, raw_slice) - if raw_slice2 == raw_slice: + + if data_hash in hashtable: + hashed_slice = hashtable[data_hash] + slices[chunk] = hashed_slice + + _verify_new_chunk_reuse( + raw_data=raw_data, + new_data=data_s, + data_hash=data_hash, + hashed_slice=hashed_slice, + chunk_being_written=data_s, + slices_to_write=slices_to_write, + ) + + else: + slices[chunk] = raw_slice + hashtable[data_hash] = raw_slice data_to_write[raw_slice] = data_s - slices[chunk] = raw_slice2 + slices_to_write[raw_slice] = chunk assert None not in slices.values() - old_shape = ds.shape - ds.resize((old_shape[0] + len(data_to_write)*chunk_size,) + chunks[1:]) + old_shape = raw_data.shape + raw_data.resize((old_shape[0] + len(data_to_write)*chunk_size,) + chunks[1:]) for raw_slice, data_s in data_to_write.items(): c = (raw_slice.raw,) + tuple(slice(0, i) for i in data_s.shape[1:]) - ds[c] = data_s + raw_data[c] = data_s return slices def create_virtual_dataset(f, version_name, name, shape, slices, attrs=None, fillvalue=None):