diff --git a/datalad_next/url_operations/ssh.py b/datalad_next/url_operations/ssh.py index 1e9a3aacf..aad2cfab0 100644 --- a/datalad_next/url_operations/ssh.py +++ b/datalad_next/url_operations/ssh.py @@ -4,8 +4,8 @@ from __future__ import annotations import logging -import subprocess import sys +from functools import partial from itertools import chain from pathlib import ( Path, @@ -16,23 +16,19 @@ Queue, ) from typing import ( - Any, Dict, Generator, IO, ) from urllib.parse import urlparse -from datalad_next.runners import ( - GeneratorMixIn, - NoCaptureGeneratorProtocol, - Protocol as RunnerProtocol, - StdOutCaptureGeneratorProtocol, - ThreadedRunner, - CommandError, -) +from more_itertools import side_effect + +from datalad_next.consts import COPY_BUFSIZE +from datalad_next.runners.iter_subproc import iter_subproc +from datalad_next.iterable_subprocess.iterable_subprocess import IterableSubprocessError +from datalad_next.itertools import align_pattern -from datalad_next.utils.consts import COPY_BUFSIZE from . import ( UrlOperations, @@ -71,6 +67,18 @@ class SshUrlOperations(UrlOperations): "|| exit 244" _cat_cmd = "cat '{fpath}'" + def _check_return_code(self, return_code: int, url: str): + # At this point the subprocess has either exited, was terminated, or + # was killed. + if return_code == 244: + # this is the special code for a file-not-found + raise UrlOperationsResourceUnknown(url) + elif return_code != 0: + raise UrlOperationsRemoteError( + url, + message=f'ssh process returned {return_code}' + ) + def stat(self, url: str, *, @@ -81,63 +89,56 @@ def stat(self, See :meth:`datalad_next.url_operations.UrlOperations.stat` for parameter documentation and exception behavior. """ + ssh_cat = _SshCommandBuilder(url) + cmd = ssh_cat.get_cmd(SshUrlOperations._stat_cmd) try: - props = self._stat( - url, - cmd=SshUrlOperations._stat_cmd, - ) - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(url) from e - else: - raise UrlOperationsRemoteError(url, message=str(e)) from e - + with iter_subproc(cmd) as stream: + props = self._get_props(url, stream) + except IterableSubprocessError as e: + self._check_return_code(e.returncode, url) + except StopIteration: + # The stream was empty, this happens, if the ssh-shell does not + # return any output. Usually that indicates that the resource + # identified by `url` is not available. + raise UrlOperationsResourceUnknown(url) return {k: v for k, v in props.items() if not k.startswith('_')} - def _stat(self, url: str, cmd: str) -> Dict: - # any stream must start with this magic marker, or we do not + def _get_props(self, url, stream: Generator) -> dict | None: + # Any stream must start with this magic marker, or we do not # recognize what is happening # after this marker, the server will send the size of the # to-be-downloaded file in bytes, followed by another magic - # b'\1', and the file content after that - need_magic = b'\1\2\3' - expected_size_str = b'' - expected_size = None - - ssh_cat = _SshCat(url) - stream = ssh_cat.run(cmd, protocol=StdOutCaptureGeneratorProtocol) - for chunk in stream: - if need_magic: - expected_magic = need_magic[:min(len(need_magic), - len(chunk))] - incoming_magic = chunk[:len(need_magic)] - # does the incoming data have the remaining magic bytes? - if incoming_magic != expected_magic: - raise RuntimeError( - "Protocol error: report header not received") - # reduce (still missing) magic, if any - need_magic = need_magic[len(expected_magic):] - # strip magic from input - chunk = chunk[len(expected_magic):] - if chunk and expected_size is None: - # we have incoming data left and - # we have not yet consumed the size info - size_data = chunk.split(b'\1', maxsplit=1) - expected_size_str += size_data[0] - if len(size_data) > 1: - # this is not only size info, but we found the start of - # the data - expected_size = int(expected_size_str) - chunk = size_data[1] - if expected_size: - props = { - 'content-length': expected_size, - '_stream': chain([chunk], stream) if chunk else stream, - } - return props - # there should be no data left to process, or something went wrong - assert not chunk + # b'\1', and the file content after that. + magic_marker = b'\1\2\3' + + # use the `align_pattern` iterable to guarantees, that the magic + # marker is always contained in a complete chunk. + # TODO: the align-pattern iterator stays "on" the stream, which might + # slow down processing upstream. We could consider to change + # align_pattern to an iterator-class and support detaching. + aligned_stream = align_pattern(stream, magic_marker) + + # Because the stream should start with the pattern, the first chunk should contain it. + chunk = next(aligned_stream) + if chunk[:len(magic_marker)] != magic_marker: + raise RuntimeError("Protocol error: report header not received") + chunk = chunk[len(magic_marker):] + + # The length is transferred now and terminated by b'\x01'. + while b'\x01' not in chunk: + chunk += next(stream) + + marker_index = chunk.index(b'\x01') + expected_size = int(chunk[:marker_index]) + chunk = chunk[marker_index + 1:] + props = { + 'content-length': expected_size, + '_stream': + chain([chunk], aligned_stream) + if chunk + else aligned_stream + } + return props def download(self, from_url: str, @@ -160,49 +161,55 @@ def download(self, # this is pretty much shutil.copyfileobj() with the necessary # wrapping to perform hashing and progress reporting hasher = self._get_hasher(hash) - progress_id = self._get_progress_id(from_url, to_path) - - dst_fp = None + progress_id = self._get_progress_id(from_url, str(to_path)) + ssh_cat = _SshCommandBuilder(from_url) + cmd = ssh_cat.get_cmd(f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}') try: - props = self._stat( - from_url, - cmd=f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}', - ) - stream = props.pop('_stream') - expected_size = props['content-length'] - dst_fp = sys.stdout.buffer if to_path is None \ - else open(to_path, 'wb') - # Localize variable access to minimize overhead - dst_fp_write = dst_fp.write - # download can start - self._progress_report_start( - progress_id, - ('Download %s to %s', from_url, to_path), - 'downloading', - expected_size, - ) - for chunk in stream: - # write data - dst_fp_write(chunk) - # compute hash simultaneously - hasher.update(chunk) - self._progress_report_update( - progress_id, ('Downloaded chunk',), len(chunk)) - props.update(hasher.get_hexdigest()) - return props - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(from_url) from e - else: - # wrap this into the datalad-standard, but keep the - # original exception linked - raise UrlOperationsRemoteError(from_url, message=str(e)) from e - finally: - if dst_fp and to_path is not None: - dst_fp.close() - self._progress_report_stop(progress_id, ('Finished download',)) + with iter_subproc(cmd) as stream: + props = self._get_props(from_url, stream) + expected_size = props['content-length'] + # The stream might have changed due to not yet processed, but + # fetched data, that is now chained in front of it. Therefore we + # get the updated stream from the props + download_stream = props.pop('_stream') + + dst_fp = sys.stdout.buffer \ + if to_path is None \ + else open(to_path, 'wb') + + # Localize variable access to minimize overhead + dst_fp_write = dst_fp.write + + # download can start + for chunk in side_effect( + lambda chunk: self._progress_report_update( + progress_id, + ('Downloaded chunk',), + len(chunk) + ), + download_stream, + before=partial( + self._progress_report_start, + progress_id, + ('Download %s to %s', from_url, to_path), + 'downloading', + expected_size + ) + ): + # write data + dst_fp_write(chunk) + # compute hash simultaneously + hasher.update(chunk) + + except IterableSubprocessError as e: + self._check_return_code(e.returncode, from_url) + except StopIteration: + raise UrlOperationsResourceUnknown(from_url) + return { + **props, + **hasher.get_hexdigest(), + } def upload(self, from_path: Path | None, @@ -234,7 +241,7 @@ def upload(self, with from_path.open("rb") as src_fp: return self._perform_upload( src_fp=src_fp, - source_name=from_path, + source_name=str(from_path), to_url=to_url, hash_names=hash, expected_size=from_path.stat().st_size, @@ -251,65 +258,64 @@ def _perform_upload(self, hasher = self._get_hasher(hash_names) + # we use a queue to implement timeouts. # we limit the queue to few items in order to `make queue.put()` # block relatively quickly, and thereby have the progress report - # actually track the upload, and not just the feeding of the - # queue + # actually track the upload, i.e. the feeding of the stdin pipe + # of the ssh-process, and not just the feeding of the + # queue. + # If we did not support timeouts, we could just use the following + # as `input`-iterable for `iter_subproc`: + # + # `iter(partial(src_fp.read, COPY_BUFSIZE), b'') + # upload_queue = Queue(maxsize=2) - ssh_cat = _SshCat(to_url) - ssh_runner_generator = ssh_cat.run( + cmd = _SshCommandBuilder(to_url).get_cmd( # leave special exit code when writing fails, but not the # general SSH access - "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244", - protocol=NoCaptureGeneratorProtocol, - stdin=upload_queue, - timeout=timeout, + "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244" ) - # file is open, we can start progress tracking progress_id = self._get_progress_id(source_name, to_url) - self._progress_report_start( - progress_id, - ('Upload %s to %s', source_name, to_url), - 'uploading', - expected_size, - ) try: - upload_size = 0 - while ssh_runner_generator.runner.process.poll() is None: - chunk = src_fp.read(COPY_BUFSIZE) - if chunk == b'': - break - chunk_size = len(chunk) - # compute hash simultaneously - hasher.update(chunk) - # we are just putting stuff in the queue, and rely on - # its maxsize to cause it to block the next call to - # have the progress reports be anyhow valid - upload_queue.put(chunk, timeout=timeout) - self._progress_report_update( - progress_id, ('Uploaded chunk',), chunk_size) - upload_size += chunk_size - # we're done, close queue - upload_queue.put(None, timeout=timeout) - - # Exhaust the generator, that might raise CommandError - # or TimeoutError, if timeout was not `None`. - tuple(ssh_runner_generator) - except CommandError as e: - if e.code == 244: - raise UrlOperationsResourceUnknown(to_url) from e - else: - raise UrlOperationsRemoteError(to_url, message=str(e)) from e - except (TimeoutError, Full): - ssh_runner_generator.runner.process.kill() - raise TimeoutError - finally: - self._progress_report_stop(progress_id, ('Finished upload',)) - - assert ssh_runner_generator.return_code == 0, "Unexpected ssh " \ - f"return value: {ssh_runner_generator.return_code}" + with iter_subproc( + cmd, + input=side_effect( + lambda chunk: self._progress_report_update( + progress_id, ('Uploaded chunk',), len(chunk) + ), + iter(upload_queue.get, None), + before=partial( + self._progress_report_start, + progress_id, + ('Upload %s to %s', source_name, to_url), + 'uploading', + expected_size + ) + ) + ): + upload_size = 0 + for chunk in iter(partial(src_fp.read, COPY_BUFSIZE), b''): + + # we are just putting stuff in the queue, and rely on + # its maxsize to cause it to block the next call to + # have the progress reports be anyhow valid, we also + # rely on put-timeouts to implement timeout. + upload_queue.put(chunk, timeout=timeout) + + # compute hash simultaneously + hasher.update(chunk) + upload_size += len(chunk) + + upload_queue.put(None, timeout=timeout) + + except IterableSubprocessError as e: + self._check_return_code(e.returncode, to_url) + except Full: + if chunk != b'': + # we had a timeout while uploading + raise TimeoutError return { **hasher.get_hexdigest(), @@ -320,7 +326,7 @@ def _perform_upload(self, } -class _SshCat: +class _SshCommandBuilder: def __init__(self, url: str, *additional_ssh_args): self._parsed = urlparse(url) # make sure the essential pieces exist @@ -328,11 +334,7 @@ def __init__(self, url: str, *additional_ssh_args): assert self._parsed.path self.ssh_args: list[str] = list(additional_ssh_args) - def run(self, - payload_cmd: str, - protocol: type[RunnerProtocol], - stdin: Queue | None = None, - timeout: float | None = None) -> Any | Generator: + def get_cmd(self, payload_cmd: str) -> list[str]: fpath = self._parsed.path cmd = ['ssh'] cmd.extend(self.ssh_args) @@ -344,9 +346,4 @@ def run(self, fpath=fpath, ), ]) - return ThreadedRunner( - cmd=cmd, - protocol_class=protocol, - stdin=subprocess.DEVNULL if stdin is None else stdin, - timeout=timeout, - ).run() + return cmd diff --git a/datalad_next/url_operations/tests/test_ssh.py b/datalad_next/url_operations/tests/test_ssh.py index c2293783f..7d3db8747 100644 --- a/datalad_next/url_operations/tests/test_ssh.py +++ b/datalad_next/url_operations/tests/test_ssh.py @@ -1,3 +1,4 @@ +import contextlib import io import pytest @@ -79,7 +80,7 @@ def test_ssh_url_upload(tmp_path, monkeypatch): # this may seem strange for SSH, but FILE does it too. # likewise an HTTP upload is also not required to establish # server-side preconditions first. - # this functionality is not about about exposing a full + # this functionality is not about exposing a full # remote FS abstraction -- just upload ops.upload(payload_path, upload_url) assert upload_path.read_text() == payload @@ -114,13 +115,13 @@ def test_ssh_url_upload_timeout(tmp_path, monkeypatch): upload_url = f'ssh://localhost/not_used' ssh_url_ops = SshUrlOperations() - def mocked_popen(*args, **kwargs): - from subprocess import Popen - args = (['sleep', '3'],) + args[1:] - return Popen(*args, **kwargs) + @contextlib.contextmanager + def mocked_iter_subproc(*args, **kwargs): + yield None with monkeypatch.context() as mp_ctx: - import datalad - mp_ctx.setattr(datalad.runner.nonasyncrunner, "Popen", mocked_popen) + import datalad_next.url_operations.ssh + mp_ctx.setattr(datalad_next.url_operations.ssh, 'iter_subproc', mocked_iter_subproc) + mp_ctx.setattr(datalad_next.url_operations.ssh, 'COPY_BUFSIZE', 1) with pytest.raises(TimeoutError): ssh_url_ops.upload(payload_path, upload_url, timeout=1)