Skip to content

Commit

Permalink
add pipeline processing generator-protocols
Browse files Browse the repository at this point in the history
This commit adds pipeline processing generator
protocols. The protocols can be initialized
with a list of processors that they will apply
to the data that they receive.

There is currently no equivalent for
non-generator protocols. The reason for that
is that those protocols expect every data
element that they handle to be `bytes`.
That does not fit well with processors
like `splitlines_processor` or
`decode_processor`.

We could could implement
a new non-generator protocol "family"
that collects whatever comes from a
processor pipeline as the values of
`stdout`- and `stderr`-keys of the
result dictionary. But it is currently
not clear when and where that would
be used.
  • Loading branch information
christian-monch committed Oct 24, 2023
1 parent 25bf3cf commit 9720976
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 0 deletions.
82 changes: 82 additions & 0 deletions datalad_next/runners/protocols.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

from typing import Optional

from . import (
GeneratorMixIn,
NoCapture,
StdOutCapture,
StdOutErrCapture,
)
from .data_processor_pipeline import DataProcessorPipeline


#
Expand All @@ -29,3 +35,79 @@ def pipe_data_received(self, fd: int, data: bytes):

def timeout(self, fd):
raise TimeoutError(f"Runner timeout {fd}")


class StdOutCaptureProcessingGeneratorProtocol(StdOutCaptureGeneratorProtocol):
""" A generator protocol that applies a processor pipeline to stdout data
This protocol can be initialized with a list of processors. Data that is
read from stdout will be processed by the processors and the result of the
last processor will be sent to the result generator, which will then
yield it.
"""
def __init__(self,
done_future=None,
processors: list | None = None
) -> None:
StdOutCaptureGeneratorProtocol.__init__(self, done_future, None)
self.processor_pipeline = (
DataProcessorPipeline(processors)
if processors
else None
)

def pipe_data_received(self, fd: int, data: bytes):
assert fd == 1
if self.processor_pipeline:
for processed_data in self.processor_pipeline.process(data):
self.send_result(processed_data)
return
self.send_result(data)

def pipe_connection_lost(self, fd: int, exc: Optional[BaseException]) -> None:
assert fd == 1
if self.processor_pipeline:
for processed_data in self.processor_pipeline.finalize():
self.send_result(processed_data)


class StdOutErrCaptureProcessingGeneratorProtocol(StdOutErrCapture, GeneratorMixIn):
""" A generator protocol that applies processor-pipeline to stdout- and stderr-data
This protocol can be initialized with a list of processors for stdout-data,
and with a list of processors for stderr-data. Data that is read from stdout
or stderr will be processed by the respective processors. The protocol will
send 2-tuples to the result generator. Each tuple consists of the file
descriptor on which data arrived and the output of the last processor of the
respective pipeline. The result generator. which will then yield the
results.
"""
def __init__(self,
done_future=None,
stdout_processors: list | None = None,
stderr_processors: list | None = None,
) -> None:
StdOutErrCapture.__init__(self, done_future, None)
GeneratorMixIn.__init__(self)
self.processor_pipelines = {
fd: DataProcessorPipeline(processors)
for fd, processors in ((1, stdout_processors), (2, stderr_processors))
if processors is not None
}

def pipe_data_received(self, fd: int, data: bytes):
assert fd in (1, 2)
if fd in self.processor_pipelines:
for processed_data in self.processor_pipelines[fd].process(data):
self.send_result((fd, processed_data))
return
self.send_result((fd, data))

def pipe_connection_lost(self, fd: int, exc: Optional[BaseException]) -> None:
assert fd in (1, 2)
if fd in self.processor_pipelines:
for processed_data in self.processor_pipelines[fd].finalize():
self.send_result((fd, processed_data))

def timeout(self, fd):
raise TimeoutError(f"Runner timeout {fd}")
81 changes: 81 additions & 0 deletions datalad_next/runners/tests/test_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys

from ..data_processors import (
decode_processor,
splitlines_processor,
)
from ..protocols import (
StdOutCaptureProcessingGeneratorProtocol,
StdOutErrCaptureProcessingGeneratorProtocol,
)
from ..run import run


def test_stdout_pipeline_protocols_simple():
# verify that the pipeline is used and finalized
processors = [splitlines_processor()]
protocol = StdOutCaptureProcessingGeneratorProtocol(processors=processors)

data = b'abc\ndef\nghi'
protocol.pipe_data_received(1, data)
protocol.pipe_connection_lost(1, None)

assert tuple(protocol.result_queue) == (b'abc\n', b'def\n', b'ghi')


def test_stdout_pipeline_protocol():
with run(
[sys.executable, '-u', '-c', 'print("abc\\ndef\\nghi", end="")'],
protocol_class=StdOutCaptureProcessingGeneratorProtocol,
protocol_kwargs=dict(
processors=[decode_processor(), splitlines_processor()]
)
) as r:
# There is no way to get un-decoded byte content with the non-generator
# protocols.
assert tuple(r) == ('abc\n', 'def\n', 'ghi')


def test_stdout_stderr_pipeline_protocol_simple():
protocol = StdOutErrCaptureProcessingGeneratorProtocol(
stdout_processors=[decode_processor(), splitlines_processor()],
stderr_processors=[splitlines_processor()]
)

protocol.pipe_data_received(1, b'abc\ndef\nghi')
assert tuple(protocol.result_queue) == ((1, 'abc\n'), (1, 'def\n'))
protocol.result_queue.clear()

# Check that the processing pipeline is finalized
protocol.pipe_connection_lost(1, None)
assert tuple(protocol.result_queue) == ((1, 'ghi'),)
protocol.result_queue.clear()

protocol.pipe_data_received(2, b'rst\nuvw\nxyz')
assert tuple(protocol.result_queue) == ((2, b'rst\n'), (2, b'uvw\n'))
protocol.result_queue.clear()

# Check that the processing pipeline is finalized
protocol.pipe_connection_lost(2, None)
assert tuple(protocol.result_queue) == ((2, b'xyz'),)


def test_stdout_stderr_pipeline_protocol():
with run(
[
sys.executable, '-u', '-c',
'import sys\n'
'print("abc\\ndef\\nghi", end="")\n'
'print("rst\\nuvw\\nxyz", end="", file=sys.stderr)\n'
],
protocol_class=StdOutErrCaptureProcessingGeneratorProtocol,
protocol_kwargs=dict(
stdout_processors=[decode_processor(), splitlines_processor()],
stderr_processors=[splitlines_processor()]
)
) as r:
result = tuple(r)

assert len(result) == 6
assert ''.join(x[1] for x in result if x[0] == 1) == 'abc\ndef\nghi'
assert b''.join(x[1] for x in result if x[0] == 2) == b'rst\nuvw\nxyz'

0 comments on commit 9720976

Please sign in to comment.