Skip to content

Commit

Permalink
feat: Update type annotations for Python 3.9+ (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Jan 23, 2025
1 parent 86ae705 commit 131f1f7
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 76 deletions.
1 change: 1 addition & 0 deletions changes/73.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update and modernize the type annotations of the `taskgroup` module for Python 3.9 or later
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ underlines = ["-", "", ""]
[tool.ruff]
line-length = 88
src = ["src"]
target-version = "py39"
preview = true

[tool.ruff.lint]
ignore = ["E203","E731","E501"]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
Expand All @@ -58,16 +63,13 @@ select = [
# "C", # flake8-comprehensions
# "B", # flake8-bugbear
]
ignore = ["E203","E731","E501"]
preview = true
target-version = "py39"

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["aiotools"]
known-local-folder = ["src"]
split-on-trailing-comma = true

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"src/aiotools/taskgroup/__init__.py" = ["F405"]

[tool.black]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ universal = false
testpaths = tests

[mypy]
mypy_path = src
mypy_path = src:tests

[mypy-pytest.*]
ignore_missing_imports = true
64 changes: 39 additions & 25 deletions src/aiotools/taskgroup/persistent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import itertools
import logging
Expand All @@ -13,10 +15,11 @@
Coroutine,
Optional,
Sequence,
Type,
Union,
TypeVar,
)

from typing_extensions import Self

from .. import compat
from .types import AsyncExceptionHandler

Expand All @@ -31,20 +34,26 @@
_log = logging.getLogger(__name__)
_all_ptaskgroups: weakref.WeakSet["PersistentTaskGroup"] = weakref.WeakSet()

T_co = TypeVar("T_co", covariant=True)

async def _default_exc_handler(exc_type, exc_obj, exc_tb) -> None:

async def _default_exc_handler(
exc_type: type[BaseException],
exc_obj: BaseException,
exc_tb: TracebackType,
) -> None:
traceback.print_exc()


class PersistentTaskGroup:
_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_tasks: set[asyncio.Task]
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional[Token["PersistentTaskGroup"]]
_tasks: set[asyncio.Task[Any]]
_on_completed_fut: Optional[asyncio.Future[Any]]
_current_taskgroup_token: Optional[Token[PersistentTaskGroup]]

@classmethod
def all_ptaskgroups(cls) -> Sequence["PersistentTaskGroup"]:
def all_ptaskgroups(cls) -> Sequence[PersistentTaskGroup]:
return list(_all_ptaskgroups)

def __init__(
Expand Down Expand Up @@ -78,10 +87,10 @@ def get_name(self) -> str:

def create_task(
self,
coro: Coroutine[Any, Any, Any],
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
*,
name: Optional[str] = None,
) -> Awaitable[Any]:
) -> asyncio.Future[T_co]:
if not self._entered:
# When used as object attribute, auto-enter.
self._entered = True
Expand All @@ -91,11 +100,11 @@ def create_task(

def _create_task_with_name(
self,
coro: Coroutine[Any, Any, Any],
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
*,
name: Optional[str] = None,
cb: Callable[[asyncio.Task], Any],
) -> Awaitable[Any]:
cb: Callable[[asyncio.Task[Any]], None],
) -> asyncio.Future[T_co]:
loop = compat.get_running_loop()
result_future = loop.create_future()
child_task = loop.create_task(
Expand Down Expand Up @@ -143,9 +152,9 @@ async def shutdown(self) -> None:

async def _task_wrapper(
self,
coro: Coroutine,
result_future: weakref.ref[asyncio.Future],
) -> Any:
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
result_future: weakref.ref[asyncio.Future[T_co]],
) -> T_co | None:
loop = compat.get_running_loop()
task = compat.current_task()
fut = result_future()
Expand All @@ -168,7 +177,11 @@ async def _task_wrapper(
try:
if fut is not None:
fut.set_exception(e)
await self._exc_handler(*sys.exc_info())
exc_info = sys.exc_info()
assert exc_info[0] is not None
assert exc_info[1] is not None
assert exc_info[2] is not None
await self._exc_handler(*exc_info)
except Exception as exc:
# If there are exceptions inside the exception handler
# we report it as soon as possible using the event loop's
Expand All @@ -182,10 +195,11 @@ async def _task_wrapper(
"exception": exc,
"task": task,
})
return None
finally:
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
def _on_task_done(self, task: asyncio.Task[Any]) -> None:
try:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
Expand Down Expand Up @@ -213,23 +227,23 @@ def _on_task_done(self, task: asyncio.Task) -> None:
finally:
self._tasks.discard(task)

async def __aenter__(self) -> "PersistentTaskGroup":
async def __aenter__(self) -> Self:
self._parent_task = compat.current_task()
self._current_taskgroup_token = current_ptaskgroup.set(self)
self._entered = True
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
assert self._parent_task is not None
self._exiting = True
propagate_cancellation_error: Optional[
Union[Type[BaseException], BaseException]
] = None
propagate_cancellation_error: Optional[type[BaseException] | BaseException] = (
None
)

