diff --git a/docs/user.rst b/docs/user.rst index 7f40117..b07bf3e 100644 --- a/docs/user.rst +++ b/docs/user.rst @@ -38,9 +38,9 @@ pipeline out of individual sections. A ``PipelineSection`` is any object that is function. This currently includes the following types: AsyncIterables - Async iterables are valid only as the very first ``PipelineSection``. Subsequent - sections will use this async iterable as input source. Placing an ``AsyncIterable`` into the middle of - a sequence of pipeline sections, will cause a ``ValueError``. + Async iterables are valid only as the very first ``PipelineSection``, and must support the ``aclose()`` + method (nearly all do). Subsequent sections will use this async iterable as input source. Placing an + ``AsyncIterable`` into the middle of a sequence of pipeline sections, will cause a ``ValueError``. Sections Any :class:`Section ` abc subclass is a valid ``PipelineSection``, at any position in the pipeline. diff --git a/poetry.lock b/poetry.lock index 3b4ad73..0da312f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,17 +25,6 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -[[package]] -name = "async-generator" -version = "1.10" -description = "Async generators and context managers for Python 3.5+" -optional = false -python-versions = ">=3.5" -files = [ - {file = "async_generator-1.10-py3-none-any.whl", hash = "sha256:01c7bf666359b4967d2cda0000cc2e4af16a0ae098cbffcb8472fb9e8ad6585b"}, - {file = "async_generator-1.10.tar.gz", hash = "sha256:6ebb3d106c12920aaae42ccb6f787ef5eefdcdd166ea3d628fa8476abe712144"}, -] - [[package]] name = "attrs" version = "23.2.0" @@ -895,4 +884,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ca21f6d0cebe7f3af9c74277edf6de1541014df6f1ce1d89df8d8ed167e73d01" +content-hash = "438477d8aefee8e748be47f77e3fb979322e7773abe5308b8f3b72011f567625" diff --git a/pyproject.toml b/pyproject.toml index 2594236..b177a18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ classifiers = [ [tool.poetry.dependencies] -async-generator = "^1.10" python = "^3.8" trio = "^0.23.0" diff --git a/slurry/_pipeline.py b/slurry/_pipeline.py index a92bb57..3ad62fa 100644 --- a/slurry/_pipeline.py +++ b/slurry/_pipeline.py @@ -5,11 +5,11 @@ from contextlib import asynccontextmanager import trio -from async_generator import aclosing from .sections.weld import weld from ._tap import Tap from ._types import PipelineSection +from ._utils import aclosing class Pipeline: """The main Slurry ``Pipeline`` class. @@ -54,7 +54,7 @@ async def _pump(self): output = weld(nursery, *self.sections) # Output to taps - async with aclosing(output) as aiter: + async with aclosing(output.__aiter__()) as aiter: async for item in aiter: self._taps = set(filter(lambda tap: not tap.closed, self._taps)) if not self._taps: diff --git a/slurry/_types.py b/slurry/_types.py index 8b9a747..02b5c9a 100644 --- a/slurry/_types.py +++ b/slurry/_types.py @@ -1,10 +1,18 @@ -from typing import Any, AsyncIterable, Awaitable, Protocol, Tuple, Union, runtime_checkable +from typing import Any, AsyncIterable, Awaitable, Protocol, Tuple, TypeVar, Union, runtime_checkable -from .sections.abc import Section +from .sections import abc -PipelineSection = Union[AsyncIterable[Any], Section, Tuple["PipelineSection", ...]] +PipelineSection = Union["AsyncIterableWithAcloseableIterator[Any]", "abc.Section", Tuple["PipelineSection", ...]] + +_T_co = TypeVar("_T_co", covariant=True) @runtime_checkable class SupportsAclose(Protocol): - def aclose(self) -> Awaitable[object]: - ... + def aclose(self) -> Awaitable[object]: ... + +class AcloseableAsyncIterator(SupportsAclose, Protocol[_T_co]): + def __anext__(self) -> Awaitable[_T_co]: ... + def __aiter__(self) -> "AcloseableAsyncIterator[_T_co]": ... + +class AsyncIterableWithAcloseableIterator(Protocol[_T_co]): + def __aiter__(self) -> AcloseableAsyncIterator[_T_co]: ... diff --git a/slurry/_utils.py b/slurry/_utils.py index 204c565..9499ae0 100644 --- a/slurry/_utils.py +++ b/slurry/_utils.py @@ -1,15 +1,14 @@ -from typing import AsyncGenerator, AsyncIterator, TypeVar +from typing import AsyncGenerator, TypeVar -from ._types import SupportsAclose +from ._types import AcloseableAsyncIterator from contextlib import asynccontextmanager _T = TypeVar("_T") @asynccontextmanager -async def safe_aclosing(obj: AsyncIterator[_T]) -> AsyncGenerator[AsyncIterator[_T], None]: +async def aclosing(obj: AcloseableAsyncIterator[_T]) -> AsyncGenerator[AcloseableAsyncIterator[_T], None]: try: yield obj finally: - if isinstance(obj, SupportsAclose): - await obj.aclose() + await obj.aclose() diff --git a/slurry/environments/_multiprocessing.py b/slurry/environments/_multiprocessing.py index 2769444..50c1f6d 100644 --- a/slurry/environments/_multiprocessing.py +++ b/slurry/environments/_multiprocessing.py @@ -1,11 +1,12 @@ """Implements a section that runs in an independent python proces.""" from multiprocessing import Process, SimpleQueue -from typing import AsyncIterable, Any, Awaitable, Callable, Optional, cast +from typing import Any, Awaitable, Callable, Optional, cast import trio from ..sections.abc import SyncSection +from .._types import AsyncIterableWithAcloseableIterator class ProcessSection(SyncSection): """ProcessSection defines a section interface with a synchronous @@ -19,7 +20,9 @@ class ProcessSection(SyncSection): `_. """ - async def pump(self, input: Optional[AsyncIterable[Any]], output: Callable[[Any], Awaitable[None]]): + async def pump( + self, input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: Callable[[Any], Awaitable[None]] + ): """ The ``ProcessSection`` pump method works similar to the threaded version, however since communication between processes is not as simple as it is between threads, diff --git a/slurry/environments/_threading.py b/slurry/environments/_threading.py index 0d84f66..32c1f2d 100644 --- a/slurry/environments/_threading.py +++ b/slurry/environments/_threading.py @@ -1,9 +1,10 @@ """The threading module implements a synchronous section that runs in a background thread.""" -from typing import Any, AsyncIterable, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, Optional import trio from ..sections.abc import SyncSection +from .._types import AsyncIterableWithAcloseableIterator class ThreadSection(SyncSection): @@ -12,7 +13,7 @@ class ThreadSection(SyncSection): """ async def pump(self, - input: Optional[AsyncIterable[Any]], + input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: Callable[[Any], Awaitable[None]]): """Runs the refine method in a background thread with synchronous input and output wrappers, which transparently bridges the input and outputs between the parent diff --git a/slurry/environments/_trio.py b/slurry/environments/_trio.py index dafba32..d79b4ff 100644 --- a/slurry/environments/_trio.py +++ b/slurry/environments/_trio.py @@ -1,13 +1,16 @@ """The Trio environment implements ``TrioSection``, which is a Trio-native :class:`AsyncSection `.""" -from typing import Any, AsyncIterable, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, Optional from ..sections.abc import AsyncSection +from .._types import AsyncIterableWithAcloseableIterator class TrioSection(AsyncSection): """Since Trio is the native Slurry event loop, this environment is simple to implement. The pump method does not need to do anything special to bridge the input and output. It simply delegates directly to the refine method, as the api is identical.""" - async def pump(self, input: Optional[AsyncIterable[Any]], output: Callable[[Any], Awaitable[None]]): + async def pump( + self, input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: Callable[[Any], Awaitable[None]] + ): """Calls refine.""" await self.refine(input, output) diff --git a/slurry/sections/_buffers.py b/slurry/sections/_buffers.py index a243802..038658e 100644 --- a/slurry/sections/_buffers.py +++ b/slurry/sections/_buffers.py @@ -1,12 +1,13 @@ """Pipeline sections with age- and volume-based buffers.""" from collections import deque import math -from typing import Any, AsyncIterable, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence import trio -from async_generator import aclosing from ..environments import TrioSection +from .._types import AsyncIterableWithAcloseableIterator +from .._utils import aclosing class Window(TrioSection): """Window buffer with size and age limits. @@ -26,13 +27,13 @@ class Window(TrioSection): :param max_size: The maximum buffer size. :type max_size: int :param source: Input when used as first section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] :param max_age: Maximum item age in seconds. (default: unlimited) :type max_age: float :param min_size: Minimum amount of items in the buffer to trigger an output. :type min_size: int """ - def __init__(self, max_size: int, source: Optional[AsyncIterable[Any]] = None, *, + def __init__(self, max_size: int, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None, *, max_age: float = math.inf, min_size: int = 1): super().__init__() @@ -51,7 +52,7 @@ async def refine(self, input, output): buf = deque() - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: now = trio.current_time() buf.append((item, now)) @@ -80,7 +81,7 @@ class Group(TrioSection): :param interval: Time in seconds from when an item arrives until the buffer is sent. :type interval: float :param source: Input when used as first section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] :param max_size: Maximum number of items in buffer, which when reached, will cause the buffer to be sent. :type max_size: int @@ -89,7 +90,7 @@ class Group(TrioSection): :param reducer: Optional reducer function used to transform the buffer to a single value. :type reducer: Optional[Callable[[Sequence[Any]], Any]] """ - def __init__(self, interval: float, source: Optional[AsyncIterable[Any]] = None, *, + def __init__(self, interval: float, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None, *, max_size: Optional[int] = None, mapper: Optional[Callable[[Any], Any]] = None, reducer: Optional[Callable[[Sequence[Any]], Any]] = None): @@ -111,7 +112,7 @@ async def refine(self, input, output): send_channel, receive_channel = trio.open_memory_channel(0) async def pull_task(): - async with send_channel, aclosing(source) as aiter: + async with send_channel, aclosing(source.__aiter__()) as aiter: async for item in aiter: await send_channel.send(item) nursery.start_soon(pull_task) @@ -152,9 +153,9 @@ class Delay(TrioSection): :param interval: Number of seconds that each item is delayed. :type interval: float :param source: Input when used as first section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] """ - def __init__(self, interval: float, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, interval: float, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): super().__init__() self.source = source self.interval = interval @@ -169,7 +170,7 @@ async def refine(self, input, output): buffer_input_channel, buffer_output_channel = trio.open_memory_channel(math.inf) async def pull_task(): - async with buffer_input_channel, aclosing(source) as aiter: + async with buffer_input_channel, aclosing(source.__aiter__()) as aiter: async for item in aiter: await buffer_input_channel.send((item, trio.current_time() + self.interval)) diff --git a/slurry/sections/_combiners.py b/slurry/sections/_combiners.py index 66dad9d..971dab1 100644 --- a/slurry/sections/_combiners.py +++ b/slurry/sections/_combiners.py @@ -3,11 +3,11 @@ import itertools import trio -from async_generator import aclosing from ..environments import TrioSection from .weld import weld from .._types import PipelineSection +from .._utils import aclosing class Chain(TrioSection): """Chains input from one or more sources. Any valid ``PipelineSection`` is an allowed source. @@ -40,7 +40,7 @@ async def refine(self, input, output): sources = self.sources async with trio.open_nursery() as nursery: for source in sources: - async with aclosing(weld(nursery, source)) as agen: + async with aclosing(weld(nursery, source).__aiter__()) as agen: async for item in agen: await output(item) @@ -66,7 +66,7 @@ async def refine(self, input, output): async with trio.open_nursery() as nursery: async def pull_task(source): - async with aclosing(weld(nursery, source)) as aiter: + async with aclosing(weld(nursery, source).__aiter__()) as aiter: async for item in aiter: await output(item) @@ -155,7 +155,8 @@ class ZipLatest(TrioSection): default value to output, until an input has arrived on a source. Defaults to ``None``. :type default: Any :param monitor: Additional asynchronous sequences to monitor. - :type monitor: Optional[Union[AsyncIterable[Any], Sequence[AsyncIterable[Any]]]] + :type monitor: Optional[Union[AsyncIterableWithAcloseableIterator[Any], + Sequence[AsyncIterableWithAcloseableIterator[Any]]]] :param place_input: Position of the pipeline input source in the output tuple. Options: ``'first'`` (default)|``'last'`` :type place_input: string @@ -203,7 +204,7 @@ async def refine(self, input, output): async with trio.open_nursery() as nursery: async def pull_task(index, source, monitor=False): - async with aclosing(weld(nursery, source)) as aiter: + async with aclosing(weld(nursery, source).__aiter__()) as aiter: async for item in aiter: results[index] = item ready[index] = True diff --git a/slurry/sections/_filters.py b/slurry/sections/_filters.py index ad20ea0..822b83e 100644 --- a/slurry/sections/_filters.py +++ b/slurry/sections/_filters.py @@ -1,10 +1,11 @@ """Pipeline sections that filters the incoming items.""" -from typing import Any, AsyncIterable, Callable, Hashable, Optional, Union +from typing import Any, Callable, Hashable, Optional, Union -from async_generator import aclosing import trio from ..environments import TrioSection +from .._types import AsyncIterableWithAcloseableIterator +from .._utils import aclosing class Skip(TrioSection): """Skips the first ``count`` items in an asynchronous sequence. @@ -14,9 +15,9 @@ class Skip(TrioSection): :param count: Number of items to skip :type count: int :param source: Input source if starting section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] """ - def __init__(self, count: int, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, count: int, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): super().__init__() self.count = count self.source = source @@ -49,9 +50,9 @@ class SkipWhile(TrioSection): :param pred: Predicate function. :type pred: Callable[[Any], bool] :param source: Input source if starting section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] """ - def __init__(self, pred, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, pred, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): super().__init__() self.pred = pred self.source = source @@ -64,7 +65,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: if not self.pred(item): await output(item) @@ -83,9 +84,9 @@ class Filter(TrioSection): :param func: Matching function. :type func: Callable[[Any], bool] :param source: Source if used as a starting section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] """ - def __init__(self, func, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, func, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): super().__init__() self.func = func self.source = source @@ -98,7 +99,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: if self.func(item): await output(item) @@ -117,9 +118,9 @@ class Changes(TrioSection): Items are compared using the != operator. :param source: Source if used as a starting section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] """ - def __init__(self, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): super().__init__() self.source = source @@ -133,7 +134,7 @@ async def refine(self, input, output): token = object() last = token - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: if last is token or item != last: last = item @@ -155,13 +156,13 @@ class RateLimit(TrioSection): :param interval: Minimum number of seconds between each sent item. :type interval: float :param source: Input when used as first section. - :type source: Optional[AsyncIterable[Any]] + :type source: Optional[AsyncIterableWithAcloseableIterator[Any]] :param subject: Subject for per subject rate limiting. :type subject: Optional[] """ def __init__(self, interval, - source: Optional[AsyncIterable[Any]] = None, + source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None, *, subject: Optional[Union[Hashable, Callable[[Any], Hashable]]] = None): super().__init__() @@ -185,7 +186,7 @@ async def refine(self, input, output): get_subject = lambda item: item[self.subject] timestamps = {} - async with aclosing(source) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: now = trio.current_time() subject = get_subject(item) diff --git a/slurry/sections/_producers.py b/slurry/sections/_producers.py index 098c4ae..e161033 100644 --- a/slurry/sections/_producers.py +++ b/slurry/sections/_producers.py @@ -1,11 +1,11 @@ """Pipeline sections that produce data streams.""" from time import time -from typing import Any, AsyncIterable, cast +from typing import Any import trio -from async_generator import aclosing from ..environments import TrioSection +from .._utils import aclosing class Repeat(TrioSection): """Yields a single item repeatedly at regular intervals. @@ -54,7 +54,7 @@ async def repeater(item, *, task_status=trio.TASK_STATUS_IGNORED): running_repeater = await nursery.start(repeater, self.default) if input: - async with aclosing(input) as aiter: + async with aclosing(input.__aiter__()) as aiter: async for item in aiter: if running_repeater: running_repeater.cancel() @@ -66,7 +66,7 @@ class Metronome(TrioSection): If used as a middle section, the input can be used to set the value that is sent. When an input is received, it is stored and send at the next tick of the clock. If multiple - inputs are received during a tick, only the latest is sent. The preceeding inputs are + inputs are received during a tick, only the latest is sent. The preceding inputs are dropped. When an input is used, closure of the input stream will cause the metronome to close as well. diff --git a/slurry/sections/_refiners.py b/slurry/sections/_refiners.py index a464a89..9fef230 100644 --- a/slurry/sections/_refiners.py +++ b/slurry/sections/_refiners.py @@ -1,8 +1,9 @@ """Sections for transforming an input into a different output.""" -from typing import Any, AsyncIterable, Optional +from typing import Any, Optional from ..environments import TrioSection -from .._utils import safe_aclosing +from .._types import AsyncIterableWithAcloseableIterator +from .._utils import aclosing class Map(TrioSection): """Maps over an asynchronous sequence. @@ -14,7 +15,7 @@ class Map(TrioSection): :param source: Source if used as a starting section. :type source: Optional[AsyncIterable[Any]] """ - def __init__(self, func, source: Optional[AsyncIterable[Any]] = None): + def __init__(self, func, source: Optional[AsyncIterableWithAcloseableIterator[Any]] = None): self.func = func self.source = source @@ -26,6 +27,6 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with safe_aclosing(source.__aiter__()) as aiter: + async with aclosing(source.__aiter__()) as aiter: async for item in aiter: await output(self.func(item)) diff --git a/slurry/sections/abc.py b/slurry/sections/abc.py index 52ec016..3ef065e 100644 --- a/slurry/sections/abc.py +++ b/slurry/sections/abc.py @@ -1,12 +1,16 @@ """ Abstract Base Classes for building pipeline sections. """ from abc import ABC, abstractmethod -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional +from typing import Any, Awaitable, Callable, Iterable, Optional + +from .._types import AsyncIterableWithAcloseableIterator class Section(ABC): """Defines the basic environment api.""" @abstractmethod - async def pump(self, input: Optional[AsyncIterable[Any]], output: Callable[[Any], Awaitable[None]]): + async def pump( + self, input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: Callable[[Any], Awaitable[None]] + ): """The pump method contains the machinery that takes input from previous sections, or any asynchronous iterable, processes it and pushes it to the output. @@ -16,7 +20,7 @@ async def pump(self, input: Optional[AsyncIterable[Any]], output: Callable[[Any] :param input: The input data feed. Will be ``None`` for the first ``Section``, as the first ``Section`` is expected to supply it's own input. - :type input: Optional[AsyncIterable[Any]] + :type input: Optional[AsyncIterableWithAcloseableIterator[Any]] :param output: An awaitable callable used to send output. :type output: Callable[[Any], Awaitable[None]] """ @@ -25,13 +29,15 @@ class AsyncSection(Section): """AsyncSection defines an abc for sections that are designed to run in an async event loop.""" @abstractmethod - async def refine(self, input: Optional[AsyncIterable[Any]], output: Callable[[Any], Awaitable[None]]): + async def refine( + self, input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: Callable[[Any], Awaitable[None]] + ): """The async section refine method must contain the logic that iterates the input, processes the indidual items, and feeds results to the output. :param input: The input data feed. Will be ``None`` for the first ``Section``, as the first ``Section`` is expected to supply it's own input. - :type input: Optional[AsyncIterable[Any]] + :type input: Optional[AsyncIterableWithAcloseableIterator[Any]] :param output: An awaitable callable used to send output. :type output: Callable[[Any], Awaitable[None]] """ diff --git a/slurry/sections/weld.py b/slurry/sections/weld.py index f357a68..1b05fa2 100644 --- a/slurry/sections/weld.py +++ b/slurry/sections/weld.py @@ -1,13 +1,13 @@ """Contains the `weld` utility function for composing sections.""" -from typing import Any, AsyncIterable, Optional, cast +from typing import Any, Optional, cast import trio from .abc import Section -from .._types import PipelineSection, SupportsAclose +from .._types import PipelineSection, AsyncIterableWithAcloseableIterator, SupportsAclose -def weld(nursery, *sections: PipelineSection) -> AsyncIterable[Any]: +def weld(nursery, *sections: PipelineSection) -> AsyncIterableWithAcloseableIterator[Any]: """ Connects the individual parts of a sequence of pipeline sections together and starts pumps for individual Sections. It returns an async iterable which yields results of the sequence. @@ -17,7 +17,7 @@ def weld(nursery, *sections: PipelineSection) -> AsyncIterable[Any]: :param PipelineSection \\*sections: Pipeline sections. """ - async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.MemorySendChannel): + async def pump(section, input: Optional[AsyncIterableWithAcloseableIterator[Any]], output: trio.MemorySendChannel): try: await section.pump(input, output.send) except trio.BrokenResourceError: @@ -43,4 +43,4 @@ async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.Memory output = section section_input = output - return cast(AsyncIterable[Any], output) + return cast(AsyncIterableWithAcloseableIterator[Any], output) diff --git a/tests/fixtures.py b/tests/fixtures.py index 9b572a4..9d81da5 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -86,18 +86,3 @@ def __init__(self, source_aiterable): def __aiter__(self): return self.source_aiterable.__aiter__() - -class AsyncIteratorWithoutAclose: - def __init__(self, source_aiterable): - self.source_aiter = source_aiterable.__aiter__() - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return await self.source_aiter.__anext__() - except StopAsyncIteration: - if hasattr(self.source_aiter, "aclose"): - await self.source_aiter.aclose() - raise diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 38c58a4..0fdc5cb 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -5,7 +5,7 @@ from slurry.sections import Map from slurry.environments import TrioSection -from .fixtures import AsyncIteratorWithoutAclose, produce_increasing_integers +from .fixtures import produce_increasing_integers async def test_pipeline_create(autojump_clock): async with Pipeline.create(None): @@ -47,21 +47,15 @@ async def refine(self, input, output): assert isinstance(i, int) break -async def perform_welding(source): +async def test_welding(autojump_clock): async with Pipeline.create( - source, - (Map(lambda i: i+1),) + produce_increasing_integers(1), + (Map(lambda i: i+1),) ) as pipeline: async with pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [1, 2, 3] -async def test_welding(autojump_clock): - await perform_welding(produce_increasing_integers(1)) - -async def test_welding_with_source_no_aclose(autojump_clock): - await perform_welding(AsyncIteratorWithoutAclose(produce_increasing_integers(1))) - async def test_welding_two_generator_functions_not_allowed(autojump_clock): with pytest.raises(ValueError): async with Pipeline.create(