Skip to content

Commit

Permalink
Finished WriteOperation
Browse files Browse the repository at this point in the history
  • Loading branch information
peytondmurray committed Mar 2, 2024
1 parent 7a9f33b commit 4771c04
Showing 1 changed file with 127 additions and 30 deletions.
157 changes: 127 additions & 30 deletions versioned_hdf5/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import numpy as np
from h5py import Dataset, File, VirtualLayout, VirtualSource, h5s
Expand Down Expand Up @@ -682,11 +682,23 @@ class SetOperation(WriteOperation):
"""Operation which indexes the dataset to write data."""

def __init__(self, index: Tuple, arr: np.ndarray):
"""Initialize a SetOperation.
Parameters
----------
index : Tuple
Virtual dataset index where ``arr`` is to be written
arr : np.ndarray
Array to write to the dataset
"""
self.index = index
self.arr = arr

def __repr__(self):
return f"SetOperation:\n Index {self.index}: Data {self.arr}"

def apply(self, f: File, name: str, version: str) -> Dict[Tuple, Tuple]:
"""Write data the stored data to the dataset.
"""Write the stored data to the dataset in chunks.
Parameters
----------
Expand All @@ -703,15 +715,43 @@ def apply(self, f: File, name: str, version: str) -> Dict[Tuple, Tuple]:
Mapping between {slices in virtual dataset: slices in raw dataset}
which were written by this function.
"""
return write_to_dataset(f, version, name, self.key, self.value)
# If the shape of the array doesn't match the shape of the
# index to assign the array to, broadcast it first.
index = ndindex(self.index)
index_shape = [len(dim) for dim in index.args]
if self.arr.shape != index_shape:
arr = np.broadcast_to(self.arr, index_shape)
else:
arr = self.arr

data_dict = {}
raw_data = f["_version_data"][name]["raw_data"]
chunk_size = tuple(raw_data.attrs["chunks"])[0]

for virtual_chunk, arr_chunk in zip(
partition(arr, chunk_size), partition(index, chunk_size), strict=True
):
data_dict[virtual_chunk] = arr[arr_chunk.raw]

return write_dataset_chunks(f, name, data_dict)


class AppendOperation(WriteOperation):
"""Operation which appends data to a dataset."""

def __init__(self, value: np.ndarray):
"""Initialize a WriteOperation.
Parameters
----------
value : np.ndarray
Array to append to the dataset
"""
self.value = value

def __repr__(self):
return f"WriteOperation:\n {self.value}"

def apply(self, f: File, name: str, version: str) -> Dict[Tuple, Tuple]:
"""Append data the stored data to the dataset.
Expand Down Expand Up @@ -920,33 +960,47 @@ def append_to_dataset(
return slices


# def write_to_dataset(
# f: File, version_name: str, name: str, vslice: Tuple, data: np.ndarray | Tuple
# ) -> Dict[Tuple, Tuple]:
# """Write data to a dataset.
#
# Parameters
# ----------
# f : File
# File where data should be written
# version_name : str
# Version name for which data is to be written
# name : str
# Name of the dataset being modified
# vslice : Tuple
# Slice of the virtual dataset where data is to be written
# data : np.ndarray | Tuple
# Data to be written. If it is a Tuple, this is a slice of the raw dataset.
#
# Returns
# -------
# Dict[Tuple, Tuple]
# Mapping between {slices in virtual dataset: slices in raw dataset} which were
# written by this function.
# """
# slices = {}
#
# return slices
def write_to_dataset(
f: File, version_name: str, name: str, vslice: Tuple, data: np.ndarray | Tuple
) -> Dict[Tuple, Tuple]:
"""Write data to a dataset.
Parameters
----------
f : File
File where data should be written
version_name : str
Version name for which data is to be written
name : str
Name of the dataset being modified
vslice : Tuple
Slice of the virtual dataset where data is to be written
data : np.ndarray | Tuple
Data to be written. If it is a Tuple, this is a slice of the raw dataset.
Returns
-------
Dict[Tuple, Tuple]
Mapping between {slices in virtual dataset: slices in raw dataset} which were
written by this function.
"""
raw_data = f["_version_data"][name]["raw_data"]

if raw_data.dtype != arr.dtype:
raise ValueError(
f"dtypes of raw data ({raw_data.dtype}) does not match data to append "
f"({arr.dtype})"
)

# Get the slices from the previous version; they are reused here
prev_version_name = f["_version_data"]["versions"][version_name].attrs[
"prev_version"
]
prev_version = f["_version_data"]["versions"][prev_version_name][name]

slices = {}

return slices


def get_previous_version_slices(
Expand Down Expand Up @@ -1053,3 +1107,46 @@ def split_across_unused(
new_raw_last_chunk,
new_raw_last_chunk_data,
)


def partition(obj: Union[np.ndarray, Tuple], chunk_size: int) -> List[Tuple]:
"""Break an array or a Tuple of slices into chunks of the given chunk size.
Parameters
----------
obj : Union[np.ndarray, Tuple]
Array or Tuple index to partition
chunk_size : int
The size of each partitioned chunk
Returns
-------
List[Tuple]
A list of slices of arr that make up the chunks
"""
breakpoint()
if isinstance(obj, np.ndarray):
index = Tuple(*[Slice(0, dim) for dim in obj.shape])
else:
index = obj

# This is the size of the index along the axis to be chunked
dim0_size = len(index.args[0])

# If it all fits in one chunk, just return the whole index
if dim0_size < chunk_size:
return [index]

chunks = []
# Loop through the part of the data that fits into filled chunks
for chunk_start in range(0, dim0_size - chunk_size, step=chunk_size):
chunks.append(
Tuple(Slice(chunk_start, chunk_start + chunk_size), *index.args[1:])
)

# Partition any additional elements into a final partly-full chunk
chunk_start += chunk_size
if chunk_start < arr.shape[0]:
chunks.append(Tuple(Slice(chunk_start, arr.shape[0]), *index.args[1:]))

return chunks

0 comments on commit 4771c04

Please sign in to comment.