Skip to content

Commit

Permalink
Add support for dask distributed scheduler in quantum detector reader
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpre committed May 29, 2024
1 parent 31bd677 commit d1ee9e1
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 28 deletions.
2 changes: 1 addition & 1 deletion doc/user_guide/supported_formats/supported_formats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Protochips logfile <protochips-format>` | csv & log | Yes | No | No | No |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Quantum Detector <quantumdetector-format>` | mib | Yes | No | Yes | No |
| :ref:`Quantum Detector <quantumdetector-format>` | mib | Yes | No | Yes | Yes |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Renishaw <renishaw-format>` | wdf | Yes | No | No | No |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
Expand Down
34 changes: 27 additions & 7 deletions rsciio/quantumdetector/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

from rsciio._docstrings import (
CHUNKS_READ_DOC,
DISTRIBUTED_DOC,
FILENAME_DOC,
LAZY_DOC,
MMAP_DOC,
NAVIGATION_SHAPE,
RETURNS_DOC,
)
from rsciio.utils.distributed import memmap_distributed

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,6 +196,7 @@ def load_mib_data(
navigation_shape=None,
first_frame=None,
last_frame=None,
distributed=False,
mib_prop=None,
return_headers=False,
print_info=False,
Expand All @@ -210,6 +213,7 @@ def load_mib_data(
%s
%s
%s
%s
mib_prop : ``MIBProperties``, default=None
The ``MIBProperties`` instance of the file. If None, it will be
parsed from the file.
Expand Down Expand Up @@ -302,15 +306,21 @@ def load_mib_data(
# if it is read from TCPIP interface it needs to drop first 15 bytes which
# describe the stream size. Also watch for the coma in front of the stream.
if isinstance(mib_prop.path, str):
data = np.memmap(
mib_prop.path,
dtype=merlin_frame_dtype,
memmap_kwargs = dict(
filename=mib_prop.path,
# take into account first_frame
offset=mib_prop.offset + merlin_frame_dtype.itemsize * first_frame,
# need to use np.prod(navigation_shape) to crop number line
shape=np.prod(navigation_shape),
mode=mmap_mode,
dtype=merlin_frame_dtype,
)
if distributed:
data = memmap_distributed(chunks=chunks, key="data", **memmap_kwargs)
if not lazy:
data = data.compute()
# get_file_handle(data).close()
else:
data = np.memmap(mode=mmap_mode, **memmap_kwargs)
elif isinstance(path, bytes):
data = np.frombuffer(
path,
Expand All @@ -322,10 +332,11 @@ def load_mib_data(
else: # pragma: no cover
raise TypeError("`path` must be a str or a buffer.")

headers = data["header"]
data = data["data"]
if not distributed:
headers = data["header"]
data = data["data"]
if not return_mmap:
if lazy:
if not distributed and lazy:
if isinstance(chunks, tuple) and len(chunks) > 2:
# Since the data is reshaped later on, we set only the
# signal dimension chunks here
Expand All @@ -344,6 +355,10 @@ def load_mib_data(
data = data.rechunk(chunks)

if return_headers:
if distributed:
raise ValueError(
"Retuning headers is not supported with `distributed=True`."
)
return data, headers
else:
return data
Expand All @@ -356,6 +371,7 @@ def load_mib_data(
MMAP_DOC,
NAVIGATION_SHAPE,
_FIRST_LAST_FRAME,
DISTRIBUTED_DOC,
)


Expand Down Expand Up @@ -489,6 +505,7 @@ def file_reader(
navigation_shape=None,
first_frame=None,
last_frame=None,
distributed=False,
print_info=False,
):
"""
Expand All @@ -505,6 +522,7 @@ def file_reader(
%s
%s
%s
%s
print_info : bool
Display information about the mib file.
Expand Down Expand Up @@ -589,6 +607,7 @@ def file_reader(
navigation_shape=navigation_shape,
first_frame=first_frame,
last_frame=last_frame,
distributed=distributed,
mib_prop=mib_prop,
print_info=print_info,
return_mmap=False,
Expand Down Expand Up @@ -653,5 +672,6 @@ def file_reader(
MMAP_DOC,
NAVIGATION_SHAPE,
_FIRST_LAST_FRAME,
DISTRIBUTED_DOC,
RETURNS_DOC,
)
18 changes: 18 additions & 0 deletions rsciio/tests/test_quantumdetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,21 @@ def test_frames_in_acquisition_zero():

s = hs.load(f"{fname}.mib")
assert s.axes_manager.navigation_shape == ()


@pytest.mark.parametrize("lazy", (True, False))
def test_distributed(lazy):
s = hs.load(
TEST_DATA_DIR_UNZIPPED / "001_4x2_6bit.mib",
distributed=False,
lazy=lazy,
)
s2 = hs.load(
TEST_DATA_DIR_UNZIPPED / "001_4x2_6bit.mib",
distributed=True,
lazy=lazy,
)
if lazy:
s.compute()
s2.compute()
np.testing.assert_array_equal(s.data, s2.data)
2 changes: 1 addition & 1 deletion rsciio/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_get_date_time_from_metadata():

@pytest.mark.parametrize(
"shape",
((10, 20, 30, 512, 512),(20, 30, 512, 512), (10, 512, 512), (512, 512))
((10, 20, 30, 512, 512), (20, 30, 512, 512), (10, 512, 512), (512, 512))
)
def test_get_chunk_slice(shape):
chunk_arr, chunk = get_chunk_slice(shape=shape, chunks=-1) # 1 chunk
Expand Down
71 changes: 52 additions & 19 deletions rsciio/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with RosettaSciIO. If not, see <https://www.gnu.org/licenses/#GPL>.

import os

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -60,22 +61,19 @@ def get_chunk_slice(
)
chunks_shape = tuple([len(c) for c in chunks])
slices = np.empty(
shape=chunks_shape
+ (
len(chunks_shape),
2,
),
shape=chunks_shape + (len(chunks_shape), 2),
dtype=int,
)
for ind in np.ndindex(chunks_shape):
current_chunk = [chunk[i] for i, chunk in zip(ind, chunks)]
starts = [int(np.sum(chunk[:i])) for i, chunk in zip(ind, chunks)]
stops = [s + c for s, c in zip(starts, current_chunk)]
slices[ind] = [[start, stop] for start, stop in zip(starts, stops)]

return da.from_array(slices, chunks=(1,) * len(shape) + slices.shape[-2:]), chunks


def slice_memmap(slices, file, dtypes, shape, **kwargs):
def slice_memmap(slices, file, dtypes, shape, key=None, **kwargs):
"""
Slice a memory mapped file using a tuple of slices.
Expand All @@ -96,6 +94,8 @@ def slice_memmap(slices, file, dtypes, shape, **kwargs):
Data type of the data for :class:`numpy.memmap` function.
shape : tuple
Shape of the entire dataset. Passed to the :class:`numpy.memmap` function.
key : None, str
For structured dtype only. Specify the key of the structured dtype to use.
**kwargs : dict
Additional keyword arguments to pass to the :class:`numpy.memmap` function.
Expand All @@ -104,31 +104,36 @@ def slice_memmap(slices, file, dtypes, shape, **kwargs):
numpy.ndarray
Array of the data from the memory mapped file sliced using the provided slice.
"""
sl = np.squeeze(slices)[()]
slices_ = np.squeeze(slices)[()]
data = np.memmap(file, dtypes, shape=shape, **kwargs)
slics = tuple([slice(s[0], s[1]) for s in sl])
return data[slics]
if key is not None:
data = data[key]
slices_ = tuple([slice(s[0], s[1]) for s in slices_])
return data[slices_]


def memmap_distributed(
file,
filename,
dtype,
offset=0,
shape=None,
order="C",
chunks="auto",
block_size_limit=None,
key=None,
):
"""
Drop in replacement for py:func:`numpy.memmap` allowing for distributed loading of data.
Drop in replacement for py:func:`numpy.memmap` allowing for distributed
loading of data.
This always loads the data using dask which can be beneficial in many cases, but
may not be ideal in others. The ``chunks`` and ``block_size_limit`` are for describing an ideal chunk shape and size
as defined using the :py:func:`dask.array.core.normalize_chunks` function.
This always loads the data using dask which can be beneficial in many
cases, but may not be ideal in others. The ``chunks`` and ``block_size_limit``
are for describing an ideal chunk shape and size as defined using the
:func:`dask.array.core.normalize_chunks` function.
Parameters
----------
file : str
filename : str
Path to the file.
dtype : numpy.dtype
Data type of the data for memmap function.
Expand All @@ -142,25 +147,52 @@ def memmap_distributed(
Chunk shape. The default is "auto".
block_size_limit : int, optional
Maximum size of a block in bytes. The default is None.
key : None, str
For structured dtype only. Specify the key of the structured dtype to use.
Returns
-------
dask.array.Array
Dask array of the data from the memmaped file and with the specified chunks.
Notes
-----
Currently :func:`dask.array.map_blocks` does not allow for multiple outputs.
As a result, in case of structured dtype, the key of the structured dtype need
to be specified.
For example: with dtype = (("data", int, (128, 128)), ("sec", "<u4", 512)),
"data" or "sec" will need to be specified.
"""
# if np.dtype(dtype).fields and chunks == "auto":
# raise ValueError("Structured dtype can't be used with `chunks='auto'`.")

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

if dtype.names is not None:
# Structured dtype
array_dtype = dtype[key].base
sub_array_shape = dtype[key].shape
else:
array_dtype = dtype.base
sub_array_shape = dtype.shape

if shape is None:
unit_size = np.dtype(dtype).itemsize
shape = int(os.path.getsize(filename) / unit_size)
if not isinstance(shape, tuple):
shape = (shape,)

# Separates slices into appropriately sized chunks.
chunked_slices, data_chunks = get_chunk_slice(
shape=shape,
shape=shape + sub_array_shape,
chunks=chunks,
block_size_limit=block_size_limit,
dtype=dtype,
dtype=array_dtype,
)
num_dim = len(shape)
data = da.map_blocks(
slice_memmap,
chunked_slices,
file=file,
dtype=dtype,
file=filename,
dtype=array_dtype,
shape=shape,
order=order,
mode="r",
Expand All @@ -171,5 +203,6 @@ def memmap_distributed(
num_dim,
num_dim + 1,
), # Dask 2021.10.0 minimum to use negative indexing
key=key,
)
return data

0 comments on commit d1ee9e1

Please sign in to comment.