diff --git a/datalad_next/iter_collections/gitworktree.py b/datalad_next/iter_collections/gitworktree.py index a2dee1b5..304cefbe 100644 --- a/datalad_next/iter_collections/gitworktree.py +++ b/datalad_next/iter_collections/gitworktree.py @@ -22,10 +22,13 @@ from datalad_next.runners import ( DEVNULL, - LineSplitter, - ThreadedRunner, - StdOutCaptureGeneratorProtocol, + StdOutCaptureProcessingGeneratorProtocol, ) +from datalad_next.runners.data_processors import ( + decode_processor, + splitlines_processor, +) +from datalad_next.runners.run import run from .utils import ( FileSystemItem, @@ -250,23 +253,27 @@ def _lsfiles_line2props( def _git_ls_files(path, *args): - # we use a plain runner to avoid the overhead of a GitRepo instance - runner = ThreadedRunner( - cmd=[ - 'git', 'ls-files', - # we rely on zero-byte splitting below - '-z', - # otherwise take whatever is coming in - *args, - ], - protocol_class=StdOutCaptureGeneratorProtocol, - stdin=DEVNULL, - # run in the directory we want info on - cwd=path, - ) - line_splitter = LineSplitter('\0', keep_ends=False) - # for each command output chunk received by the runner - for content in runner.run(): - # for each zerobyte-delimited "line" in the output - for line in line_splitter.process(content.decode('utf-8')): - yield line + with run( + cmd=[ + 'git', 'ls-files', + # we rely on zero-byte splitting below + '-z', + # otherwise take whatever is coming in + *args, + ], + protocol_class=StdOutCaptureProcessingGeneratorProtocol, + stdin=DEVNULL, + cwd=path, + protocol_kwargs=dict( + processors=[ + decode_processor('utf-8'), + splitlines_processor(separator='\0', keep_ends=False) + ] + ) + ) as r: + # This code uses the data processor chain to process data. This fixes + # a problem with the previous version of the code, where `decode` was + # used on every data chunk that was sent tp `pipe_data_received`. But + # data is chunked up randomly and might be split in the middle of a + # character encoding, leading to weird errors. + yield from r diff --git a/datalad_next/runners/__init__.py b/datalad_next/runners/__init__.py index cca244f9..9a38204f 100644 --- a/datalad_next/runners/__init__.py +++ b/datalad_next/runners/__init__.py @@ -61,6 +61,8 @@ from .protocols import ( NoCaptureGeneratorProtocol, StdOutCaptureGeneratorProtocol, + StdOutCaptureProcessingGeneratorProtocol, + StdOutErrCaptureProcessingGeneratorProtocol, ) # exceptions from datalad.runner.exception import ( diff --git a/datalad_next/runners/batch.py b/datalad_next/runners/batch.py new file mode 100644 index 00000000..807b2341 --- /dev/null +++ b/datalad_next/runners/batch.py @@ -0,0 +1,217 @@ +"""Helpers to execute batch commands + +Some of the functionality provided by this module depends on specific +"generator" flavors of runner protocols, and additional dedicated +low-level tooling: + +.. currentmodule:: datalad_next.runners.batch +.. autosummary:: + :toctree: generated + + StdOutCaptureGeneratorProtocol + GeneratorAnnexJsonProtocol + _ResultGenerator +""" +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from queue import Queue +from typing import ( + Any, + Callable, + Generator, +) + +from datalad.runner.nonasyncrunner import _ResultGenerator +from datalad.support.annexrepo import GeneratorAnnexJsonProtocol + +from . import Protocol +from .protocols import StdOutCaptureGeneratorProtocol +from .run import run + + +class BatchProcess: + """Representation of a (running) batch process + + An instance of this class is produced by any of the context manager variants + in this module. It is a convenience wrapper around an instance of a + :class:`_ResultGenerator` that is produced by a :meth:`ThreadedRunner.run`. + + A batch process instanced is used by passing ``bytes`` input to its + ``__call__()`` method, and receiving the batch output as return value. + + While the process is still running, it ``return_code`` property will + be ``None``. After the has terminated, the property will contain the + respective exit status. + """ + def __init__(self, rgen: _ResultGenerator): + self._rgen = rgen + self._stdin_queue = rgen.runner.stdin_queue + + def __call__(self, data: bytes | None) -> Any: + self._stdin_queue.put(data) + try: + return next(self._rgen) + except StopIteration: + return None + + def close_stdin(self): + return self(None) + + @property + def return_code(self) -> None | int: + return self._rgen.return_code + + +@contextmanager +def batchcommand( + cmd: list, + protocol_class: type[Protocol], + cwd: Path | None = None, + closing_action: Callable | None = None, + terminate_time: int | None = None, + kill_time: int | None = None, + **protocol_kwargs +) -> Generator[BatchProcess, None, None]: + """Generic context manager for batch processes + + ``cmd`` is an ``argv``-list specification of a command. It is executed via + a :class:`~datalad_next.runners.run.run` context manager. This context + manager is parameterized with ``protocol_class`` (which can take any + implementation of a DataLad runner protocol), and optional keyword arguments + that are passed to the protocol class. + + On leaving the context, the manager will perform a "closing_action". By + default, this is to close ``stdin`` of the underlying process. This will + typically cause the underlying process to exit. A caller can specify an + alternative function, i.e. ``closing_action``. If ``closing_action`` is set, + the function will be called with two arguments. The first argument is the + :class:`BatchProcess`-instance, the second argument is the stdin-queue of + the subprocess. + A custom ``closing_action`` might, for example, send some kind of exit + command to the subprocess, and then close stdin. This method exists because + the control flow might enter the exit-handler through different mechanisms. + One mechanism would be an un-caught exception. + + If ``terminate_time`` is given, the context handler will send a + terminate-signal to the process, if it is still running ``terminate_time`` + seconds after the context is left. If ``kill_time`` is given, the context + handler will send a kill-signal to the process, if it is still running + ``(terminate_time or 0) + kill_time`` seconds after the context is left. + + If neither ``terminate_time`` nor ``kill_time`` are set and the process + is not triggered to exit, e.g. because its stdin is not closed or because + it requires different actions to trigger its exit, batchcommand will wait + forever after the context exited. Note that the context might also be + exited in an unexpected way by an ``Èxception``. + + While this generic context manager can be used directly, it can be + more convenient to use any of the more specialized implementations + that employ a specific protocol (e.g., :func:`stdout_batchcommand`, + :func:`annexjson_batchcommand`). + + Parameters + ---------- + cmd : list[str] + A list containing the command and its arguments (argv-like). + cwd : Path | None + If not ``None``, determines a new working directory for the command. + closing_action: Callable | None + If not ``None``, contains a callable that will be called when the context + is left. The callable is invoked with two arguments, the + :class:`BatchProcessor`-instance and the stdin-queue. + if ``closing_action`` is ``None``, :func:`batchcommand`will close stdin + of the subprocess by calling the method :meth:`BatchProcess.close_stdin`. + terminate_time: int | None + The number of timeouts after which a terminate-signal will be sent to + the process, if it is still running. If no timeouts were provided in the + ``timeout``-argument, the timeout is set to ``1.0``. + kill_time: int | None + See documentation of :func:`datalad_next.runners.run.run`. + protocol_kwargs: dict + If ``terminate_time`` is given, a kill-signal will be sent to the + subprocess after kill-signal after ``terminate_time + kill_time`` + timeouts. If ``terminate_time`` is not set, a kill-signal will be sent + after ``kill_time`` timeouts. + It is a good idea to set ``kill_time`` and ``terminate_time`` in order + to let the process exit gracefully, if it is capable to do so. + + Yields + ------- + BatchProcess + A :class:`BatchProcess`-instance that can be used to interact with the + cmd + + """ + input_queue = Queue() + try: + run_context_manager = run( + cmd=cmd, + protocol_class=protocol_class, + stdin=input_queue, + cwd=cwd, + terminate_time=terminate_time, + kill_time=kill_time, + **protocol_kwargs + ) + with run_context_manager as result_generator: + batch_process = BatchProcess(result_generator) + yield batch_process + if closing_action: + closing_action(batch_process, input_queue) + else: + batch_process.close_stdin() + finally: + del input_queue + + +def stdout_batchcommand( + cmd: list, + cwd: Path | None = None, + closing_action: Callable | None = None, + terminate_time: int | None = None, + kill_time: int | None = None, +) -> Generator[BatchProcess, None, None]: + """Context manager for commands that produce arbitrary output on ``stdout`` + + Internally this calls :func:`batchcommand` with the + :class:`StdOutCaptureGeneratorProtocol` protocol implementation. See the + documentation of :func:`batchcommand` for a description of the parameters. + """ + return batchcommand( + cmd, + protocol_class=StdOutCaptureGeneratorProtocol, + cwd=cwd, + closing_action=closing_action, + terminate_time=terminate_time, + kill_time=kill_time, + ) + + +def annexjson_batchcommand( + cmd: list, + cwd: Path | None = None, + closing_action: Callable | None = None, + terminate_time: int | None = None, + kill_time: int | None = None, +) -> Generator[BatchProcess, None, None]: + """ + Context manager for git-annex commands that support ``--batch --json`` + + The given ``cmd``-list must be complete, i.e., include + ``git annex ... --batch --json``, and any additional flags that may be + needed. + + Internally this calls :func:`batchcommand` with the + :class:`GeneratorAnnexJsonProtocol` protocol implementation. See the + documentation of :func:`batchcommand` for a description of the parameters. + """ + return batchcommand( + cmd, + protocol_class=GeneratorAnnexJsonProtocol, + cwd=cwd, + closing_action=closing_action, + terminate_time=terminate_time, + kill_time=kill_time, + ) diff --git a/datalad_next/runners/data_processor_pipeline.py b/datalad_next/runners/data_processor_pipeline.py new file mode 100644 index 00000000..d7e4dfe5 --- /dev/null +++ b/datalad_next/runners/data_processor_pipeline.py @@ -0,0 +1,166 @@ +""" +This module contains the implementation of a data processing pipeline driver. +The data processing pipeline takes chunks of bytes as input and feeds them +into a list of data processors, i.e. the data processing pipeline. + +Data processing can be performed via calls to +:meth:`ProcessorPipeline.process` and :meth:`ProcessorPipeline.finalize`. +Alternatively, it can be performed over data chunks that are yielded by a +generator via the method :meth:`ProcessorPipeline.process_from`, which +creates a new generator that will yield the results of the data processing +pipeline. + +Typical data processors would be: + +- decode a stream of bytes +- split a stream of characters at line-ends +- convert a line of text into JSON + +Data processors have a common interface and can be chained. For example, +one can pass data chunks, where each chunk is a byte-string, into a chain +of two data processors: a decode-processor that converts bytes into strings, +and a linesplit-processor that converts character-streams into lines. The result +of the chain would be lines of text. + +Data processors are callables that have the following signature:: + + def process(data: list[T], final: bool) -> tuple[list[N] | None, list[T]]: + ... + +where N is the type that is returned by processor. The return value is a +consisting of optional results, i.e. list[N] | None, and a number of input +elements that were not processed and should be presented again, when more +data arrives from the "preceding" element. + +Data processors might need to buffer some data before yielding their result. The +"driver" of the data processing chains supports the buffering of input data for +individual processors. Therefore, data processors do not need to store +state themselves and can be quite simple. + +The module currently supports the following data processors: + + - ``jsonline_processor`` + - ``decode_processor`` + - ``splitlines_processor`` + - ``pattern_processor` + + +""" +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Generator +from typing import ( + Any, + Callable, + Iterable, + List, + Union, +) + + +StrList = List[str] +BytesList = List[bytes] +StrOrBytes = Union[str, bytes] +StrOrBytesList = List[StrOrBytes] + + +class DataProcessorPipeline: + """ + Hold a list of data processors and pushes data through them. + + Calls the processors in the specified order and feeds the output + of a preceding processor into the following processor. If a processor + has unprocessed data, either because it did not have enough data to + successfully process it, or because not all data was processed, it returns + the unprocessed data to the `process`-method and will receive it together + with newly arriving data in the "next round". + """ + def __init__(self, + processors: list[Callable] + ) -> None: + self.processors = processors + self.waiting_data: dict[Callable, list] = defaultdict(list) + self.remaining = None + self.finalized = False + + def process(self, data: bytes) -> list[Any]: + output = [data] + for processor in self.processors: + if self.waiting_data[processor]: + output = self.waiting_data[processor] + output + output, self.waiting_data[processor] = processor(output) + if not output: + # If this processor does not output anything then the next + # one has only the input that he already has buffered. We can + # therefore end here. + break + return output + + def finalize(self) -> list[Any]: + assert self.finalized is False, f'finalize() called repeatedly on {self}' + self.finalized = True + output = [] + for processor in self.processors: + if self.waiting_data[processor]: + output = self.waiting_data[processor] + output + # We used to do the following + if not output: + continue + # This cannot be done anymore because some processors store internal + # data, e.g. SplitLinesProcessor. Those would not require an input + # to generate an output on the final round. + output, self.waiting_data[processor] = processor(output, True) + return output + + def process_from(self, data_source: Iterable) -> Generator: + """ pass output from a generator through this pipeline and yield output + + This method takes an existing byte-yielding generator, uses it as input + and executes the specified processors over it. The result of the first + processor is fed into the second processor and so on. The result of the + last processor is yielded by the function. + + Parameters + ---------- + data_source : Iterable + An iterable object or generator that will deliver a byte stream in a + number of chunks + + Yields + ------- + Any + Individual responses that were created by the last processor + """ + for data_chunk in data_source: + result = self.process(data_chunk) + if result: + yield from result + result = self.finalize() + if result: + yield from result + + +def process_from(data_source: Iterable, + processors: list[Callable] + ) -> Generator: + """ A convenience wrapper around the ProcessorPipeline.process_from-method + + Parameters + ---------- + data_source : Iterable + An iterable object or generator that will deliver a byte stream in a + number of chunks + + processors : List[Callable] + The list of processors that process the incoming data. The processors are + receiving the data in the order `processors[0], processor[1], ..., + processor[-1]` + + Yields + ------- + Any + Individual responses that were created by the last processor + """ + processor_pipeline = DataProcessorPipeline(processors) + yield from processor_pipeline.process_from(data_source=data_source) diff --git a/datalad_next/runners/data_processors/__init__.py b/datalad_next/runners/data_processors/__init__.py new file mode 100644 index 00000000..3d5214c7 --- /dev/null +++ b/datalad_next/runners/data_processors/__init__.py @@ -0,0 +1,19 @@ +""" This module contains data processors for the data pipeline processor + +Available data processors: + +.. currentmodule:: datalad_next.runners.data_processors +.. autosummary:: + :toctree: generated + + decode + jsonline + pattern + splitlines + +""" + +from .decode import decode_processor +from .jsonline import jsonline_processor +from .pattern import pattern_processor +from .splitlines import splitlines_processor diff --git a/datalad_next/runners/data_processors/decode.py b/datalad_next/runners/data_processors/decode.py new file mode 100644 index 00000000..4bb8ed06 --- /dev/null +++ b/datalad_next/runners/data_processors/decode.py @@ -0,0 +1,85 @@ +""" Data processor that decodes bytes into strings """ + +from __future__ import annotations + +from typing import Callable + +from ..data_processor_pipeline import ( + BytesList, + StrList, +) + + +__all__ = ['decode_processor'] + + +def decode_processor(encoding: str = 'utf-8') -> Callable: + """ create a data processor that decodes a byte-stream + + The created data processor will decode byte-streams, even if the encoding + is split at chunk borders. + If an encoding error occurs on the final data chunk, the un-decodable bytes + will be replaced with their escaped hex-values, i.e. ``\\xHH``, + for hex-value HH. + + Parameters + ---------- + encoding: str + The name of encoding that should be decoded. + + Returns + ------- + Callable + A data processor that can be used in a processing pipeline to decode + chunks of bytes. The result are chunks of strings. + """ + return _DecodeProcessor(encoding) + + +class _DecodeProcessor: + """ Decode a byte-stream, even if the encoding is split at chunk borders + + Instances of this class can be used as data processors. + """ + def __init__(self, encoding: str = 'utf-8') -> None: + """ + + Parameters + ---------- + encoding: str + The type of encoding that should be decoded. + """ + self.encoding = encoding + + def __call__(self, data_chunks: BytesList, + final: bool = False + ) -> tuple[StrList, BytesList]: + """ The data processor interface + + This allows instances of :class:``DecodeProcessor`` to be used as + data processor in pipeline definitions. + + Parameters + ---------- + data_chunks: list[bytes] + a list of bytes (data chunks) that should be decoded + final : bool + the data chunks are the final data chunks of the source. If an + encoding error happens, the offending bytes will be replaced with + their escaped hex-values, i.e. ``\\xHH``, for hex-value HH. + + Returns + ------- + list[str] + the decoded data chunks, possibly joined + """ + try: + text = (b''.join(data_chunks)).decode(self.encoding) + except UnicodeDecodeError: + if final: + text = (b''.join(data_chunks)).decode( + self.encoding, + errors='backslashreplace') + else: + return [], data_chunks + return [text], [] diff --git a/datalad_next/runners/data_processors/jsonline.py b/datalad_next/runners/data_processors/jsonline.py new file mode 100644 index 00000000..43aadbba --- /dev/null +++ b/datalad_next/runners/data_processors/jsonline.py @@ -0,0 +1,46 @@ +""" Data processor that generates JSON objects from lines of bytes or strings """ + +from __future__ import annotations + +import json +from typing import Any + +from ..data_processor_pipeline import StrOrBytesList + + +def jsonline_processor(lines: StrOrBytesList, + _: bool = False + ) -> tuple[list[tuple[bool, Any]], StrOrBytesList]: + """ A data processor that converts lines into JSON objects, if possible. + + Parameters + ---------- + lines: StrOrBytesList + A list containing strings or byte-strings that that hold JSON-serialized + data. + + _: bool + The ``final`` parameter is ignored because lines are assumed to be + complete and the conversion takes place for every line. Consequently, + no remaining input data exists, and there is no need for "flushing" in + a final round. + + Returns + ------- + tuple[list[Tuple[bool, StrOrBytes]], StrOrByteList] + The result, i.e. the first element of the result tuple, is a list that + contains one tuple for each element of ``lines``. The first element of the + tuple is a bool that indicates whether the line could be converted. If it + was successfully converted the value is ``True``. The second element is the + Python structure that resulted from the conversion if the first element + was ``True``. If the first element is ``False``, the second element contains + the input that could not be converted. + """ + result = [] + for line in lines: + assert len(line.splitlines()) == 1 + try: + result.append((True, json.loads(line))) + except json.decoder.JSONDecodeError: + result.append((False, lines)) + return result, [] diff --git a/datalad_next/runners/data_processors/pattern.py b/datalad_next/runners/data_processors/pattern.py new file mode 100644 index 00000000..df9aba34 --- /dev/null +++ b/datalad_next/runners/data_processors/pattern.py @@ -0,0 +1,100 @@ +""" Data processor that ensure that a pattern odes not cross data chunk borders """ + +from __future__ import annotations + +from functools import partial +from typing import Callable + +from ..data_processor_pipeline import ( + StrOrBytes, + StrOrBytesList, +) + + +__all__ = ['pattern_processor'] + + +def pattern_processor(pattern: StrOrBytes) -> Callable: + """ Create a pattern processor for the given ``pattern``. + + A pattern processor re-assembles data chunks in such a way, that a single + data chunk could contain the complete pattern and will contain the complete + pattern, if the complete pattern start in the data chunk. It guarantees: + + 1. All chunks have at minimum the size of the pattern + 2. If a complete pattern exists, it will be contained completely within a + single chunk, i.e. it will NOT be the case that a prefix of the pattern + is at the end of a chunk, and the rest of the pattern in the beginning + of the next chunk + + The pattern might be present multiple times in a data chunk. + """ + assert len(pattern) > 0 + return partial(_pattern_processor, pattern) + + +def _pattern_processor(pattern: StrOrBytes, + data_chunks: StrOrBytesList, + final: bool = False, + ) -> tuple[StrOrBytesList, StrOrBytesList]: + """ Ensure that ``pattern`` appears only completely contained within a chunk + + This processor ensures that a given data pattern (if it exists in the data + chunks) is either completely contained in a chunk or not in the chunk. That + means the processor ensures that all data chunks have at least the length of + the data pattern and that they do not end with a prefix of the data pattern. + + As a result, a simple ``pattern in data_chunk`` test is sufficient to + determine whether a pattern appears in the data stream. + + To use this function as a data processor, use partial to "fix" the first + parameter. + + Parameters + ---------- + pattern: str | bytes + The pattern that should be contained in the chunks + data_chunks: list[str | bytes] + a list of strings or bytes + final : bool + the data chunks are the final data chunks of the source. A line is + terminated by end of data. + + Returns + ------- + list[str | bytes] + data chunks that have at least the size of the pattern and do not end + with a prefix of the pattern. Note that a data chunk might contain the + pattern multiple times. + """ + + def ends_with_pattern_prefix(data: StrOrBytes, pattern: StrOrBytes) -> bool: + """ Check whether the chunk ends with a prefix of the pattern """ + for index in range(len(pattern) - 1, 0, -1): + if data[-index:] == pattern[:index]: + return True + return False + + # Copy the list, because we might modify it and the caller might not expect that. + data_chunks = data_chunks[:] + + # Join data chunks until they are sufficiently long to contain the pattern, + # i.e. have a least the size: `len(pattern)`. Continue joining, if the chunk + # ends with a prefix of the pattern. + current_index = 0 + while current_index < len(data_chunks) - 1: + current_chunk = data_chunks[current_index] + while (len(data_chunks[current_index:]) > 1 + and (len(current_chunk) < len(pattern) + or ends_with_pattern_prefix(current_chunk, pattern))): + data_chunks[current_index] += data_chunks[current_index + 1] + del data_chunks[current_index + 1] + current_chunk = data_chunks[current_index] + current_index += 1 + + # At this point we have joined whatever we can join. We still have to check + # whether the last chunk ends with a pattern-prefix. + if not final: + if ends_with_pattern_prefix(data_chunks[-1], pattern): + return data_chunks[:-1], data_chunks[-1:] + return data_chunks, [] diff --git a/datalad_next/runners/data_processors/splitlines.py b/datalad_next/runners/data_processors/splitlines.py new file mode 100644 index 00000000..f8818a89 --- /dev/null +++ b/datalad_next/runners/data_processors/splitlines.py @@ -0,0 +1,107 @@ +""" Data processor that splits the input into individual lines """ + +from __future__ import annotations + +from functools import partial +from typing import Callable + +from ..data_processor_pipeline import ( + StrOrBytes, + StrOrBytesList, +) + + +__all__ = ['splitlines_processor'] + + +def splitlines_processor( + separator: StrOrBytes | None = None, + keep_ends: bool = True +) -> Callable[[StrOrBytesList, bool], tuple[StrOrBytesList, StrOrBytesList]]: + """ Generate a data processor that splits character- or byte-strings into lines + + This function returns a data processor, that splits lines either on a given + separator, if 'separator' is not ``None``, or on one of the known line endings, + if 'separator' is ``None``. If ``separator`` is ``None``, the line endings are + determined by python. + + Parameters + ---------- + separator: Optional[str] + If not None, the provided separator will be used to split lines. + keep_ends: bool + If True, the separator will be contained in the returned lines. + + Returns + ------- + Callable + A data processor that takes a list of strings or bytes, and returns + a list of strings or bytes, where every element is a single line (as + defined by the ``separator`` and ``keep_ends`` argument). + """ + return partial(_splitlines_processor, separator, keep_ends) + + +# We don't use LineSplitter here because it has two "problems". Firstly, it does +# not support `bytes`. Secondly, it can not be properly re-used because it does +# not delete its internal storage when calling `LineSplitter.finish_processing`. +# The issue https://github.com/datalad/datalad/issues/7519 has been created to +# fix the problem upstream. Until then we use this code. +def _splitlines_processor(separator: StrOrBytes | None, + keep_ends: bool, + data_chunks: StrOrBytesList, + final: bool = False + ) -> tuple[StrOrBytesList, StrOrBytesList]: + """ Implementation of character-strings or byte-strings line splitting + + This function implements the line-splitting data processor and is used + by :func:`splitlines_processor` below. + + To use this function as a data processor, use partial to "fix" the first + two parameter. + + Parameters + ---------- + separator: Optional[str] + If not None, the provided separator will be used to split lines. + keep_ends: bool + If True, the separator will be contained in the returned lines. + data_chunks: list[str | bytes] + a list of strings or bytes + final : bool + the data chunks are the final data chunks of the source. A line is + terminated by end of data. + + Returns + ------- + list[str | bytes] + if the input data chunks contained bytes the result will be a list of + byte-strings that end with byte-size line-delimiters. If the input data + chunks contained strings, the result will be a list strings that end with + string delimiters (see Python-documentation for a definition of string + line delimiters). + """ + # We use `data_chunks[0][0:0]` to get an empty value the proper type, i.e. + # either the string `''` or the byte-string `b''`. + empty = data_chunks[0][0:0] + text = empty.join(data_chunks) + if separator is None: + # Use the builtin line split-wisdom of Python + parts_with_ends = text.splitlines(keepends=True) + parts_without_ends = text.splitlines(keepends=False) + lines = parts_with_ends if keep_ends else parts_without_ends + if parts_with_ends[-1] == parts_without_ends[-1] and not final: + return lines[:-1], [parts_with_ends[-1]] + return lines, [] + else: + detected_lines = text.split(separator) + remaining = detected_lines[-1] if text.endswith(separator) else None + del detected_lines[-1] + if keep_ends: + result = [line + separator for line in detected_lines], [remaining] if remaining else [] + else: + result = detected_lines, [remaining] if remaining else [] + if final: + result[0].extend(result[1]) + result = result[0], [] + return result diff --git a/datalad_next/runners/protocols.py b/datalad_next/runners/protocols.py index b366a173..be3111a4 100644 --- a/datalad_next/runners/protocols.py +++ b/datalad_next/runners/protocols.py @@ -1,8 +1,14 @@ +from __future__ import annotations + +from typing import Optional + from . import ( GeneratorMixIn, NoCapture, StdOutCapture, + StdOutErrCapture, ) +from .data_processor_pipeline import DataProcessorPipeline # @@ -29,3 +35,79 @@ def pipe_data_received(self, fd: int, data: bytes): def timeout(self, fd): raise TimeoutError(f"Runner timeout {fd}") + + +class StdOutCaptureProcessingGeneratorProtocol(StdOutCaptureGeneratorProtocol): + """ A generator protocol that applies a processor pipeline to stdout data + + This protocol can be initialized with a list of processors. Data that is + read from stdout will be processed by the processors and the result of the + last processor will be sent to the result generator, which will then + yield it. + """ + def __init__(self, + done_future=None, + processors: list | None = None + ) -> None: + StdOutCaptureGeneratorProtocol.__init__(self, done_future, None) + self.processor_pipeline = ( + DataProcessorPipeline(processors) + if processors + else None + ) + + def pipe_data_received(self, fd: int, data: bytes): + assert fd == 1 + if self.processor_pipeline: + for processed_data in self.processor_pipeline.process(data): + self.send_result(processed_data) + return + self.send_result(data) + + def pipe_connection_lost(self, fd: int, exc: Optional[BaseException]) -> None: + assert fd == 1 + if self.processor_pipeline: + for processed_data in self.processor_pipeline.finalize(): + self.send_result(processed_data) + + +class StdOutErrCaptureProcessingGeneratorProtocol(StdOutErrCapture, GeneratorMixIn): + """ A generator protocol that applies processor-pipeline to stdout- and stderr-data + + This protocol can be initialized with a list of processors for stdout-data, + and with a list of processors for stderr-data. Data that is read from stdout + or stderr will be processed by the respective processors. The protocol will + send 2-tuples to the result generator. Each tuple consists of the file + descriptor on which data arrived and the output of the last processor of the + respective pipeline. The result generator. which will then yield the + results. + """ + def __init__(self, + done_future=None, + stdout_processors: list | None = None, + stderr_processors: list | None = None, + ) -> None: + StdOutErrCapture.__init__(self, done_future, None) + GeneratorMixIn.__init__(self) + self.processor_pipelines = { + fd: DataProcessorPipeline(processors) + for fd, processors in ((1, stdout_processors), (2, stderr_processors)) + if processors is not None + } + + def pipe_data_received(self, fd: int, data: bytes): + assert fd in (1, 2) + if fd in self.processor_pipelines: + for processed_data in self.processor_pipelines[fd].process(data): + self.send_result((fd, processed_data)) + return + self.send_result((fd, data)) + + def pipe_connection_lost(self, fd: int, exc: Optional[BaseException]) -> None: + assert fd in (1, 2) + if fd in self.processor_pipelines: + for processed_data in self.processor_pipelines[fd].finalize(): + self.send_result((fd, processed_data)) + + def timeout(self, fd): + raise TimeoutError(f"Runner timeout {fd}") diff --git a/datalad_next/runners/run.py b/datalad_next/runners/run.py new file mode 100644 index 00000000..8504d3ef --- /dev/null +++ b/datalad_next/runners/run.py @@ -0,0 +1,258 @@ +""" +This module provides a run-context manager that executes a subprocess and +can guarantee that the subprocess is terminated when the context is left. +""" +from __future__ import annotations + +import subprocess +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path +from queue import Queue +from subprocess import DEVNULL +from typing import ( + Any, + Callable, + IO, +) + +from . import ( + GeneratorMixIn, + Protocol, + ThreadedRunner, +) + + +def _create_kill_wrapper(cls: type[Protocol]) -> type[Protocol]: + """ Extend ``cls`` to supports the "kill-interface" + + This function creates a subclass of `cls` that contains the components + of the "kill-interface". The two major components are a method called + `arm`, and logic inside the timeout handler that can trigger a termination + or kill signal to the subprocess, if the termination time or kill time has + come. + + Parameters + ---------- + cls : type[Protocol] + A protocol class that should be extended by the kill-interface + + Returns + ------- + KillWrapper + A protocol class that inherits `cls` and implements the kill logic + that is used by the run-context-manager to forcefully terminate + subprocesses. + """ + + class KillWrapper(cls): + def __init__(self, *args, **kwargs): + kill_wrapper_kwargs = kwargs.pop('dl_kill_wrapper_kwargs') + self.armed = kill_wrapper_kwargs.pop('armed') + self.introduced_timeout = kill_wrapper_kwargs.pop('introduced_timeout') + self.terminate_time = kill_wrapper_kwargs.pop('terminate_time') + kill_time = kill_wrapper_kwargs.pop('kill_time') + self.kill_time = ( + ((self.terminate_time or 0) + kill_time) + if kill_time is not None + else kill_time + ) + + self.process: subprocess.Popen | None = None + self.return_code: int | None = None + self.kill_counter: int = 0 + + super().__init__(*args, **kwargs) + + def arm(self) -> None: + self.kill_counter = 0 + self.armed = True + + def connection_made(self, process: subprocess.Popen) -> None: + self.process = process + super().connection_made(process) + + def timeout(self, fd: int | None) -> bool: + if self.armed and fd is None: + self.kill_counter += 1 + if self.kill_time and self.kill_counter >= self.kill_time: + self.process.kill() + self.kill_time = None + if self.terminate_time and self.kill_counter >= self.terminate_time: + self.process.terminate() + self.terminate_time = None + + # If we set the timeout argument due to a not-None kill_time + # or a not-None terminate_time, and due to a None timeout parameter, + # we leave the timeout handler here. + if self.introduced_timeout: + return False + + # If the timeout was set by the user of the context, we execute + # the timeout handler of the superclass. + return super().timeout(fd) + + def process_exited(self) -> None: + self.return_code = self.process.poll() + + return KillWrapper + + +@contextmanager +def run( + cmd: list, + protocol_class: type[Protocol], + *, + cwd: Path | None = None, + stdin: int | IO | bytes | Queue[bytes | None] | None = None, + timeout: float | None = None, + closing_action: Callable | None = None, + terminate_time: int | None = None, + kill_time: int | None = None, + protocol_kwargs: dict | None = None, +) -> Any | Generator: + """ A context manager for subprocesses + + The run-context manager will start a subprocess via ``ThreadedRunner``, + provide the result of the subprocess invocation, i.e. the result of + ``ThreadedRunner.run()`` in the ``as``-variable, and + clean up the subprocess when the context is left. + + The run-context manager supports the guaranteed exit of a subprocess through + either: + + a) natural exit of the subprocess + b) termination of the subprocess via SIGTERM, if ``terminate_time`` is + specified + c) termination of the subprocess via SIGKILL, if ``kill_time`` is specified + + If the process terminates the run-context manager will ensure that its exit + status is read, in order to prevent zombie processes. + + If neither ``terminate_time`` nor ``kill_time`` are specified, and the + subprocess does not exit by itself, for example, because it waits for some + input, the ``__exit__``-method of the run-context manager will never return. + In other words the thread will seem to "hang" when leaving the run-context. + The only way to ensure that the context is eventually left is to provide a + ``kill_time``. It is a good idea to provide a ``terminate_time`` in + addition, to allow the subprocess a graceful exit (see ``kill_time``- and + ``terminate_time``-argument descriptions below). + + Generator- and non-generator-protocols are both supported by the + context manager. Depending on the type of the provided protocol the + interpretation of ``terminate_time`` and ``kill_time`` are different. + + If a non-generator-protocol is used, counting of the ``terminate_time`` and + the ``kill_time`` starts when the subprocess is started. + + If a generator-protocol is used, counting of the ``terminate_time`` and + the ``kill_time`` starts when the run-context is left. + + Parameters + ---------- + cmd : list[str] + The command list that is passed to ``ThreadedRunner.run()`` + protocol_class : type[Protocol] + The protocol class that should be used to process subprocess events. + cwd: Path, optional + If provided, defines the current work directory for the subprocess + stdin: int | IO | bytes | Queue[bytes | None], optional + Input source or data for stdin of the subprocess. See the constructor + of :class:`ThreadedRunner` for a detailed description + timeout: float, optional + If provided, defines the time after which the timeout-callback of the + protocol will be enabled. See the constructor + of :class:`ThreadedRunner` for a detailed description + closing_action: Callable, optional + This argument is only used in generator-mode. + + If given, should be a callable that takes two arguments. The first + argument will be the :class:`ThreadedRunner`-instance that executes the + subprocess. The second argument will be the result generator, that was + returned by :meth:`ThreadedRunner.run`. + + The context + manager will call this function when the context is left. Because the + code executed in the context can be left unexpectedly there might be + actions that a user wants to perform in order to instruct the subprocess + to terminate. For example, if the subprocess is a git-annex command in + batch-mode, the user might want to close stdin of the subprocess. When + the ``closing_action`` returns, the run-context-manager will be "armed", + that means terminate- and kill-time counting begins. + terminate_time: int, optional + The number of timeouts after which a terminate-signal will be sent to + the process, if it is still running. If no timeouts were provided in the + ``timeout``-argument, the timeout is set to ``1.0``. + kill_time: int, optional + If ``terminate_time`` is given, a kill-signal will be sent to the + subprocess after kill-signal after ``terminate_time + kill_time`` + timeouts. If ``terminate_time`` is not set, a kill-signal will be sent + after ``kill_time`` timeouts. + It is a good idea to set ``kill_time`` and ``terminate_time`` in order + to let the process exit gracefully, if it is capable to do so. + protocol_kwargs : dict + A dictionary with Keyword arguments that will be used when + instantiating the protocol class. + + Yields + ------- + Any | Generator + The result of the invocation of :meth:`ThreadedRunner.run` is returned. + """ + introduced_timeout = False + if timeout is None: + introduced_timeout = True + timeout = 1.0 + + armed = False if issubclass(protocol_class, GeneratorMixIn) else True + + # Create the wrapping class. This is done mainly to ensure that the + # termination-related functionality is present in the protocol class that + # is used, independent of the actual protocol class that the user passes as + # argument. + # A side effect of this approach is, that the number of protocol class + # definitions is reduced, because the user does not need to define + # terminate-capable protocols for every protocol they want to use. + kill_protocol_class = _create_kill_wrapper(protocol_class) + + runner = ThreadedRunner( + cmd=cmd, + protocol_class=kill_protocol_class, + stdin=DEVNULL if stdin is None else stdin, + protocol_kwargs=dict( + **(protocol_kwargs or {}), + dl_kill_wrapper_kwargs=dict( + introduced_timeout=introduced_timeout, + terminate_time=terminate_time, + kill_time=kill_time, + armed=armed, + ) + ), + timeout=timeout, + exception_on_error=False, + cwd=cwd, + ) + result = runner.run() + # We distinguish between a non-generator run, i,e. a blocking run and + # a generator run. + if not issubclass(protocol_class, GeneratorMixIn): + yield result + else: + try: + yield result + finally: + # If the user provided a closing action, call it with the runner + # and the result generator as arguments. + if closing_action is not None: + closing_action(runner, result) + # Arm the protocol, that will enable terminate signaling or kill + # signaling, if terminate_time or kill_time are not None. + result.runner.protocol.arm() + # Exhaust the generator. Because we have set a timeout, this will + # lead to invocation of the timeout method of the wrapper, which + # will take care of termination or killing. And it will fetch + # the result code of the terminated process. + # NOTE: if terminate_time and kill_time are both None, this might + # loop forever. + for _ in result: + pass diff --git a/datalad_next/runners/tests/resources/shell_like_prog.py b/datalad_next/runners/tests/resources/shell_like_prog.py new file mode 100644 index 00000000..f37ab9e5 --- /dev/null +++ b/datalad_next/runners/tests/resources/shell_like_prog.py @@ -0,0 +1,34 @@ +""" +This program emulates a shell like behavior. + +Shell-like behavior means that a single line of input (From +stdin) leads to a random number of output lines, followed by +a known "marker". + +The program reads a single lines from stdin and uses the +stripped content as "marker". It then emits a random number +of output-lines. + +If the number of random output lines is odd, the last random output +lines is terminated with a newline, then the marker and a newline is +written out. + +If number of random output lines is even, the last random output +line is terminated by the marker and a newline. + +The program randomly outputs one additional, newline-terminated, +line after the marker +""" +import random +import sys +import time + + +marker = sys.stdin.readline().strip() +output_line_count = 1 + random.randrange(8) +last_end = '\n' if output_line_count % 2 == 1 else '' +for i in range(output_line_count): + print(time.time(), end='\n' if i < output_line_count - 1 else last_end) +print(marker) +if random.randrange(2) == 1: + print(f'random additional output {time.time()}') diff --git a/datalad_next/runners/tests/test_batch.py b/datalad_next/runners/tests/test_batch.py new file mode 100644 index 00000000..d99b5a06 --- /dev/null +++ b/datalad_next/runners/tests/test_batch.py @@ -0,0 +1,228 @@ +import os +import signal +import sys +import time +from queue import Queue + +from .. import ( + GeneratorMixIn, + StdOutErrCapture, +) +from ..batch import ( + BatchProcess, + annexjson_batchcommand, + batchcommand, + stdout_batchcommand, +) + + +class PythonProtocol(StdOutErrCapture, GeneratorMixIn): + """Parses interactive python output and enqueues complete output strings + + This is an example for a protocol that processes results of an a priori + unknown structure and length. + Instances of this class interpret the stdout- and stderr-output of a + python interpreter. They assemble decoded stdout content until the python + interpreter sends ``'>>> '`` on stderr. Then the assembled output is + returned as result. + + This requires to start the python interpreter in unbuffered mode! If not, + the ``stderr``-output can be processed too early, i.e. before all + ``stdout``-output is processed. This is due to the fact that the runner is + thread-based. The runner does not necessarily preserve the wall-clock-order + of events that arrive from different streams. + """ + def __init__(self): + StdOutErrCapture.__init__(self) + GeneratorMixIn.__init__(self) + self.stdout = '' + self.stderr = b'' + self.prompt_count = -1 + + def pipe_data_received(self, fd: int, data: bytes) -> None: + if fd == 1: + # We known that no multibyte encoded strings are used in the + # examples. Therefore, we don't have to care about encodings that + # are split between consecutive data chunks, and we can always + # successfully decode `data`. + self.stdout += data.decode() + elif fd == 2: + self.stderr += data + if len(self.stderr) >= 4 and b'>>> ' in self.stderr: + self.prompt_count += 1 + self.stderr = b'' + if self.stdout and self.prompt_count > 0: + self.send_result(self.stdout) + self.stdout = '' + self.prompt_count -= 1 + + +def test_batch_simple(existing_dataset): + # first with a simplistic protocol to test the basic mechanics + with stdout_batchcommand( + ['git', 'annex', 'examinekey', + # the \n in the format is needed to produce an output that hits + # the output queue after each input line + '--format', '${bytesize}\n', + '--batch'], + cwd=existing_dataset.pathobj, + ) as bp: + res = bp(b'MD5E-s21032--2f4e22eb05d58c21663794876dc701aa\n') + assert res.rstrip(b'\r\n') == b'21032' + # to subprocess is still running + assert bp.return_code is None + # another batch + res = bp(b'MD5E-s999--2f4e22eb05d58c21663794876dc701aa\n') + assert res.rstrip(b'\r\n') == b'999' + assert bp.return_code is None + # we can bring the process down with stupid input, because it is + # inside our context handlers, it will not raise CommandError. check exit instead + res = bp(b'stupid\n') + # process exit is detectable + assert res is None + assert bp.return_code == 1 + # continued use raises the same exception + # (but stacktrace is obvs different) + res = bp(b'MD5E-s999--2f4e22eb05d58c21663794876dc701aa\n') + assert res is None + assert bp.return_code == 1 + + # now with a more complex protocol (decodes JSON-lines output) + with annexjson_batchcommand( + ['git', 'annex', 'examinekey', '--json', '--batch'], + cwd=existing_dataset.pathobj, + ) as bp: + # output is a decoded JSON object + res = bp(b'MD5E-s21032--2f4e22eb05d58c21663794876dc701aa\n') + assert res['backend'] == "MD5E" + assert res['bytesize'] == "21032" + assert res['key'] == "MD5E-s21032--2f4e22eb05d58c21663794876dc701aa" + res = bp(b'MD5E-s999--2f4e22eb05d58c21663794876dc701aa\n') + assert res['bytesize'] == "999" + res = bp(b'stupid\n') + assert res is None + assert bp.return_code == 1 + + +def test_batch_killing(existing_dataset): + # to test killing we have to circumvent the automatic stdin-closing by + # BatchCommand. We do that by setting `closing_action` to an empty function. + with stdout_batchcommand( + ['git', 'annex', 'examinekey', + # the \n in the format is needed to produce an output that hits + # the output queue after each input line + '--format', '${bytesize}\n', + '--batch'], + cwd=existing_dataset.pathobj, + closing_action=lambda a, b: True, + terminate_time=2, + kill_time=2, + ) as bp: + leave_time = time.time() + + leave_time = time.time() - leave_time + # at this point the process should have been terminated after about 3 + # seconds, because git-annex behaves well and terminates when it receives + # a terminate signal + assert 1.5 < leave_time < 2.5 + assert bp.return_code not in (0, None) + if os.name == 'posix': + assert bp.return_code == -signal.SIGTERM + + +def test_annexjsonbatch_killing(existing_dataset): + # to test killing we have to circumvent the automatic stdin-closing by + # BatchCommand. We do that by setting `closing_action` to an empty function. + with annexjson_batchcommand( + ['git', 'annex', 'examinekey', '--json', '--batch'], + cwd=existing_dataset.pathobj, + closing_action=lambda a, b: True, + terminate_time=2, + kill_time=2, + ) as bp: + leave_time = time.time() + + leave_time = time.time() - leave_time + # at this point the process should have been terminated after about 2 + # seconds, because git-annex behaves well and terminates when it receives + # a terminate signal + assert 1.5 < leave_time < 2.5 + assert bp.return_code not in (0, None) + if os.name == 'posix': + assert bp.return_code == -signal.SIGTERM + + +def test_plain_batch_python_multiline(): + + def close_stdin(batch_process: BatchProcess, + stdin_queue: Queue): + assert isinstance(batch_process, BatchProcess) + assert isinstance(stdin_queue, Queue) + batch_process.close_stdin() + stdin_queue.put(None) + + prog = ''' +import time +def x(count): + for i in range(count): + print(i, flush=True) + time.sleep(.2) +''' + # We set a terminate and kill time here, because otherwise an exception + # that is raised in the `batchcommand`-context will get the test to hang. + # The reason is that the exception triggers an exit from the context, + # but the python process will never stop since we did neither close its + # stdin nor did we call the `exit()`-function. + with batchcommand([sys.executable, '-i', '-u', '-c', prog], + protocol_class=PythonProtocol, + terminate_time=3, + kill_time=2, + ) as python_interactive: + + # multiline output should be handled by the protocol, + for count in (5, 20): + response = python_interactive(f'x({count})\n'.encode()) + assert len(response.splitlines()) == count + python_interactive.close_stdin() + assert python_interactive.return_code == 0 + + # Test with unclosed stdin + with batchcommand([sys.executable, '-i', '-u', '-c', prog], + protocol_class=PythonProtocol, + terminate_time=3, + kill_time=2, + ) as python_interactive: + for count in (5, 20): + response = python_interactive(f'x({count})\n'.encode()) + assert len(response.splitlines()) == count + # Do not close stdin here, we let BatchCommand do that. + assert python_interactive.return_code == 0 + + # Test with closing action + with batchcommand([sys.executable, '-i', '-u', '-c', prog], + protocol_class=PythonProtocol, + terminate_time=3, + kill_time=2, + closing_action=close_stdin + ) as python_interactive: + for count in (5, 20): + response = python_interactive(f'x({count})\n'.encode()) + assert len(response.splitlines()) == count + # Do not close stdin here, we let the closing_action handle that. + assert python_interactive.return_code == 0 + + # Test with a "bad" closing action to ensure that only the closing action + # is called and not the internal stdin-closing of `BatchCommand`. + with batchcommand([sys.executable, '-i', '-u', '-c', prog], + protocol_class=PythonProtocol, + terminate_time=3, + kill_time=2, + closing_action=lambda a, b: None, + ) as python_interactive: + for count in (5, 20): + response = python_interactive(f'x({count})\n'.encode()) + assert len(response.splitlines()) == count + # Do not close stdin here, we let the closing_action handle that. + assert python_interactive.return_code not in (0, None) + if os.name == 'posix': + assert python_interactive.return_code == -signal.SIGTERM diff --git a/datalad_next/runners/tests/test_data_processors.py b/datalad_next/runners/tests/test_data_processors.py new file mode 100644 index 00000000..468c3116 --- /dev/null +++ b/datalad_next/runners/tests/test_data_processors.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import json +from itertools import chain + +from ..data_processor_pipeline import ( + DataProcessorPipeline, + process_from, +) +from ..data_processors.decode import decode_processor +from ..data_processors.jsonline import jsonline_processor +from ..data_processors.pattern import pattern_processor +from ..data_processors.splitlines import splitlines_processor + + +decode_utf8_processor = decode_processor() + +text = '''This is the first line of text +the second line of text, followed by an empty line + +4th line of text with some non-ASCII characters: äöß + + +{"key0": "some text \\u1822"} + +7th line with interesting characters: € 😃👽 +an a non-terminated line''' + +text_lines = text.splitlines(keepends=True) +text_data_chunks = [ + text.encode()[i:i+100] + for i in range(0, len(text.encode()) + 100, 100) +] + + +json_result = [ + (True, {'key1': 'simple'}), + (True, {'key2': 'abc\naabböäß'}), + (True, {'key3': {'key3.1': 1.2}}), +] + +json_text = '\n'.join([json.dumps(o[1]) for o in json_result]) +json_data_chunks = [ + json_text.encode()[i:i+100] + for i in range(0, len(json_text.encode()) + 100, 100) +] + + +def test_decoding_splitting(): + result = [ + line + for line in process_from( + data_source=text_data_chunks, + processors=[ + decode_utf8_processor, + splitlines_processor() + ] + ) + ] + assert result == text_lines + + +def test_json_lines(): + result = [ + json_info + for json_info in process_from( + data_source=json_data_chunks, + processors=[ + decode_utf8_processor, + splitlines_processor(), + jsonline_processor + ] + ) + ] + assert result == json_result + + +def test_faulty_json_lines(): + result = [ + json_info[1] + for json_info in process_from( + data_source=text_data_chunks, + processors=[ + decode_utf8_processor, + splitlines_processor(), + jsonline_processor + ] + ) + if json_info[0] is True + ] + assert len(result) == 1 + assert result[0] == {'key0': 'some text \u1822'} + + +def test_pattern_border_processor(): + from ..data_processors import pattern_processor + + def perform_test(data_chunks: list[str | bytes], + pattern: str | bytes, + expected_non_final: tuple[list[str | bytes], list[str | bytes]], + expected_final: tuple[list[str | bytes], list[str | bytes]]): + + copied_data_chunks = data_chunks[:] + for final, result in ((True, expected_final), (False, expected_non_final)): + r = pattern_processor(pattern)(data_chunks, final=final) + assert tuple(r) == result, f'failed with final {final}' + # Check that the original list was not modified + assert copied_data_chunks == data_chunks + + perform_test( + data_chunks=['a', 'b', 'c', 'd', 'e'], + pattern='abc', + expected_non_final=(['abc', 'de'], []), + expected_final=(['abc', 'de'], []), + ) + + perform_test( + data_chunks=['a', 'b', 'c', 'a', 'b', 'c'], + pattern='abc', + expected_non_final=(['abc', 'abc'], []), + expected_final=(['abc', 'abc'], []), + ) + + # Ensure that unaligned pattern prefixes are not keeping data chunks short + perform_test( + data_chunks=['a', 'b', 'c', 'dddbbb', 'a', 'b', 'x'], + pattern='abc', + expected_non_final=(['abc', 'dddbbb', 'abx'], []), + expected_final=(['abc', 'dddbbb', 'abx'], []), + ) + + # Expect that a trailing minimum length-chunk that ends with a pattern + # prefix is not returned as data, but as remainder, if it is not the final + # chunk + perform_test( + data_chunks=['a', 'b', 'c', 'd', 'a'], + pattern='abc', + expected_non_final=(['abc'], ['da']), + expected_final=(['abc', 'da'], []), + ) + + # Expect the last chunk to be returned as data, if final is True, although + # it ends with a pattern prefix. If final is false, the last chunk will be + # returned as a remainder, because it ends with a pattern prefix. + perform_test( + data_chunks=['a', 'b', 'c', 'dddbbb', 'a'], + pattern='abc', + expected_non_final=(['abc', 'dddbbb'], ['a']), + expected_final=(['abc', 'dddbbb', 'a'], []) + ) + + + perform_test( + data_chunks=['a', 'b', 'c', '9', 'a'], + pattern='abc', + expected_non_final=(['abc'], ['9a']), + expected_final=(['abc', '9a'], []) + ) + + +def test_processor_removal(): + + stream = iter([b'\1', b'\2', b'\3', b'9\1', b'content']) + + pattern = b'\1\2\3' + pipeline = DataProcessorPipeline([pattern_processor(pattern)]) + filtered_stream = pipeline.process_from(stream) + + # The first chunk should start with the pattern, i.e. b'\1\2\3' + chunk = next(filtered_stream) + assert chunk[:len(pattern)] == pattern + + # Remove the filter again. The chunk is extended to contain all + # data that was buffered in the pipeline. + buffered_chunks = pipeline.finalize() + chunk = b''.join([chunk[len(pattern):]] + buffered_chunks) + + # The length is transferred now and terminated by b'\x01'. + while b'\x01' not in chunk: + chunk += next(stream) + + marker_index = chunk.index(b'\x01') + expected_size = int(chunk[:marker_index]) + assert expected_size == 9 + chunk = chunk[marker_index + 1:] + + source = chain([chunk], stream) if chunk else stream + assert b''.join(source) == b'content' + + +def test_split_decoding(): + encoded = 'ö'.encode('utf-8') + part_1, part_2 = encoded[:1], encoded[1:] + + # check that incomplete encodings are caught + decoded, remaining = decode_utf8_processor([part_1]) + assert decoded == [] + assert remaining == [part_1] + + # vreify that the omplete encoding decodes correctly + decoded, remaining = decode_utf8_processor([part_1, part_2]) + assert decoded == ['ö'] + assert remaining == [] + + +def test_pipeline_finishing(): + encoded = 'ö'.encode('utf-8') + part_1, part_2 = encoded[:1], encoded[1:] + + pipeline = DataProcessorPipeline([decode_utf8_processor]) + res = pipeline.process(part_1) + assert res == [] + res = pipeline.finalize() + assert res == ['\\xc3'] diff --git a/datalad_next/runners/tests/test_protocols.py b/datalad_next/runners/tests/test_protocols.py new file mode 100644 index 00000000..3e2cf9f4 --- /dev/null +++ b/datalad_next/runners/tests/test_protocols.py @@ -0,0 +1,81 @@ +import sys + +from ..data_processors import ( + decode_processor, + splitlines_processor, +) +from ..protocols import ( + StdOutCaptureProcessingGeneratorProtocol, + StdOutErrCaptureProcessingGeneratorProtocol, +) +from ..run import run + + +def test_stdout_pipeline_protocols_simple(): + # verify that the pipeline is used and finalized + processors = [splitlines_processor()] + protocol = StdOutCaptureProcessingGeneratorProtocol(processors=processors) + + data = b'abc\ndef\nghi' + protocol.pipe_data_received(1, data) + protocol.pipe_connection_lost(1, None) + + assert tuple(protocol.result_queue) == (b'abc\n', b'def\n', b'ghi') + + +def test_stdout_pipeline_protocol(): + with run( + [sys.executable, '-u', '-c', 'print("abc\\ndef\\nghi", end="")'], + protocol_class=StdOutCaptureProcessingGeneratorProtocol, + protocol_kwargs=dict( + processors=[decode_processor(), splitlines_processor()] + ) + ) as r: + # There is no way to get un-decoded byte content with the non-generator + # protocols. + assert tuple(r) == ('abc\n', 'def\n', 'ghi') + + +def test_stdout_stderr_pipeline_protocol_simple(): + protocol = StdOutErrCaptureProcessingGeneratorProtocol( + stdout_processors=[decode_processor(), splitlines_processor()], + stderr_processors=[splitlines_processor()] + ) + + protocol.pipe_data_received(1, b'abc\ndef\nghi') + assert tuple(protocol.result_queue) == ((1, 'abc\n'), (1, 'def\n')) + protocol.result_queue.clear() + + # Check that the processing pipeline is finalized + protocol.pipe_connection_lost(1, None) + assert tuple(protocol.result_queue) == ((1, 'ghi'),) + protocol.result_queue.clear() + + protocol.pipe_data_received(2, b'rst\nuvw\nxyz') + assert tuple(protocol.result_queue) == ((2, b'rst\n'), (2, b'uvw\n')) + protocol.result_queue.clear() + + # Check that the processing pipeline is finalized + protocol.pipe_connection_lost(2, None) + assert tuple(protocol.result_queue) == ((2, b'xyz'),) + + +def test_stdout_stderr_pipeline_protocol(): + with run( + [ + sys.executable, '-u', '-c', + 'import sys\n' + 'print("abc\\ndef\\nghi", end="")\n' + 'print("rst\\nuvw\\nxyz", end="", file=sys.stderr)\n' + ], + protocol_class=StdOutErrCaptureProcessingGeneratorProtocol, + protocol_kwargs=dict( + stdout_processors=[decode_processor(), splitlines_processor()], + stderr_processors=[splitlines_processor()] + ) + ) as r: + result = tuple(r) + + assert len(result) == 6 + assert ''.join(x[1] for x in result if x[0] == 1) == 'abc\ndef\nghi' + assert b''.join(x[1] for x in result if x[0] == 2) == b'rst\nuvw\nxyz' diff --git a/datalad_next/runners/tests/test_run.py b/datalad_next/runners/tests/test_run.py new file mode 100644 index 00000000..50e32245 --- /dev/null +++ b/datalad_next/runners/tests/test_run.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from queue import Queue +from random import randint +from typing import Generator + +import pytest + +from datalad.utils import ( + on_osx, + on_windows, +) +from datalad.tests.utils_pytest import skip_if + +from .. import ( + NoCapture, + StdErrCapture, + StdOutCapture, + StdOutCaptureGeneratorProtocol, + StdOutErrCapture, + ThreadedRunner, +) +from ..run import run +from ..data_processors import ( + decode_processor, + splitlines_processor, +) +from ..data_processor_pipeline import process_from + +resources_dir = Path(__file__).parent / 'resources' + + +interruptible_prog = ''' +import time + +i = 0 +while True: + print(i, flush=True) + i += 1 + time.sleep(1) +''' + +uninterruptible_prog = ''' +import signal + +signal.signal(signal.SIGTERM, signal.SIG_IGN) +signal.signal(signal.SIGINT, signal.SIG_IGN) +''' + interruptible_prog + +stdin_reading_prog = ''' +import sys + +while True: + data = sys.stdin.readline() + if data == '': + exit(0) + print(f'entered: {data.strip()}', flush=True) +''' + +stdin_closing_prog = ''' +import sys +import time + +sys.stdin.close() +while True: + print(f'stdin is closed {time.time()}', flush=True) + time.sleep(.1) +''' + + +def test_sig_kill(): + with run(cmd=[sys.executable, '-c', uninterruptible_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + terminate_time=1, + kill_time=1) as r: + # Fetch one data chunk to ensure that the process is running + data = next(r) + assert data[:1] == b'0' + + # Ensure that the return code was read and is not zero + assert r.return_code not in (0, None) + if os.name == 'posix': + assert r.return_code == -signal.SIGKILL + + +def test_sig_terminate(): + with run(cmd=[sys.executable, '-c', interruptible_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + terminate_time=1, + kill_time=1) as r: + # Fetch one data chunk to ensure that the process is running + data = next(r) + assert data[:1] == b'0' + + # Ensure that the return code was read + assert r.return_code is not None + if os.name == 'posix': + assert r.return_code == -signal.SIGTERM + + +def test_external_close(): + stdin_queue = Queue() + with run([sys.executable, '-c', stdin_reading_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=stdin_queue) as r: + while True: + stdin_queue.put(f'{time.time()}{os.linesep}'.encode()) + try: + result = next(r) + except StopIteration: + break + r.runner.process.stdin.close() + + assert r.return_code == 0 + + +@skip_if(on_osx or on_windows) # On MacOS and Windows a write will block +def test_internal_close_file(): + # This test demonstrates pipe-writing behavior if the receiving side, + # i.e. the sub-process, does not read from the pipe. It is not specifically + # a test for the context-manager. + with run([sys.executable, '-c', stdin_closing_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=subprocess.PIPE, + timeout=2.0, + terminate_time=1, + kill_time=1) as r: + + os.set_blocking(r.runner.process.stdin.fileno(), False) + total = 0 + while True: + try: + written = r.runner.process.stdin.write(b'a' * 8000) + if written is None: + print(f'Write failed after {total} bytes', flush=True) + # There are no proper STDIN-timeouts because we handle that + # ourselves. So for the purpose of this test, we can not + # rely on timeout. That means, we kill the process here + # and the let __exit__-method pick up the peaces, i.e. the + # return code. + r.runner.process.kill() + break + except BrokenPipeError: + print(f'Wrote less than {total + 8000} bytes', flush=True) + break + total += written + assert r.return_code not in (0, None) + + +def _check_signal_blocking(program: str): + with run(cmd=[sys.executable, '-c', program], + protocol_class=StdOutCapture, + terminate_time=1, + kill_time=1) as r: + pass + + # Check the content + assert all([ + index == int(item) + for index, item in enumerate(r['stdout'].splitlines()) + ]) + + # Ensure that the return code was read + return_code = r['code'] + assert return_code is not None + return return_code + + +def test_kill_blocking(): + return_code = _check_signal_blocking(uninterruptible_prog) + if os.name == 'posix': + assert return_code == -signal.SIGKILL + + +def test_terminate_blocking(): + return_code = _check_signal_blocking(interruptible_prog) + if os.name == 'posix': + assert return_code == -signal.SIGTERM + + +def test_batch_1(): + stdin_queue = Queue() + with run([sys.executable, '-c', stdin_reading_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=stdin_queue, + terminate_time=20, + kill_time=5) as r: + + # Create a line-splitting result generator + line_results = process_from( + data_source=r, + processors=[decode_processor(), splitlines_processor()] + ) + + for i in range(10): + message = f'{time.time()}{os.linesep}' + stdin_queue.put(message.encode()) + response = next(line_results) + assert response == 'entered: ' + message + time.sleep(0.1) + stdin_queue.put(None) + assert r.return_code == 0 + + +def test_shell_like(): + stdin_queue = Queue() + with run([sys.executable, str(resources_dir / 'shell_like_prog.py')], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=stdin_queue) as r: + + # Create a line-splitting result generator + line_results = process_from( + data_source=r, + processors=[decode_processor(), splitlines_processor()] + ) + + # Create a random marker and send it to the subprocess + marker = f'mark-{randint(1000000, 2000000)}{os.linesep}' + stdin_queue.put(marker.encode()) + + # Read until the marker comes back + for line_index, line in enumerate(line_results): + if line[-len(marker):] == marker: + unterminated_line = line[:-len(marker)] + break + if unterminated_line: + assert line_index % 2 == 1 + + assert r.return_code == 0 + + +def test_run_timeout(): + with pytest.raises(TimeoutError): + with run([ + sys.executable, '-c', + 'import time; time.sleep(3)'], + StdOutCaptureGeneratorProtocol, + timeout=1 + ) as res: + # must poll, or timeouts are not checked + list(res) + + +def test_run_kill_on_exit(): + with run([ + sys.executable, '-c', + 'import time; print("mike", flush=True); time.sleep(10)'], + StdOutCaptureGeneratorProtocol, + terminate_time=1, + kill_time=1, + ) as res: + assert next(res).rstrip(b'\r\n') == b'mike' + # here the process must be killed be the exit of the contextmanager + if os.name == 'posix': + # on posix platforms a negative return code of -X indicates + # a "killed by signal X" + assert res.return_code < 0 + # on any system the process must be dead now (indicated by a return code) + assert res.return_code is not None + + +def test_run_instant_kill(): + with run([ + sys.executable, '-c', + 'import time; time.sleep(3)'], + StdOutCaptureGeneratorProtocol, + terminate_time=1, + kill_time=1, + ) as sp: + # we let it terminate instantly + pass + if os.name == 'posix': + assert sp.return_code < 0 + assert sp.return_code is not None + + +def test_run_cwd(tmp_path): + with run([ + sys.executable, '-c', + 'from pathlib import Path; print(Path.cwd(), end="")'], + StdOutCapture, + cwd=tmp_path, + ) as res: + assert res['stdout'] == str(tmp_path) + + +def test_run_input_bytes(): + with run([ + sys.executable, '-c', + 'import sys;' + 'print(sys.stdin.read(), end="")'], + StdOutCapture, + # it only takes bytes + stdin=b'mybytes\nline', + ) as res: + # not that bytes went in, but str comes out -- it is up to + # the protocol. + # use splitlines to compensate for platform line ending + # differences + assert res['stdout'].splitlines() == ['mybytes', 'line'] + + +def test_run_input_queue(): + stdin_queue = Queue() + with run([ + sys.executable, '-c', + 'from fileinput import input; import sys;' + '[print(line, end="", flush=True) if line.strip() else sys.exit(0)' + ' for line in input()]'], + StdOutCaptureGeneratorProtocol, + stdin=stdin_queue, + ) as sp: + stdin_queue.put(f'one\n'.encode()) + response = next(sp) + assert response.rstrip() == b'one' + stdin_queue.put(f'two\n'.encode()) + response = next(sp) + assert response.rstrip() == b'two' + # an empty line should cause process exit + stdin_queue.put(os.linesep.encode()) + # we can wait for that even before the context manager + # does its thing and tears it down + sp.runner.process.wait() + + +def test_run_nongenerator(): + # when executed with a non-generator protocol, the process + # runs and returns whatever the specified protocol returns + # from _prepare_result. + # Below we test the core protocols -- that all happen to + # report a return `code`, `stdout`, `stderr` -- but this is + # nohow a given for any other protocol. + with run([sys.executable, '--version'], NoCapture) as res: + assert res['code'] == 0 + with run([sys.executable, '-c', 'import sys; sys.exit(134)'], + NoCapture) as res: + assert res['code'] == 134 + with run([ + sys.executable, '-c', + 'import sys; print("print", end="", file=sys.stdout)'], + StdOutCapture, + ) as res: + assert res['code'] == 0 + assert res['stdout'] == 'print' + with run([ + sys.executable, '-c', + 'import sys; print("print", end="", file=sys.stderr)'], + StdErrCapture, + ) as res: + assert res['code'] == 0 + assert res['stderr'] == 'print' + with run([ + sys.executable, '-c', + 'import sys; print("outy", end="", file=sys.stdout); ' + 'print("error", end="", file=sys.stderr)'], + StdOutErrCapture, + ) as res: + assert res['code'] == 0 + assert res['stdout'] == 'outy' + assert res['stderr'] == 'error' + + +def test_closing_action(): + # Check exit condition without closing stdin and without a closing action. + # The process should be terminated by a SIGTERM signal. + stdin_queue = Queue() + with run([sys.executable, '-c', stdin_reading_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=stdin_queue, + terminate_time=2) as r: + stdin_queue.put(f'{time.time()}{os.linesep}'.encode()) + # Leave the context without closing + assert r.return_code not in (0, None) + if os.name == 'posix': + assert r.return_code == -signal.SIGTERM + + # Check exit condition without closing stdin. The process should be + # terminated by a SIGTERM signal. + def closing_action(runner: ThreadedRunner, result_generator: Generator): + assert isinstance(runner, ThreadedRunner) + assert isinstance(result_generator, Generator) + runner.stdin_queue.put(None) + + stdin_queue = Queue() + with run([sys.executable, '-c', stdin_reading_prog], + protocol_class=StdOutCaptureGeneratorProtocol, + stdin=stdin_queue, + closing_action=closing_action, + terminate_time=2) as r: + stdin_queue.put(f'{time.time()}{os.linesep}'.encode()) + # Leave the context without closing + # If the closing action was activated, we expect a zero exit code + assert r.return_code == 0 diff --git a/datalad_next/url_operations/ssh.py b/datalad_next/url_operations/ssh.py index 1e9a3aac..1e3972aa 100644 --- a/datalad_next/url_operations/ssh.py +++ b/datalad_next/url_operations/ssh.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging -import subprocess import sys from itertools import chain from pathlib import ( @@ -16,7 +15,6 @@ Queue, ) from typing import ( - Any, Dict, Generator, IO, @@ -24,14 +22,13 @@ from urllib.parse import urlparse from datalad_next.runners import ( - GeneratorMixIn, NoCaptureGeneratorProtocol, - Protocol as RunnerProtocol, StdOutCaptureGeneratorProtocol, - ThreadedRunner, - CommandError, ) +from datalad_next.runners.data_processors import pattern_processor +from datalad_next.runners.data_processor_pipeline import DataProcessorPipeline +from datalad_next.runners.run import run from datalad_next.utils.consts import COPY_BUFSIZE from . import ( @@ -71,6 +68,18 @@ class SshUrlOperations(UrlOperations): "|| exit 244" _cat_cmd = "cat '{fpath}'" + def _check_return_code(self, url, stream): + # At this point the subprocess has either exited, was terminated, or + # was killed. + if stream.return_code == 244: + # this is the special code for a file-not-found + raise UrlOperationsResourceUnknown(url) + elif stream.return_code != 0: + raise UrlOperationsRemoteError( + url, + message=f'ssh process returned {stream.return_code}' + ) + def stat(self, url: str, *, @@ -81,63 +90,67 @@ def stat(self, See :meth:`datalad_next.url_operations.UrlOperations.stat` for parameter documentation and exception behavior. """ - try: - props = self._stat( - url, - cmd=SshUrlOperations._stat_cmd, - ) - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(url) from e - else: - raise UrlOperationsRemoteError(url, message=str(e)) from e + ssh_cat = _SshCat(url) + cmd = ssh_cat.get_cmd(SshUrlOperations._stat_cmd) + with run(cmd, protocol_class=StdOutCaptureGeneratorProtocol) as stream: + props = self._get_props(url, stream) + # At this point the subprocess has either exited, was terminated, or + # was killed. + self._check_return_code(url, stream) return {k: v for k, v in props.items() if not k.startswith('_')} - def _stat(self, url: str, cmd: str) -> Dict: - # any stream must start with this magic marker, or we do not - # recognize what is happening - # after this marker, the server will send the size of the - # to-be-downloaded file in bytes, followed by another magic - # b'\1', and the file content after that - need_magic = b'\1\2\3' - expected_size_str = b'' - expected_size = None + def _get_props(self, url, stream: Generator) -> Dict: + # The try clause enables us to execute the code after the context + # handler if the iterator stops unexpectedly. That would, for + # example be the case, if the ssh-subprocess terminates prematurely, + # for example, due to a missing file. + # (An alternative way to detect and handle the exit would be to + # implement some handling in the protocol.connection_lost callback + # and send the result to the generator, e.g. via: + # protocol.send(('process-exit', self.process.poll())) + try: + # any stream must start with this magic marker, or we do not + # recognize what is happening + # after this marker, the server will send the size of the + # to-be-downloaded file in bytes, followed by another magic + # b'\1', and the file content after that + magic_marker = b'\1\2\3' + + # Create a pipeline object that contains a single data + # processors, i.e. the "pattern_border_processor". It guarantees, that + # each chunk has at least the size of the pattern and that no chunk + # ends with a pattern prefix (except from the last chunk). + # (We could have used the convenience wrapper "process_from", but we + # want to remove the filter again below. This requires us to have a + # ProcessorPipeline-object). + pipeline = DataProcessorPipeline([pattern_processor(magic_marker)]) + filtered_stream = pipeline.process_from(stream) + + # The first chunk should start with the magic marker, i.e. b'\1\2\3' + chunk = next(filtered_stream) + if chunk[:len(magic_marker)] != magic_marker: + raise RuntimeError("Protocol error: report header not received") + + # Remove the filter again. The chunk is extended to contain all + # data that was buffered in the pipeline. + chunk = b''.join([chunk[len(magic_marker):]] + pipeline.finalize()) + + # The length is transferred now and terminated by b'\x01'. + while b'\x01' not in chunk: + chunk += next(stream) + + marker_index = chunk.index(b'\x01') + expected_size = int(chunk[:marker_index]) + chunk = chunk[marker_index + 1:] + props = { + 'content-length': expected_size, + '_stream': chain([chunk], stream) if chunk else stream + } + return props - ssh_cat = _SshCat(url) - stream = ssh_cat.run(cmd, protocol=StdOutCaptureGeneratorProtocol) - for chunk in stream: - if need_magic: - expected_magic = need_magic[:min(len(need_magic), - len(chunk))] - incoming_magic = chunk[:len(need_magic)] - # does the incoming data have the remaining magic bytes? - if incoming_magic != expected_magic: - raise RuntimeError( - "Protocol error: report header not received") - # reduce (still missing) magic, if any - need_magic = need_magic[len(expected_magic):] - # strip magic from input - chunk = chunk[len(expected_magic):] - if chunk and expected_size is None: - # we have incoming data left and - # we have not yet consumed the size info - size_data = chunk.split(b'\1', maxsplit=1) - expected_size_str += size_data[0] - if len(size_data) > 1: - # this is not only size info, but we found the start of - # the data - expected_size = int(expected_size_str) - chunk = size_data[1] - if expected_size: - props = { - 'content-length': expected_size, - '_stream': chain([chunk], stream) if chunk else stream, - } - return props - # there should be no data left to process, or something went wrong - assert not chunk + except StopIteration: + self._check_return_code(url, stream) def download(self, from_url: str, @@ -164,13 +177,17 @@ def download(self, dst_fp = None - try: - props = self._stat( - from_url, - cmd=f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}', - ) - stream = props.pop('_stream') + ssh_cat = _SshCat(from_url) + cmd = ssh_cat.get_cmd(f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}') + with run(cmd, protocol_class=StdOutCaptureGeneratorProtocol) as stream: + + props = self._get_props(from_url, stream) expected_size = props['content-length'] + # The stream might have changed due to not yet processed, but + # fetched data, that is now chained in front of it. Therefore we get + # the updated stream from the props + stream = props.pop('_stream') + dst_fp = sys.stdout.buffer if to_path is None \ else open(to_path, 'wb') # Localize variable access to minimize overhead @@ -190,19 +207,18 @@ def download(self, self._progress_report_update( progress_id, ('Downloaded chunk',), len(chunk)) props.update(hasher.get_hexdigest()) - return props - except CommandError as e: - if e.code == 244: - # this is the special code for a file-not-found - raise UrlOperationsResourceUnknown(from_url) from e - else: - # wrap this into the datalad-standard, but keep the - # original exception linked - raise UrlOperationsRemoteError(from_url, message=str(e)) from e - finally: - if dst_fp and to_path is not None: - dst_fp.close() - self._progress_report_stop(progress_id, ('Finished download',)) + + # At this point the subprocess has either exited, was terminated, or + # was killed. + if stream.return_code == 244: + # this is the special code for a file-not-found + raise UrlOperationsResourceUnknown(from_url) + elif stream.return_code != 0: + raise UrlOperationsRemoteError( + from_url, + message=f'ssh process returned {stream.return_code}' + ) + return props def upload(self, from_path: Path | None, @@ -253,64 +269,71 @@ def _perform_upload(self, # we limit the queue to few items in order to `make queue.put()` # block relatively quickly, and thereby have the progress report - # actually track the upload, and not just the feeding of the - # queue + # actually track the upload, i.e. the feeding of the stdin pipe + # of the ssh-process, and not just the feeding of the + # queue. upload_queue = Queue(maxsize=2) - ssh_cat = _SshCat(to_url) - ssh_runner_generator = ssh_cat.run( + cmd = _SshCat(to_url).get_cmd( # leave special exit code when writing fails, but not the # general SSH access - "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244", - protocol=NoCaptureGeneratorProtocol, - stdin=upload_queue, - timeout=timeout, - ) - - # file is open, we can start progress tracking - progress_id = self._get_progress_id(source_name, to_url) - self._progress_report_start( - progress_id, - ('Upload %s to %s', source_name, to_url), - 'uploading', - expected_size, + "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244" ) - try: + with run(cmd, NoCaptureGeneratorProtocol, stdin=upload_queue, timeout=timeout) as ssh: + # file is open, we can start progress tracking + progress_id = self._get_progress_id(source_name, to_url) + self._progress_report_start( + progress_id, + ('Upload %s to %s', source_name, to_url), + 'uploading', + expected_size, + ) upload_size = 0 - while ssh_runner_generator.runner.process.poll() is None: + while True: chunk = src_fp.read(COPY_BUFSIZE) + # Leave the write-loop at eof if chunk == b'': break + + # If the ssh-subprocess exited, leave the write loop, the + # result will be interpreted below + if ssh.runner.process.poll() is not None: + break + chunk_size = len(chunk) # compute hash simultaneously hasher.update(chunk) + # we are just putting stuff in the queue, and rely on # its maxsize to cause it to block the next call to # have the progress reports be anyhow valid - upload_queue.put(chunk, timeout=timeout) + try: + upload_queue.put(chunk, timeout=timeout) + except Full: + raise TimeoutError + self._progress_report_update( progress_id, ('Uploaded chunk',), chunk_size) upload_size += chunk_size - # we're done, close queue - upload_queue.put(None, timeout=timeout) - - # Exhaust the generator, that might raise CommandError - # or TimeoutError, if timeout was not `None`. - tuple(ssh_runner_generator) - except CommandError as e: - if e.code == 244: - raise UrlOperationsResourceUnknown(to_url) from e - else: - raise UrlOperationsRemoteError(to_url, message=str(e)) from e - except (TimeoutError, Full): - ssh_runner_generator.runner.process.kill() - raise TimeoutError - finally: - self._progress_report_stop(progress_id, ('Finished upload',)) - - assert ssh_runner_generator.return_code == 0, "Unexpected ssh " \ - f"return value: {ssh_runner_generator.return_code}" + # we're done, close queue + try: + upload_queue.put(None, timeout=timeout) + except Full: + # Everything is done. If we leave the context the subprocess + # will be treated as specified in the context initialization, + # either wait for it, terminate, or kill it. + pass + + # At this point the subprocess has terminated by itself or was killed. + if ssh.return_code == 244: + raise UrlOperationsResourceUnknown(to_url) + elif ssh.return_code != 0: + raise UrlOperationsRemoteError( + to_url, + message=f'ssh exited with return value: {ssh.return_code}') + + assert ssh.return_code == 0, f"Unexpected ssh return value: {ssh.return_code}" return { **hasher.get_hexdigest(), # return how much was copied. we could compare with @@ -328,11 +351,7 @@ def __init__(self, url: str, *additional_ssh_args): assert self._parsed.path self.ssh_args: list[str] = list(additional_ssh_args) - def run(self, - payload_cmd: str, - protocol: type[RunnerProtocol], - stdin: Queue | None = None, - timeout: float | None = None) -> Any | Generator: + def get_cmd(self, payload_cmd: str) -> list[str]: fpath = self._parsed.path cmd = ['ssh'] cmd.extend(self.ssh_args) @@ -344,9 +363,4 @@ def run(self, fpath=fpath, ), ]) - return ThreadedRunner( - cmd=cmd, - protocol_class=protocol, - stdin=subprocess.DEVNULL if stdin is None else stdin, - timeout=timeout, - ).run() + return cmd diff --git a/docs/source/pyutils.rst b/docs/source/pyutils.rst index b97a248d..722fa69e 100644 --- a/docs/source/pyutils.rst +++ b/docs/source/pyutils.rst @@ -24,6 +24,7 @@ packages. exceptions iter_collections runners + runners.data_processors tests.fixtures types uis