From 60ff030301755e02a9e9f7d748ba5e7885025abd Mon Sep 17 00:00:00 2001 From: Marco Sirabella Date: Mon, 20 May 2024 09:31:34 -0400 Subject: [PATCH] Accept AsyncIterables being passed to Response Fixes pallets/flask#5322 --- src/quart/wrappers/response.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/quart/wrappers/response.py b/src/quart/wrappers/response.py index 065eabf..a63bcdf 100644 --- a/src/quart/wrappers/response.py +++ b/src/quart/wrappers/response.py @@ -102,19 +102,24 @@ async def __anext__(self) -> bytes: class IterableBody(ResponseBody): - def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None: + def __init__(self, iterable: AsyncIterable[bytes] | Iterable) -> None: self.iter: AsyncGenerator[bytes, None] if isasyncgen(iterable): self.iter = iterable elif isgenerator(iterable): self.iter = run_sync_iterable(iterable) - else: - + elif isinstance(iterable, AsyncIterable): async def _aiter() -> AsyncGenerator[bytes, None]: - for data in iterable: # type: ignore + async for data in iterable: yield data - self.iter = _aiter() + elif isinstance(iterable, Iterable): + async def _aiter() -> AsyncGenerator[bytes, None]: + for data in iterable: + yield data + self.iter = _aiter() + else: + raise ValueError("unreachable?") async def __aenter__(self) -> IterableBody: return self @@ -262,7 +267,7 @@ class Response(SansIOResponse): def __init__( self, - response: ResponseBody | AnyStr | Iterable | None = None, + response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None, status: int | None = None, headers: dict | Headers | None = None, mimetype: str | None = None,