if (
exc_val is not None
Expand Down
84 changes: 44 additions & 40 deletions src/aiotools/taskgroup/persistent_compat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

import asyncio
import itertools
import logging
import sys
import traceback

try:
from contextvars import ContextVar, Token

has_contextvars = True
except ImportError:
has_contextvars = False
import weakref
from contextvars import ContextVar, Token
from types import TracebackType
from typing import (
Any,
Expand All @@ -19,41 +15,46 @@
Coroutine,
Optional,
Sequence,
Type,
TypeVar,
)

from typing_extensions import Self

from .. import compat
from .common import create_task_with_name, patch_task
from .types import AsyncExceptionHandler

__all__ = [
"PersistentTaskGroup",
"current_ptaskgroup",
]

if has_contextvars:
current_ptaskgroup: ContextVar["PersistentTaskGroup"] = ContextVar(
"current_ptaskgroup"
)
__all__.append("current_ptaskgroup")
current_ptaskgroup: ContextVar["PersistentTaskGroup"] = ContextVar("current_ptaskgroup")

_ptaskgroup_idx = itertools.count()
_log = logging.getLogger(__name__)
_all_ptaskgroups: "weakref.WeakSet[PersistentTaskGroup]" = weakref.WeakSet()
_all_ptaskgroups: weakref.WeakSet[PersistentTaskGroup] = weakref.WeakSet()

T_co = TypeVar("T_co", covariant=True)


async def _default_exc_handler(exc_type, exc_obj, exc_tb) -> None:
async def _default_exc_handler(
exc_type: type[BaseException],
exc_obj: BaseException,
exc_tb: TracebackType,
) -> None:
traceback.print_exc()


class PersistentTaskGroup:
_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_tasks: set[asyncio.Task]
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional["Token[PersistentTaskGroup]"]
_tasks: set[asyncio.Task[Any]]
_on_completed_fut: Optional[asyncio.Future[Any]]
_current_taskgroup_token: Optional[Token[PersistentTaskGroup]]

@classmethod
def all_ptaskgroups(cls) -> Sequence["PersistentTaskGroup"]:
def all_ptaskgroups(cls) -> Sequence[PersistentTaskGroup]:
return list(_all_ptaskgroups)

def __init__(
Expand Down Expand Up @@ -87,10 +88,10 @@ def get_name(self) -> str:

def create_task(
self,
coro: Coroutine[Any, Any, Any],
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
*,
name: Optional[str] = None,
) -> Awaitable[Any]:
) -> asyncio.Future[T_co]:
if not self._entered:
# When used as object attribute, auto-enter.
self._entered = True
Expand All @@ -100,11 +101,11 @@ def create_task(

def _create_task_with_name(
self,
coro: Coroutine[Any, Any, Any],
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
*,
name: Optional[str] = None,
cb: Callable[[asyncio.Task], Any],
) -> Awaitable[Any]:
cb: Callable[[asyncio.Task[None]], Any],
) -> asyncio.Future[T_co]:
loop = compat.get_running_loop()
result_future = loop.create_future()
child_task = create_task_with_name(
Expand Down Expand Up @@ -152,9 +153,9 @@ async def shutdown(self) -> None:

async def _task_wrapper(
self,
coro: Coroutine,
result_future: "weakref.ref[asyncio.Future]",
) -> Any:
coro: Coroutine[Any, Any, T_co] | Awaitable[T_co],
result_future: weakref.ref[asyncio.Future[T_co]],
) -> T_co | None:
loop = compat.get_running_loop()
task = compat.current_task()
fut = result_future()
Expand All @@ -177,7 +178,11 @@ async def _task_wrapper(
try:
if fut is not None:
fut.set_exception(e)
await self._exc_handler(*sys.exc_info())
exc_info = sys.exc_info()
assert exc_info[0] is not None
assert exc_info[1] is not None
assert exc_info[2] is not None
await self._exc_handler(*exc_info)
except Exception as exc:
# If there are exceptions inside the exception handler
# we report it as soon as possible using the event loop's
Expand All @@ -191,10 +196,11 @@ async def _task_wrapper(
"exception": exc,
"task": task,
})
return None
finally:
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
def _on_task_done(self, task: asyncio.Task[None]) -> None:
try:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
Expand Down Expand Up @@ -222,20 +228,19 @@ def _on_task_done(self, task: asyncio.Task) -> None:
finally:
self._tasks.discard(task)

async def __aenter__(self) -> "PersistentTaskGroup":
async def __aenter__(self) -> Self:
self._parent_task = compat.current_task()
patch_task(self._parent_task)
if has_contextvars:
self._current_taskgroup_token = current_ptaskgroup.set(self)
self._current_taskgroup_token = current_ptaskgroup.set(self)
self._entered = True
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
self._exiting = True
propagate_cancelation = False

Expand All @@ -262,10 +267,9 @@ async def __aexit__(
prop_ex = await self._wait_completion()
if prop_ex is not None:
propagate_cancelation = prop_ex
if has_contextvars:
if self._current_taskgroup_token:
current_ptaskgroup.reset(self._current_taskgroup_token)
self._current_taskgroup_token = None
if self._current_taskgroup_token:
current_ptaskgroup.reset(self._current_taskgroup_token)
self._current_taskgroup_token = None

if propagate_cancelation:
# The wrapping task was cancelled; since we're done with
Expand Down
Loading

0 comments on commit 131f1f7

Please sign in to comment.