diff --git a/datalad_next/itertools/__init__.py b/datalad_next/itertools/__init__.py index 339f5984..15aaf103 100644 --- a/datalad_next/itertools/__init__.py +++ b/datalad_next/itertools/__init__.py @@ -4,6 +4,7 @@ .. autosummary:: :toctree: generated + align_pattern decode_bytes itemize load_json @@ -13,6 +14,7 @@ """ +from .align_pattern import align_pattern from .decode_bytes import decode_bytes from .itemize import itemize from .load_json import ( diff --git a/datalad_next/itertools/align_pattern.py b/datalad_next/itertools/align_pattern.py new file mode 100644 index 00000000..3c89d616 --- /dev/null +++ b/datalad_next/itertools/align_pattern.py @@ -0,0 +1,102 @@ +""" Function to ensure that a pattern is completely contained in single chunks +""" + +from __future__ import annotations + +from typing import ( + Generator, + Iterable, +) + + +def align_pattern(iterable: Iterable[str | bytes | bytearray], + pattern: str | bytes | bytearray + ) -> Generator[str | bytes | bytearray, None, None]: + """ Yield data chunks that contain a complete pattern, if it is present + + ``align_pattern`` makes it easy to find a pattern (``str``, ``bytes``, + or ``bytearray``) in data chunks. It joins data-chunks in such a way, + that a simple containment-check (e.g. ``pattern in chunk``) on the chunks + that ``align_pattern`` yields will suffice to determine whether the pattern + is present in the stream yielded by the underlying iterable or not. + + To achieve this, ``align_pattern`` will join consecutive chunks to ensures + that the following two assertions hold: + + 1. Each chunk that is yielded by ``align_pattern`` has at least the length + of the pattern (unless the underlying iterable is exhausted before the + length of the pattern is reached). + + 2. The pattern is not split between two chunks, i.e. no chunk that is + yielded by ``align_pattern`` ends with a prefix of the pattern (unless + it is the last chunk that the underlying iterable yield). + + The pattern might be present multiple times in a yielded data chunk. + + Note: the ``pattern`` is compared verbatim to the content in the data + chunks, i.e. no parsing of the ``pattern`` is performed and no regular + expressions or wildcards are supported. + + .. code-block:: python + + >>> from datalad_next.itertools import align_pattern + >>> tuple(align_pattern([b'abcd', b'e', b'fghi'], pattern=b'def')) + (b'abcdefghi',) + >>> # The pattern can be present multiple times in a yielded chunk + >>> tuple(align_pattern([b'abcd', b'e', b'fdefghi'], pattern=b'def')) + (b'abcdefdefghi',) + + Use this function if you want to locate a pattern in an input stream. It + allows to use a simple ``in``-check to determine whether the pattern is + present in the yielded result chunks. + + The function always yields everything it has fetched from the underlying + iterable. So after a yield it does not cache any data from the underlying + iterable. That means, if the functionality of + ``align_pattern`` is no longer required, the underlying iterator can be + used, when ``align_pattern`` has yielded a data chunk. + This allows more efficient processing of the data that remains in the + underlying iterable. + + Parameters + ---------- + iterable: Iterable + An iterable that yields data chunks. + pattern: str | bytes | bytearray + The pattern that should be contained in the chunks. Its type must be + compatible to the type of the elements in ``iterable``. + + Yields + ------- + str | bytes | bytearray + data chunks that have at least the size of the pattern and do not end + with a prefix of the pattern. Note that a data chunk might contain the + pattern multiple times. + """ + + def ends_with_pattern_prefix(data: str | bytes | bytearray, + pattern: str | bytes | bytearray, + ) -> bool: + """ Check whether the chunk ends with a prefix of the pattern """ + for index in range(len(pattern) - 1, 0, -1): + if data[-index:] == pattern[:index]: + return True + return False + + # Join data chunks until they are sufficiently long to contain the pattern, + # i.e. have at least size: `len(pattern)`. Continue joining, if the chunk + # ends with a prefix of the pattern. + current_chunk = None + for data_chunk in iterable: + # get the type of current_chunk from the type of this data_chunk + if current_chunk is None: + current_chunk = data_chunk + else: + current_chunk += data_chunk + if len(current_chunk) >= len(pattern) \ + and not ends_with_pattern_prefix(current_chunk, pattern): + yield current_chunk + current_chunk = None + + if current_chunk is not None: + yield current_chunk diff --git a/datalad_next/itertools/tests/test_align_pattern.py b/datalad_next/itertools/tests/test_align_pattern.py new file mode 100644 index 00000000..6da8bc87 --- /dev/null +++ b/datalad_next/itertools/tests/test_align_pattern.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import pytest + +from ..align_pattern import align_pattern + + +@pytest.mark.parametrize('data_chunks,pattern,expected', [ + (['a', 'b', 'c', 'd', 'e'], 'abc', ['abc', 'de']), + (['a', 'b', 'c', 'a', 'b', 'c'], 'abc', ['abc', 'abc']), + # Ensure that unaligned pattern prefixes are not keeping data chunks short. + (['a', 'b', 'c', 'dddbbb', 'a', 'b', 'x'], 'abc', ['abc', 'dddbbb', 'abx']), + # Expect that a trailing minimum length-chunk that ends with a pattern + # prefix is not returned as data, but as remainder, if it is not the final + # chunk. + (['a', 'b', 'c', 'd', 'a'], 'abc', ['abc', 'da']), + # Expect the last chunk to be returned as data, if final is True, although + # it ends with a pattern prefix. If final is false, the last chunk will be + # returned as a remainder, because it ends with a pattern prefix. + (['a', 'b', 'c', 'dddbbb', 'a'], 'abc', ['abc', 'dddbbb', 'a']), + (['a', 'b', 'c', '9', 'a'], 'abc', ['abc', '9a']), +]) +def test_pattern_processor(data_chunks, pattern, expected): + assert expected == list(align_pattern(data_chunks, pattern=pattern)) diff --git a/datalad_next/runners/__init__.py b/datalad_next/runners/__init__.py index 7a96f952..e2ee54e3 100644 --- a/datalad_next/runners/__init__.py +++ b/datalad_next/runners/__init__.py @@ -42,7 +42,10 @@ StdOutErrCapture """ -from .iter_subproc import iter_subproc +from .iter_subproc import ( + iter_subproc, + IterableSubprocessError, +) # runners from datalad.runner import ( diff --git a/datalad_next/runners/iter_subproc.py b/datalad_next/runners/iter_subproc.py index 5d5bf932..6901abcd 100644 --- a/datalad_next/runners/iter_subproc.py +++ b/datalad_next/runners/iter_subproc.py @@ -4,8 +4,12 @@ List, ) -from datalad_next.iterable_subprocess.iterable_subprocess \ - import iterable_subprocess +from datalad_next.iterable_subprocess.iterable_subprocess import ( + iterable_subprocess, + # not needed here, but we want to provide all critical pieces from + # the same place. This is the key exception type + IterableSubprocessError, +) from datalad_next.consts import COPY_BUFSIZE __all__ = ['iter_subproc'] diff --git a/datalad_next/url_operations/__init__.py b/datalad_next/url_operations/__init__.py index 931ab2a8..a66775c4 100644 --- a/datalad_next/url_operations/__init__.py +++ b/datalad_next/url_operations/__init__.py @@ -15,10 +15,14 @@ from __future__ import annotations import logging +from functools import partial +from more_itertools import side_effect from pathlib import Path from typing import ( Any, Dict, + Generator, + Iterable, ) import datalad @@ -338,6 +342,37 @@ def _progress_report_stop(self, pid: str, log_msg: tuple): def _get_hasher(self, hash: list[str] | None) -> NoOpHash | MultiHash: return MultiHash(hash) if hash is not None else NoOpHash() + def _with_progress(self, + stream: Iterable[Any], + *, + progress_id: str, + label: str, + expected_size: int | None, + start_log_msg: tuple, + end_log_msg: tuple, + update_log_msg: tuple + ) -> Generator[Any, None, None]: + yield from side_effect( + lambda chunk: self._progress_report_update( + progress_id, + update_log_msg, + len(chunk) + ), + stream, + before=partial( + self._progress_report_start, + progress_id, + start_log_msg, + label, + expected_size + ), + after=partial( + self._progress_report_stop, + progress_id, + end_log_msg + ) + ) + # # Exceptions to be used by all handlers diff --git a/datalad_next/url_operations/ssh.py b/datalad_next/url_operations/ssh.py index 1e9a3aac..d52477cb 100644 --- a/datalad_next/url_operations/ssh.py +++ b/datalad_next/url_operations/ssh.py @@ -4,8 +4,8 @@ from __future__ import annotations import logging -import subprocess import sys +from functools import partial from itertools import chain from pathlib import ( Path, @@ -16,23 +16,20 @@ Queue, ) from typing import ( - Any, Dict, Generator, IO, + cast, ) from urllib.parse import urlparse +from datalad_next.consts import COPY_BUFSIZE from datalad_next.runners import ( - GeneratorMixIn, - NoCaptureGeneratorProtocol, - Protocol as RunnerProtocol, - StdOutCaptureGeneratorProtocol, - ThreadedRunner, - CommandError, + iter_subproc, + IterableSubprocessError, ) +from datalad_next.itertools import align_pattern -from datalad_next.utils.consts import COPY_BUFSIZE from . import ( UrlOperations, @@ -71,6 +68,18 @@ class SshUrlOperations(UrlOperations): "|| exit 244" _cat_cmd = "cat '{fpath}'" + def _check_return_code(self, return_code: int, url: str): + # At this point the subprocess has either exited, was terminated, or + # was killed. + if return_code == 244: + # this is the special code for a file-not-found + raise UrlOperationsResourceUnknown(url) + elif return_code != 0: + raise UrlOperationsRemoteError( + url, + message=f'ssh process returned {return_code}' + ) + def stat(self, url: str, *, @@ -81,63 +90,63 @@ def stat(self, See :meth:`datalad_next.url_operations.UrlOperations.stat` for parameter documentation and exception behavior. """ + ssh_cat = _SshCommandBuilder(url) + cmd = ssh_cat.get_cmd(SshUrlOperations._stat_cmd) try: - props = self._stat( - url, - cmd=SshUrlOperations._stat_cmd, - ) - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(url) from e - else: - raise UrlOperationsRemoteError(url, message=str(e)) from e - + with iter_subproc(cmd) as stream: + # any exception that is raised in this context and not caught + # will prevent the creation of `IterableSubprocessError`. But + # we rely on the return code of the ssh-process to signal + # specific errors. Therefore, we catch the expected + # `StopIteration` here. + try: + props = self._get_props(url, stream) + except StopIteration: + pass + except IterableSubprocessError as e: + self._check_return_code(e.returncode, url) return {k: v for k, v in props.items() if not k.startswith('_')} - def _stat(self, url: str, cmd: str) -> Dict: - # any stream must start with this magic marker, or we do not + def _get_props(self, url, stream: Generator) -> dict: + # Any stream must start with this magic marker, or we do not # recognize what is happening # after this marker, the server will send the size of the # to-be-downloaded file in bytes, followed by another magic - # b'\1', and the file content after that - need_magic = b'\1\2\3' - expected_size_str = b'' - expected_size = None - - ssh_cat = _SshCat(url) - stream = ssh_cat.run(cmd, protocol=StdOutCaptureGeneratorProtocol) - for chunk in stream: - if need_magic: - expected_magic = need_magic[:min(len(need_magic), - len(chunk))] - incoming_magic = chunk[:len(need_magic)] - # does the incoming data have the remaining magic bytes? - if incoming_magic != expected_magic: - raise RuntimeError( - "Protocol error: report header not received") - # reduce (still missing) magic, if any - need_magic = need_magic[len(expected_magic):] - # strip magic from input - chunk = chunk[len(expected_magic):] - if chunk and expected_size is None: - # we have incoming data left and - # we have not yet consumed the size info - size_data = chunk.split(b'\1', maxsplit=1) - expected_size_str += size_data[0] - if len(size_data) > 1: - # this is not only size info, but we found the start of - # the data - expected_size = int(expected_size_str) - chunk = size_data[1] - if expected_size: - props = { - 'content-length': expected_size, - '_stream': chain([chunk], stream) if chunk else stream, - } - return props - # there should be no data left to process, or something went wrong - assert not chunk + # b'\1', and the file content after that. + magic_marker = b'\1\2\3' + + # use the `align_pattern` iterable to guarantees, that the magic + # marker is always contained in a complete chunk. + aligned_stream = align_pattern(stream, magic_marker) + + # Because the stream should start with the pattern, the first chunk of + # the aligned stream must contain it. + # We know that the stream will deliver bytes, cast the result + # accordingly. + chunk = cast(bytes, next(aligned_stream)) + if chunk[:len(magic_marker)] != magic_marker: + raise RuntimeError("Protocol error: report header not received") + chunk = chunk[len(magic_marker):] + + # We are done with the aligned stream, use the original stream again. + # This is possible because `align_pattern` does not cache any data + # after a `yield`. + del aligned_stream + + # The length is transferred now and terminated by b'\x01'. + while b'\x01' not in chunk: + chunk += next(stream) + + marker_index = chunk.index(b'\x01') + expected_size = int(chunk[:marker_index]) + chunk = chunk[marker_index + 1:] + props = { + 'content-length': expected_size, + # go back to the original iterator, no need to keep looking for + # a pattern + '_stream': chain([chunk], stream) if chunk else stream + } + return props def download(self, from_url: str, @@ -147,7 +156,7 @@ def download(self, # obtain escalated/different privileges on a system # to gain file access credential: str | None = None, - hash: str | None = None, + hash: list[str] | None = None, timeout: float | None = None) -> Dict: """Download a file by streaming it through an SSH connection. @@ -160,49 +169,60 @@ def download(self, # this is pretty much shutil.copyfileobj() with the necessary # wrapping to perform hashing and progress reporting hasher = self._get_hasher(hash) - progress_id = self._get_progress_id(from_url, to_path) + progress_id = self._get_progress_id(from_url, str(to_path)) dst_fp = None + ssh_cat = _SshCommandBuilder(from_url) + cmd = ssh_cat.get_cmd(f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}') try: - props = self._stat( - from_url, - cmd=f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}', - ) - stream = props.pop('_stream') - expected_size = props['content-length'] - dst_fp = sys.stdout.buffer if to_path is None \ - else open(to_path, 'wb') - # Localize variable access to minimize overhead - dst_fp_write = dst_fp.write - # download can start - self._progress_report_start( - progress_id, - ('Download %s to %s', from_url, to_path), - 'downloading', - expected_size, - ) - for chunk in stream: - # write data - dst_fp_write(chunk) - # compute hash simultaneously - hasher.update(chunk) - self._progress_report_update( - progress_id, ('Downloaded chunk',), len(chunk)) - props.update(hasher.get_hexdigest()) - return props - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(from_url) from e - else: - # wrap this into the datalad-standard, but keep the - # original exception linked - raise UrlOperationsRemoteError(from_url, message=str(e)) from e + with iter_subproc(cmd) as stream: + # any exception that is raised in this context and not caught + # will prevent the creation of `IterableSubprocessError`. But + # we rely on the return code of the ssh-process to signal + # specific errors. Therefore, we catch the expected + # `StopIteration` here. + try: + props = self._get_props(from_url, stream) + expected_size = props['content-length'] + # The stream might have changed due to not yet processed, but + # fetched data, that is now chained in front of it. Therefore we + # get the updated stream from the props + download_stream = props.pop('_stream') + + dst_fp = sys.stdout.buffer \ + if to_path is None \ + else open(to_path, 'wb') + + # Localize variable access to minimize overhead + dst_fp_write = dst_fp.write + + # download can start + for chunk in self._with_progress( + download_stream, + progress_id=progress_id, + label='downloading', + expected_size=expected_size, + start_log_msg=('Download %s to %s', from_url, to_path), + end_log_msg=('Finished download',), + update_log_msg=('Downloaded chunk',) + ): + # write data + dst_fp_write(chunk) + # compute hash simultaneously + hasher.update(chunk) + except StopIteration: + pass + except IterableSubprocessError as e: + self._check_return_code(e.returncode, from_url) finally: if dst_fp and to_path is not None: dst_fp.close() - self._progress_report_stop(progress_id, ('Finished download',)) + + return { + **props, + **hasher.get_hexdigest(), + } def upload(self, from_path: Path | None, @@ -234,7 +254,7 @@ def upload(self, with from_path.open("rb") as src_fp: return self._perform_upload( src_fp=src_fp, - source_name=from_path, + source_name=str(from_path), to_url=to_url, hash_names=hash, expected_size=from_path.stat().st_size, @@ -247,69 +267,64 @@ def _perform_upload(self, to_url: str, hash_names: list[str] | None, expected_size: int | None, - timeout: int | None) -> dict: + timeout: float | None) -> dict: hasher = self._get_hasher(hash_names) + # we use a queue to implement timeouts. # we limit the queue to few items in order to `make queue.put()` # block relatively quickly, and thereby have the progress report - # actually track the upload, and not just the feeding of the - # queue - upload_queue = Queue(maxsize=2) - - ssh_cat = _SshCat(to_url) - ssh_runner_generator = ssh_cat.run( + # actually track the upload, i.e. the feeding of the stdin pipe + # of the ssh-process, and not just the feeding of the + # queue. + # If we did not support timeouts, we could just use the following + # as `input`-iterable for `iter_subproc`: + # + # `iter(partial(src_fp.read, COPY_BUFSIZE), b'') + # + upload_queue: Queue = Queue(maxsize=2) + + cmd = _SshCommandBuilder(to_url).get_cmd( # leave special exit code when writing fails, but not the # general SSH access - "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244", - protocol=NoCaptureGeneratorProtocol, - stdin=upload_queue, - timeout=timeout, + "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244" ) - # file is open, we can start progress tracking progress_id = self._get_progress_id(source_name, to_url) - self._progress_report_start( - progress_id, - ('Upload %s to %s', source_name, to_url), - 'uploading', - expected_size, - ) try: - upload_size = 0 - while ssh_runner_generator.runner.process.poll() is None: - chunk = src_fp.read(COPY_BUFSIZE) - if chunk == b'': - break - chunk_size = len(chunk) - # compute hash simultaneously - hasher.update(chunk) - # we are just putting stuff in the queue, and rely on - # its maxsize to cause it to block the next call to - # have the progress reports be anyhow valid - upload_queue.put(chunk, timeout=timeout) - self._progress_report_update( - progress_id, ('Uploaded chunk',), chunk_size) - upload_size += chunk_size - # we're done, close queue - upload_queue.put(None, timeout=timeout) - - # Exhaust the generator, that might raise CommandError - # or TimeoutError, if timeout was not `None`. - tuple(ssh_runner_generator) - except CommandError as e: - if e.code == 244: - raise UrlOperationsResourceUnknown(to_url) from e - else: - raise UrlOperationsRemoteError(to_url, message=str(e)) from e - except (TimeoutError, Full): - ssh_runner_generator.runner.process.kill() - raise TimeoutError - finally: - self._progress_report_stop(progress_id, ('Finished upload',)) - - assert ssh_runner_generator.return_code == 0, "Unexpected ssh " \ - f"return value: {ssh_runner_generator.return_code}" + with iter_subproc( + cmd, + input=self._with_progress( + iter(upload_queue.get, None), + progress_id=progress_id, + label='uploading', + expected_size=expected_size, + start_log_msg=('Upload %s to %s', source_name, to_url), + end_log_msg=('Finished upload',), + update_log_msg=('Uploaded chunk',) + ) + ): + upload_size = 0 + for chunk in iter(partial(src_fp.read, COPY_BUFSIZE), b''): + + # we are just putting stuff in the queue, and rely on + # its maxsize to cause it to block the next call to + # have the progress reports be anyhow valid, we also + # rely on put-timeouts to implement timeout. + upload_queue.put(chunk, timeout=timeout) + + # compute hash simultaneously + hasher.update(chunk) + upload_size += len(chunk) + + upload_queue.put(None, timeout=timeout) + + except IterableSubprocessError as e: + self._check_return_code(e.returncode, to_url) + except Full: + if chunk != b'': + # we had a timeout while uploading + raise TimeoutError return { **hasher.get_hexdigest(), @@ -320,7 +335,7 @@ def _perform_upload(self, } -class _SshCat: +class _SshCommandBuilder: def __init__(self, url: str, *additional_ssh_args): self._parsed = urlparse(url) # make sure the essential pieces exist @@ -328,25 +343,16 @@ def __init__(self, url: str, *additional_ssh_args): assert self._parsed.path self.ssh_args: list[str] = list(additional_ssh_args) - def run(self, - payload_cmd: str, - protocol: type[RunnerProtocol], - stdin: Queue | None = None, - timeout: float | None = None) -> Any | Generator: + def get_cmd(self, payload_cmd: str) -> list[str]: fpath = self._parsed.path cmd = ['ssh'] cmd.extend(self.ssh_args) cmd.extend([ '-e', 'none', - self._parsed.hostname, + self._parsed.hostname or '', payload_cmd.format( fdir=str(PurePosixPath(fpath).parent), fpath=fpath, ), ]) - return ThreadedRunner( - cmd=cmd, - protocol_class=protocol, - stdin=subprocess.DEVNULL if stdin is None else stdin, - timeout=timeout, - ).run() + return cmd diff --git a/datalad_next/url_operations/tests/test_ssh.py b/datalad_next/url_operations/tests/test_ssh.py index c2293783..7d3db874 100644 --- a/datalad_next/url_operations/tests/test_ssh.py +++ b/datalad_next/url_operations/tests/test_ssh.py @@ -1,3 +1,4 @@ +import contextlib import io import pytest @@ -79,7 +80,7 @@ def test_ssh_url_upload(tmp_path, monkeypatch): # this may seem strange for SSH, but FILE does it too. # likewise an HTTP upload is also not required to establish # server-side preconditions first. - # this functionality is not about about exposing a full + # this functionality is not about exposing a full # remote FS abstraction -- just upload ops.upload(payload_path, upload_url) assert upload_path.read_text() == payload @@ -114,13 +115,13 @@ def test_ssh_url_upload_timeout(tmp_path, monkeypatch): upload_url = f'ssh://localhost/not_used' ssh_url_ops = SshUrlOperations() - def mocked_popen(*args, **kwargs): - from subprocess import Popen - args = (['sleep', '3'],) + args[1:] - return Popen(*args, **kwargs) + @contextlib.contextmanager + def mocked_iter_subproc(*args, **kwargs): + yield None with monkeypatch.context() as mp_ctx: - import datalad - mp_ctx.setattr(datalad.runner.nonasyncrunner, "Popen", mocked_popen) + import datalad_next.url_operations.ssh + mp_ctx.setattr(datalad_next.url_operations.ssh, 'iter_subproc', mocked_iter_subproc) + mp_ctx.setattr(datalad_next.url_operations.ssh, 'COPY_BUFSIZE', 1) with pytest.raises(TimeoutError): ssh_url_ops.upload(payload_path, upload_url, timeout=1)