From 428e0336d9cd93696ceff2a1f780a84518e8b7d4 Mon Sep 17 00:00:00 2001 From: Mike Nerone Date: Wed, 17 Jan 2024 00:39:33 -0600 Subject: [PATCH] Explicit __aiter__ where __anext__ is called, b/c "iterables" aren't necessarily iterators --- slurry/environments/_threading.py | 3 ++- slurry/sections/_filters.py | 2 +- slurry/sections/weld.py | 2 +- tests/fixtures.py | 7 +++++++ tests/test_filters.py | 9 ++++++++- tests/test_threading.py | 12 ++++++++++-- 6 files changed, 29 insertions(+), 6 deletions(-) diff --git a/slurry/environments/_threading.py b/slurry/environments/_threading.py index f947710..0d84f66 100644 --- a/slurry/environments/_threading.py +++ b/slurry/environments/_threading.py @@ -28,9 +28,10 @@ def sync_input(): """Wrapper for turning an async iterable into a blocking generator.""" if input is None: return + input_aiter = input.__aiter__() try: while True: - yield trio.from_thread.run(input.__anext__) + yield trio.from_thread.run(input_aiter.__anext__) except StopAsyncIteration: pass diff --git a/slurry/sections/_filters.py b/slurry/sections/_filters.py index a592cca..ad20ea0 100644 --- a/slurry/sections/_filters.py +++ b/slurry/sections/_filters.py @@ -29,7 +29,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: try: for _ in range(self.count): await aiter.__anext__() diff --git a/slurry/sections/weld.py b/slurry/sections/weld.py index 784a6a9..c8934e2 100644 --- a/slurry/sections/weld.py +++ b/slurry/sections/weld.py @@ -21,7 +21,7 @@ async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.Memory await section.pump(input, output.send) except trio.BrokenResourceError: pass - if input: + if input and hasattr(input, "aclose") and callable(input.aclose): await input.aclose() await output.aclose() diff --git a/tests/fixtures.py b/tests/fixtures.py index 63d8822..9d81da5 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -79,3 +79,10 @@ def fibonacci(self, i): def refine(self, input: Iterable[Any], output: Callable[[Any], None]): for i in range(self.i): output(self.fibonacci(i)) + +class AsyncNonIteratorIterable: + def __init__(self, source_aiterable): + self.source_aiterable = source_aiterable + + def __aiter__(self): + return self.source_aiterable.__aiter__() diff --git a/tests/test_filters.py b/tests/test_filters.py index b85d8ee..72ac907 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,7 +1,7 @@ from slurry import Pipeline from slurry.sections import Merge, RateLimit, Skip, SkipWhile, Filter, Changes -from .fixtures import produce_increasing_integers, produce_mappings +from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, produce_mappings async def test_skip(autojump_clock): async with Pipeline.create( @@ -10,6 +10,13 @@ async def test_skip(autojump_clock): result = [i async for i in aiter] assert result == [5, 6, 7, 8, 9] +async def test_skip_input_non_iterator_iterable(autojump_clock): + async with Pipeline.create( + Skip(5, AsyncNonIteratorIterable(produce_increasing_integers(1, max=10))) + ) as pipeline, pipeline.tap() as aiter: + result = [i async for i in aiter] + assert result == [5, 6, 7, 8, 9] + async def test_skip_short_stream(autojump_clock): async with Pipeline.create( Skip(5, produce_increasing_integers(1)) diff --git a/tests/test_threading.py b/tests/test_threading.py index 834f4fd..add6744 100644 --- a/tests/test_threading.py +++ b/tests/test_threading.py @@ -2,7 +2,7 @@ from slurry import Pipeline from slurry.sections import Map -from .fixtures import produce_increasing_integers, SyncSquares +from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, SyncSquares async def test_thread_section(autojump_clock): async with Pipeline.create( @@ -12,6 +12,14 @@ async def test_thread_section(autojump_clock): result = [i async for i in aiter] assert result == [0, 1, 4, 9, 16] +async def test_thread_section_input_non_iterator_iterable(autojump_clock): + async with Pipeline.create( + AsyncNonIteratorIterable(produce_increasing_integers(1, max=5)), + SyncSquares() + ) as pipeline, pipeline.tap() as aiter: + result = [i async for i in aiter] + assert result == [0, 1, 4, 9, 16] + async def test_thread_section_early_break(autojump_clock): async with Pipeline.create( produce_increasing_integers(1, max=5), @@ -39,4 +47,4 @@ async def test_thread_section_section_input(autojump_clock): SyncSquares() ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] - assert result == [0, 1, 4] \ No newline at end of file + assert result == [0, 1, 4]