Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dask distributed scheduler in quantum detector reader #267

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/user_guide/supported_formats/supported_formats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Protochips logfile <protochips-format>` | csv & log | Yes | No | No | No |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Quantum Detector <quantumdetector-format>` | mib | Yes | No | Yes | No |
| :ref:`Quantum Detector <quantumdetector-format>` | mib | Yes | No | Yes | Yes |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
| :ref:`Renishaw <renishaw-format>` | wdf | Yes | No | No | No |
+---------------------------------------------------------------------+-------------------------+--------+--------+--------+-------------+
Expand Down
34 changes: 27 additions & 7 deletions rsciio/quantumdetector/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

from rsciio._docstrings import (
CHUNKS_READ_DOC,
DISTRIBUTED_DOC,
FILENAME_DOC,
LAZY_DOC,
MMAP_DOC,
NAVIGATION_SHAPE,
RETURNS_DOC,
)
from rsciio.utils.distributed import memmap_distributed

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,6 +196,7 @@ def load_mib_data(
navigation_shape=None,
first_frame=None,
last_frame=None,
distributed=False,
mib_prop=None,
return_headers=False,
print_info=False,
Expand All @@ -210,6 +213,7 @@ def load_mib_data(
%s
%s
%s
%s
mib_prop : ``MIBProperties``, default=None
The ``MIBProperties`` instance of the file. If None, it will be
parsed from the file.
Expand Down Expand Up @@ -302,15 +306,21 @@ def load_mib_data(
# if it is read from TCPIP interface it needs to drop first 15 bytes which
# describe the stream size. Also watch for the coma in front of the stream.
if isinstance(mib_prop.path, str):
data = np.memmap(
mib_prop.path,
dtype=merlin_frame_dtype,
memmap_kwargs = dict(
filename=mib_prop.path,
# take into account first_frame
offset=mib_prop.offset + merlin_frame_dtype.itemsize * first_frame,
# need to use np.prod(navigation_shape) to crop number line
shape=np.prod(navigation_shape),
mode=mmap_mode,
dtype=merlin_frame_dtype,
)
if distributed:
data = memmap_distributed(chunks=chunks, key="data", **memmap_kwargs)
if not lazy:
data = data.compute()
# get_file_handle(data).close()
else:
data = np.memmap(mode=mmap_mode, **memmap_kwargs)
elif isinstance(path, bytes):
data = np.frombuffer(
path,
Expand All @@ -322,10 +332,11 @@ def load_mib_data(
else: # pragma: no cover
raise TypeError("`path` must be a str or a buffer.")

headers = data["header"]
data = data["data"]
if not distributed:
headers = data["header"]
data = data["data"]
if not return_mmap:
if lazy:
if not distributed and lazy:
if isinstance(chunks, tuple) and len(chunks) > 2:
# Since the data is reshaped later on, we set only the
# signal dimension chunks here
Expand All @@ -344,6 +355,10 @@ def load_mib_data(
data = data.rechunk(chunks)

if return_headers:
if distributed:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still return the header by just setting the key="header" for a second memmap_distributed call. It will add some time onto the saving of the dataset as the entire dataset might get loaded into ram with most of it thrown away.

Really what we should do is add things to a to_store context manager and then call:

da.store(data, dset)

Only once. That will merge taskgraphs as necessary and might reduce the time for saving certain signals. I've thought about it for things like saving lazy markers of possibly creating a hs.save() function for handling mulitple signals if you wanted to save multiple parts of some anaylsis efficently. This is a fairly abstract/higher level concept so maybe it would be seledomly used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will most likely needed to be done at some point! I opened #269 to track it / add more usecases.

raise ValueError(
"Retuning headers is not supported with `distributed=True`."
)
return data, headers
else:
return data
Expand All @@ -356,6 +371,7 @@ def load_mib_data(
MMAP_DOC,
NAVIGATION_SHAPE,
_FIRST_LAST_FRAME,
DISTRIBUTED_DOC,
)


Expand Down Expand Up @@ -489,6 +505,7 @@ def file_reader(
navigation_shape=None,
first_frame=None,
last_frame=None,
distributed=False,
print_info=False,
):
"""
Expand All @@ -505,6 +522,7 @@ def file_reader(
%s
%s
%s
%s
print_info : bool
Display information about the mib file.

Expand Down Expand Up @@ -589,6 +607,7 @@ def file_reader(
navigation_shape=navigation_shape,
first_frame=first_frame,
last_frame=last_frame,
distributed=distributed,
mib_prop=mib_prop,
print_info=print_info,
return_mmap=False,
Expand Down Expand Up @@ -653,5 +672,6 @@ def file_reader(
MMAP_DOC,
NAVIGATION_SHAPE,
_FIRST_LAST_FRAME,
DISTRIBUTED_DOC,
RETURNS_DOC,
)
18 changes: 18 additions & 0 deletions rsciio/tests/test_quantumdetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,21 @@ def test_frames_in_acquisition_zero():

s = hs.load(f"{fname}.mib")
assert s.axes_manager.navigation_shape == ()


@pytest.mark.parametrize("lazy", (True, False))
def test_distributed(lazy):
s = hs.load(
TEST_DATA_DIR_UNZIPPED / "001_4x2_6bit.mib",
distributed=False,
lazy=lazy,
)
s2 = hs.load(
TEST_DATA_DIR_UNZIPPED / "001_4x2_6bit.mib",
distributed=True,
lazy=lazy,
)
if lazy:
s.compute()
s2.compute()
np.testing.assert_array_equal(s.data, s2.data)
2 changes: 1 addition & 1 deletion rsciio/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_get_date_time_from_metadata():

@pytest.mark.parametrize(
"shape",
((10, 20, 30, 512, 512),(20, 30, 512, 512), (10, 512, 512), (512, 512))
((10, 20, 30, 512, 512), (20, 30, 512, 512), (10, 512, 512), (512, 512))
)
def test_get_chunk_slice(shape):
chunk_arr, chunk = get_chunk_slice(shape=shape, chunks=-1) # 1 chunk
Expand Down
69 changes: 50 additions & 19 deletions rsciio/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with RosettaSciIO. If not, see <https://www.gnu.org/licenses/#GPL>.

import os

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -60,22 +61,19 @@ def get_chunk_slice(
)
chunks_shape = tuple([len(c) for c in chunks])
slices = np.empty(
shape=chunks_shape
+ (
len(chunks_shape),
2,
),
shape=chunks_shape + (len(chunks_shape), 2),
dtype=int,
)
for ind in np.ndindex(chunks_shape):
current_chunk = [chunk[i] for i, chunk in zip(ind, chunks)]
starts = [int(np.sum(chunk[:i])) for i, chunk in zip(ind, chunks)]
stops = [s + c for s, c in zip(starts, current_chunk)]
slices[ind] = [[start, stop] for start, stop in zip(starts, stops)]

return da.from_array(slices, chunks=(1,) * len(shape) + slices.shape[-2:]), chunks


def slice_memmap(slices, file, dtypes, shape, **kwargs):
def slice_memmap(slices, file, dtypes, shape, key=None, **kwargs):
"""
Slice a memory mapped file using a tuple of slices.

Expand All @@ -96,6 +94,8 @@ def slice_memmap(slices, file, dtypes, shape, **kwargs):
Data type of the data for :class:`numpy.memmap` function.
shape : tuple
Shape of the entire dataset. Passed to the :class:`numpy.memmap` function.
key : None, str
For structured dtype only. Specify the key of the structured dtype to use.
**kwargs : dict
Additional keyword arguments to pass to the :class:`numpy.memmap` function.

Expand All @@ -104,31 +104,36 @@ def slice_memmap(slices, file, dtypes, shape, **kwargs):
numpy.ndarray
Array of the data from the memory mapped file sliced using the provided slice.
"""
sl = np.squeeze(slices)[()]
slices_ = np.squeeze(slices)[()]
data = np.memmap(file, dtypes, shape=shape, **kwargs)
slics = tuple([slice(s[0], s[1]) for s in sl])
return data[slics]
if key is not None:
data = data[key]
slices_ = tuple([slice(s[0], s[1]) for s in slices_])
return data[slices_]


def memmap_distributed(
file,
filename,
dtype,
offset=0,
shape=None,
order="C",
chunks="auto",
block_size_limit=None,
key=None,
):
"""
Drop in replacement for py:func:`numpy.memmap` allowing for distributed loading of data.
Drop in replacement for py:func:`numpy.memmap` allowing for distributed
loading of data.

This always loads the data using dask which can be beneficial in many cases, but
may not be ideal in others. The ``chunks`` and ``block_size_limit`` are for describing an ideal chunk shape and size
as defined using the :py:func:`dask.array.core.normalize_chunks` function.
This always loads the data using dask which can be beneficial in many
cases, but may not be ideal in others. The ``chunks`` and ``block_size_limit``
are for describing an ideal chunk shape and size as defined using the
:func:`dask.array.core.normalize_chunks` function.

Parameters
----------
file : str
filename : str
Path to the file.
dtype : numpy.dtype
Data type of the data for memmap function.
Expand All @@ -142,25 +147,50 @@ def memmap_distributed(
Chunk shape. The default is "auto".
block_size_limit : int, optional
Maximum size of a block in bytes. The default is None.
key : None, str
For structured dtype only. Specify the key of the structured dtype to use.

Returns
-------
dask.array.Array
Dask array of the data from the memmaped file and with the specified chunks.

Notes
-----
Currently :func:`dask.array.map_blocks` does not allow for multiple outputs.
As a result, in case of structured dtype, the key of the structured dtype need
to be specified.
For example: with dtype = (("data", int, (128, 128)), ("sec", "<u4", 512)),
"data" or "sec" will need to be specified.
"""

if dtype.names is not None:
# Structured dtype
array_dtype = dtype[key].base
sub_array_shape = dtype[key].shape
else:
array_dtype = dtype.base
sub_array_shape = dtype.shape

if shape is None:
unit_size = np.dtype(dtype).itemsize
shape = int(os.path.getsize(filename) / unit_size)
if not isinstance(shape, tuple):
shape = (shape,)

# Separates slices into appropriately sized chunks.
chunked_slices, data_chunks = get_chunk_slice(
shape=shape,
shape=shape + sub_array_shape,
chunks=chunks,
block_size_limit=block_size_limit,
dtype=dtype,
dtype=array_dtype,
)
num_dim = len(shape)
data = da.map_blocks(
slice_memmap,
chunked_slices,
file=file,
dtype=dtype,
file=filename,
dtype=array_dtype,
shape=shape,
order=order,
mode="r",
Expand All @@ -171,5 +201,6 @@ def memmap_distributed(
num_dim,
num_dim + 1,
), # Dask 2021.10.0 minimum to use negative indexing
key=key,
)
return data
1 change: 1 addition & 0 deletions upcoming_changes/267.enhancements.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:ref:`quantumdetector-format`: Add support for dask distributed scheduler.
Loading