Skip to content

Commit

Permalink
Update io.py with auto-chunk sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
djps authored Oct 19, 2023
1 parent 4f4db2d commit 3ee0bee
Showing 1 changed file with 47 additions and 31 deletions.
78 changes: 47 additions & 31 deletions kwave/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_h5_literals():
return literals


def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=None):
def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=None, auto_chunk=False):
# get literals
h5_literals = get_h5_literals()

Expand All @@ -78,36 +78,45 @@ def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=No

# check size of matrix and set chunk size and compression level
if dims == 3:
# set chunk size to Nx * Ny
chunk_size = [Nx, Ny, 1]
if (auto_chunk):
chunk_size = True
else:
# set chunk size to Nx * Ny
chunk_size = [Nx, Ny, 1]
elif dims == 2:
# set chunk size to Nx
chunk_size = [Nx, 1, 1]
if (auto_chunk):
chunk_size = True
else:
# set chunk size to Nx
chunk_size = [Nx, 1, 1]
elif dims <= 1:
# check that the matrix size is greater than 1 MB
one_mb = (1024 ** 2) / 8
if matrix.size > one_mb:
# set chunk size to 1 MB
if Nx > Ny:
chunk_size = [one_mb, 1, 1]
elif Ny > Nz:
chunk_size = [1, one_mb, 1]
else:
chunk_size = [1, 1, one_mb]
if (auto_chunk):
chunk_size = True
else:

# set no compression
compression_level = 0

# set chunk size to grid size
if matrix.size == 1:
chunk_size = (1, 1, 1)
elif Nx > Ny:
chunk_size = (Nx, 1, 1)
elif Ny > Nz:
chunk_size = (1, Ny, 1)
# check that the matrix size is greater than 1 MB
one_mb = (1024 ** 2) / 8
if matrix.size > one_mb:
# set chunk size to 1 MB
if Nx > Ny:
chunk_size = [one_mb, 1, 1]
elif Ny > Nz:
chunk_size = [1, one_mb, 1]
else:
chunk_size = [1, 1, one_mb]
else:
chunk_size = (1, 1, Nz)

# set no compression
compression_level = 0

# set chunk size to grid size
if matrix.size == 1:
chunk_size = (1, 1, 1)
elif Nx > Ny:
chunk_size = (Nx, 1, 1)
elif Ny > Nz:
chunk_size = (1, Ny, 1)
else:
chunk_size = (1, 1, Nz)
else:
# throw error for unknown matrix size
raise ValueError('Input matrix must have 1, 2 or 3 dimensions.')
Expand Down Expand Up @@ -179,10 +188,17 @@ def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=No
raise NotImplementedError('Currently there is no support for saving 2D complex matrices.')

# allocate a holder for the new matrix within the file
opts = {
'dtype': data_type_matlab,
'chunks': tuple(chunk_size)
}
if isinstance(chunk_size, bool):
opts = {
'dtype': data_type_matlab,
'chunks': True
}
else:
opts = {
'dtype': data_type_matlab,
'chunks': tuple(chunk_size)
}

if compression_level != 0:
# use compression
opts['compression'] = compression_level
Expand Down

0 comments on commit 3ee0bee

Please sign in to comment.