diff --git a/src/timeout_executor/executor.py b/src/timeout_executor/executor.py index 7a2c97b..6d182da 100644 --- a/src/timeout_executor/executor.py +++ b/src/timeout_executor/executor.py @@ -16,7 +16,7 @@ import anyio import cloudpickle -from typing_extensions import ParamSpec, Self, TypeVar, override +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, override from timeout_executor.const import SUBPROCESS_COMMAND, TIMEOUT_EXECUTOR_INPUT_FILE from timeout_executor.logging import logger @@ -25,7 +25,7 @@ from timeout_executor.types import Callback, CallbackArgs, ExecutorArgs, ProcessCallback if TYPE_CHECKING: - from collections.abc import Coroutine, Iterable + from collections.abc import Awaitable, Coroutine, Iterable from timeout_executor.main import TimeoutExecutor @@ -35,6 +35,7 @@ T = TypeVar("T", infer_variance=True) P2 = ParamSpec("P2") T2 = TypeVar("T2", infer_variance=True) +AnyAwaitable: TypeAlias = "Awaitable[T] | Coroutine[Any, Any, T]" class Executor(Callback[P, T], Generic[P, T]): @@ -179,7 +180,7 @@ def remove_callback(self, callback: ProcessCallback[P, T]) -> Self: @overload def apply_func( timeout_or_executor: float | TimeoutExecutor, - func: Callable[P2, Coroutine[Any, Any, T2]], + func: Callable[P2, AnyAwaitable[T2]], *args: P2.args, **kwargs: P2.kwargs, ) -> AsyncResult[P2, T2]: ... @@ -223,7 +224,7 @@ def apply_func( @overload async def delay_func( timeout_or_executor: float | TimeoutExecutor, - func: Callable[P2, Coroutine[Any, Any, T2]], + func: Callable[P2, AnyAwaitable[T2]], *args: P2.args, **kwargs: P2.kwargs, ) -> AsyncResult[P2, T2]: ... diff --git a/src/timeout_executor/main.py b/src/timeout_executor/main.py index 1af4c01..1d7a6b3 100644 --- a/src/timeout_executor/main.py +++ b/src/timeout_executor/main.py @@ -4,13 +4,13 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, Generic, overload -from typing_extensions import ParamSpec, Self, TypeVar, override +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, override from timeout_executor.executor import apply_func, delay_func from timeout_executor.types import Callback, ProcessCallback if TYPE_CHECKING: - from collections.abc import Coroutine, Iterable + from collections.abc import Awaitable, Coroutine, Iterable from timeout_executor.result import AsyncResult @@ -19,6 +19,7 @@ P = ParamSpec("P") T = TypeVar("T", infer_variance=True) AnyT = TypeVar("AnyT", infer_variance=True, default=Any) +AnyAwaitable: TypeAlias = "Awaitable[T] | Coroutine[Any, Any, T]" class TimeoutExecutor(Callback[Any, AnyT], Generic[AnyT]): @@ -35,10 +36,7 @@ def timeout(self) -> float: @overload def apply( - self, - func: Callable[P, Coroutine[Any, Any, T]], - *args: P.args, - **kwargs: P.kwargs, + self, func: Callable[P, AnyAwaitable[T]], *args: P.args, **kwargs: P.kwargs ) -> AsyncResult[P, T]: ... @overload def apply( @@ -61,10 +59,7 @@ def apply( @overload async def delay( - self, - func: Callable[P, Coroutine[Any, Any, T]], - *args: P.args, - **kwargs: P.kwargs, + self, func: Callable[P, AnyAwaitable[T]], *args: P.args, **kwargs: P.kwargs ) -> AsyncResult[P, T]: ... @overload async def delay( @@ -87,10 +82,7 @@ async def delay( @overload async def apply_async( - self, - func: Callable[P, Coroutine[Any, Any, T]], - *args: P.args, - **kwargs: P.kwargs, + self, func: Callable[P, AnyAwaitable[T]], *args: P.args, **kwargs: P.kwargs ) -> AsyncResult[P, T]: ... @overload async def apply_async( diff --git a/src/timeout_executor/subprocess.py b/src/timeout_executor/subprocess.py index 2c49076..0b45e05 100644 --- a/src/timeout_executor/subprocess.py +++ b/src/timeout_executor/subprocess.py @@ -9,18 +9,19 @@ import anyio import cloudpickle -from typing_extensions import ParamSpec, TypeVar +from typing_extensions import ParamSpec, TypeAlias, TypeVar from timeout_executor.const import TIMEOUT_EXECUTOR_INPUT_FILE from timeout_executor.serde import dumps_error if TYPE_CHECKING: - from collections.abc import Coroutine + from collections.abc import Awaitable, Coroutine __all__ = [] P = ParamSpec("P") T = TypeVar("T", infer_variance=True) +AnyAwaitable: TypeAlias = "Awaitable[T] | Coroutine[Any, Any, T]" def run_in_subprocess() -> None: @@ -82,14 +83,12 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T: def _output_to_file_async( file: Path | anyio.Path, -) -> Callable[ - [Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]] -]: +) -> Callable[[Callable[P, AnyAwaitable[T]]], Callable[P, Coroutine[Any, Any, T]]]: if isinstance(file, Path): file = anyio.Path(file) def wrapper( - func: Callable[P, Coroutine[Any, Any, T]], + func: Callable[P, AnyAwaitable[T]], ) -> Callable[P, Coroutine[Any, Any, T]]: async def inner(*args: P.args, **kwargs: P.kwargs) -> T: dump = b""