Skip to content

Commit

Permalink
Explicitly discard task refs on completion, replacing WeakSet with set (
Browse files Browse the repository at this point in the history
#60)

Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
heckad and achimnol authored Jan 23, 2025
1 parent 6ddc572 commit 86ae705
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 72 deletions.
1 change: 1 addition & 0 deletions changes/60.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Track task references explicitly and discard them after their completion in taskgroups
2 changes: 1 addition & 1 deletion src/aiotools/taskgroup/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def __aenter__(self):
async def __aexit__(self, et, exc, tb):
try:
return await super().__aexit__(et, exc, tb)
except BaseExceptionGroup as eg:
except BaseExceptionGroup as eg: # noqa: F821 (this module is not used in Python older than 3.11)
# Just wrap the exception group as TaskGroupError for backward
# compatibility. In Python 3.11 or higher, TaskGroupError
# also inherits BaseExceptionGroup, so the standard except*
Expand Down
58 changes: 30 additions & 28 deletions src/aiotools/taskgroup/base_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
except ImportError:
has_contextvars = False
import itertools
import weakref

from ..compat import current_task, get_running_loop
from .common import create_task_with_name, patch_task
Expand Down Expand Up @@ -60,7 +59,7 @@ def __init__(self, *, name=None):
self._loop = get_running_loop()
self._parent_task = None
self._parent_cancel_requested = False
self._tasks = weakref.WeakSet()
self._tasks = set()
self._unfinished_tasks = 0
self._errors = []
self._base_error = None
Expand Down Expand Up @@ -206,36 +205,39 @@ def _abort(self):
t.cancel()

def _on_task_done(self, task):
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
try:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0

if self._exiting and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
if self._exiting and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)

if task.cancelled():
return
if task.cancelled():
return

exc = task.exception()
if exc is None:
return
exc = task.exception()
if exc is None:
return

self._errors.append(exc)
if self._is_base_error(exc) and self._base_error is None:
self._base_error = exc

if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
"message": (
f"Task {task!r} has errored out but its parent "
f"task {self._parent_task} is already completed"
),
"exception": exc,
"task": task,
})
return
self._errors.append(exc)
if self._is_base_error(exc) and self._base_error is None:
self._base_error = exc

if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
"message": (
f"Task {task!r} has errored out but its parent "
f"task {self._parent_task} is already completed"
),
"exception": exc,
"task": task,
})
return
finally:
self._tasks.discard(task)

self._abort()
if not self._parent_task.__cancel_requested__:
Expand Down
43 changes: 23 additions & 20 deletions src/aiotools/taskgroup/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def _default_exc_handler(exc_type, exc_obj, exc_tb) -> None:
class PersistentTaskGroup:
_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_tasks: "weakref.WeakSet[asyncio.Task]"
_tasks: set[asyncio.Task]
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional[Token["PersistentTaskGroup"]]

Expand All @@ -62,7 +62,7 @@ def __init__(
self._unfinished_tasks = 0
self._on_completed_fut = None
self._parent_task = compat.current_task()
self._tasks = weakref.WeakSet()
self._tasks = set()
self._current_taskgroup_token = None
_all_ptaskgroups.add(self)
if exception_handler is None:
Expand Down Expand Up @@ -186,29 +186,32 @@ async def _task_wrapper(
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None
try:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None

if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)

if task.cancelled():
_log.debug("%r in %r has been cancelled.", task, self)
return
if task.cancelled():
_log.debug("%r in %r has been cancelled.", task, self)
return

exc = task.exception()
if exc is None:
return
exc = task.exception()
if exc is None:
return

# Now the exception is BaseException.
if self._base_error is None:
self._base_error = exc
# Now the exception is BaseException.
if self._base_error is None:
self._base_error = exc

self._trigger_shutdown()
if not self._parent_task.cancelling():
self._parent_cancel_requested = True
self._trigger_shutdown()
if not self._parent_task.cancelling():
self._parent_cancel_requested = True
finally:
self._tasks.discard(task)

async def __aenter__(self) -> "PersistentTaskGroup":
self._parent_task = compat.current_task()
Expand Down
43 changes: 23 additions & 20 deletions src/aiotools/taskgroup/persistent_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def _default_exc_handler(exc_type, exc_obj, exc_tb) -> None:
class PersistentTaskGroup:
_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_tasks: "weakref.WeakSet[asyncio.Task]"
_tasks: set[asyncio.Task]
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional["Token[PersistentTaskGroup]"]

Expand All @@ -71,7 +71,7 @@ def __init__(
self._unfinished_tasks = 0
self._on_completed_fut = None
self._parent_task = compat.current_task()
self._tasks = weakref.WeakSet()
self._tasks = set()
self._current_taskgroup_token = None
_all_ptaskgroups.add(self)
if exception_handler is None:
Expand Down Expand Up @@ -195,29 +195,32 @@ async def _task_wrapper(
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None
try:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None

if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)

if task.cancelled():
_log.debug("%r in %r has been cancelled.", task, self)
return
if task.cancelled():
_log.debug("%r in %r has been cancelled.", task, self)
return

exc = task.exception()
if exc is None:
return
exc = task.exception()
if exc is None:
return

# Now the exception is BaseException.
if self._base_error is None:
self._base_error = exc
# Now the exception is BaseException.
if self._base_error is None:
self._base_error = exc

self._trigger_shutdown()
if not self._parent_task.__cancel_requested__: # type: ignore
self._parent_cancel_requested = True
self._trigger_shutdown()
if not self._parent_task.__cancel_requested__: # type: ignore
self._parent_cancel_requested = True
finally:
self._tasks.discard(task)

async def __aenter__(self) -> "PersistentTaskGroup":
self._parent_task = compat.current_task()
Expand Down
2 changes: 1 addition & 1 deletion src/aiotools/taskgroup/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TaskGroupError(MultiError): # type: ignore[no-redef]

else:

class MultiError(ExceptionGroup): # type: ignore[no-redef,name-defined]
class MultiError(ExceptionGroup): # type: ignore[no-redef,name-defined] # noqa: F821
def __init__(self, msg, errors=()):
super().__init__(msg, errors)
self.__errors__ = errors
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ptaskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def handler(exc_type, exc_obj, exc_tb):
async with aiotools.PersistentTaskGroup(exception_handler=handler) as tg:
for _ in range(10):
tg.create_task(subtask())
except aitools.TaskGroupError:
except aiotools.TaskGroupError:
assert False, "should not reach here"

# Check if the event loop exception handler is called.
Expand Down
1 change: 0 additions & 1 deletion tests/test_taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ async def do_job(delay, result):
@pytest.mark.asyncio
async def test_taskgroup_memoryleak_with_persistent_tg():
with VirtualClock().patch_loop(), warnings.catch_warnings():
warnings.simplefilter("ignore")

async def do_job(delay):
await asyncio.sleep(delay)
Expand Down

0 comments on commit 86ae705

Please sign in to comment.