Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4981 - Create workaround for asyncio.Task.cancelling support in older Python versions #2009

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions pymongo/_asyncio_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""


from __future__ import annotations

import asyncio
import sys
from typing import Any, Coroutine, Optional


# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
asyncio._register_task(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use private asyncio apis. Can we do this pattern instead?:

_Task(asyncio.create_task(...))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the problem. That wouldn't work because asyncio.all_tasks() would return the non-wrapper class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem still remains though. If we can't do this using only public apis then we can't add this workaround at all.

Copy link
Contributor Author

@NoahStapp NoahStapp Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any implementation needs to support our _Task getting cancelled by external components like unittest or pytest. The approach you outline, where _Task is just a wrapper around the actual task, doesn't work in this case: the loop will only interact with the task it owns, rather than the _Task itself. When our testing framework goes to cancel all remaining tasks, it only has access to the loop, which won't have a reference to any _Task instances.

Subclassing asyncio.Task and overriding the cancel method to support pre-3.11 Python versions solves this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio._register_task is documented in the asyncio docs here: https://docs.python.org/3/library/asyncio-extending.html as the only way to extend Task functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above passes all tests for me locally and doesn't use any private asyncio methods or import hacking, only MethodType tricks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we own the task and its cancellation, can't we keep a mapping of tasks to cancelled state ourselves?

Copy link
Contributor Author

@NoahStapp NoahStapp Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't own the cancellation when the test runner cancels the task at the end of the test that created it. We could write our own teardown method that uses such a mapping to run before the runner's teardown, but that seems messy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I relent for now since we plan remove this code anyway once we figure out the real issue causing the hangs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a TODO for us to come back and revisit this at that time.


def cancel(self, msg: Optional[str] = None) -> bool:
self._cancel_requests += 1
return super().cancel(msg=msg)

def uncancel(self) -> int:
if self._cancel_requests > 0:
self._cancel_requests -= 1
return self._cancel_requests

def cancelling(self) -> int:
return self._cancel_requests


def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay if we're going down this route, can we at least only use this workaround on <=3.10?
Like:

if version >= 3.11:
    return asyncio.create_task(...)
return _Task(coro, name=name)

Copy link
Contributor Author

@NoahStapp NoahStapp Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're fine with _Task.cancelling and asyncio.Task.cancelling having the same name but different actual purposes, then yes. asyncio.Task.cancelling technically returns the number of cancelled requests for a task instead of a boolean. I could also just implement a counter instead of a boolean to make it match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would they differ? Oh I see, the stdlib returns a count of the cancellation requests as an integer. Is that what you mean?

Also I see this:

This method is used by asyncio’s internals and isn’t expected to be used by end-user code. See uncancel() for more details.

So we are going against their guidance in multiple ways.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's implement the cancel/uncancel counter pattern. It's simple enough.

1 change: 0 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ async def _process_results_cursor(
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res

except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:
Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations

import asyncio
import contextlib
import enum
import socket
Expand Down Expand Up @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc

Expand Down Expand Up @@ -200,6 +203,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
Expand Down Expand Up @@ -722,6 +727,8 @@ async def create_encrypted_collection(
await database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ async def _run(self) -> None:
except ReferenceError:
# Topology was garbage-collected.
await self.close()
finally:
if self._executor._stopped:
await self._rtt_monitor.close()

async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
Expand All @@ -254,6 +257,8 @@ async def _check_server(self) -> ServerDescription:
details = cast(Mapping[str, Any], exc.details)
await self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
Expand Down
7 changes: 4 additions & 3 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)

from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception

Expand Down Expand Up @@ -259,12 +260,12 @@ async def async_receive_data(

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
cancellation_task = create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
Expand Down
12 changes: 8 additions & 4 deletions pymongo/periodic_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import weakref
from typing import Any, Optional

from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock

_IS_SYNC = False
Expand Down Expand Up @@ -61,10 +62,11 @@ def __repr__(self) -> str:
def open(self) -> None:
"""Start. Multiple calls have no effect."""
self._stopped = False
started = self._task and not self._task.done()

if not started:
self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name)
if self._task is None or (
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
):
self._task = create_task(self._run(), name=self._name)

def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
Expand All @@ -83,7 +85,7 @@ async def join(self, timeout: Optional[int] = None) -> None:
pass
except asyncio.exceptions.CancelledError:
# Task was already finished, or not yet started.
pass
raise

def wake(self) -> None:
"""Execute the target function soon."""
Expand All @@ -97,6 +99,8 @@ def skip_sleep(self) -> None:

async def _run(self) -> None:
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError
try:
if not await self._target():
self._stopped = True
Expand Down
1 change: 0 additions & 1 deletion pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def _process_results_cursor(
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res

except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:
Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations

import asyncio
import contextlib
import enum
import socket
Expand Down Expand Up @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc

Expand Down Expand Up @@ -200,6 +203,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
Expand Down Expand Up @@ -716,6 +721,8 @@ def create_encrypted_collection(
database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def _run(self) -> None:
except ReferenceError:
# Topology was garbage-collected.
self.close()
finally:
if self._executor._stopped:
self._rtt_monitor.close()

def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
Expand All @@ -254,6 +257,8 @@ def _check_server(self) -> ServerDescription:
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
Expand Down
8 changes: 4 additions & 4 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,9 @@ def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
client_context.client.close()
client_context.client = None
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()


Expand Down Expand Up @@ -1135,7 +1136,7 @@ class IntegrationTest(PyMongoTestCase):

@client_context.require_connection
def setUp(self) -> None:
if not _IS_SYNC and client_context.client is not None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
Expand Down Expand Up @@ -1210,7 +1211,6 @@ def teardown():
c.drop_database("pymongo_test_mike")
c.drop_database("pymongo_test_bernie")
c.close()

print_running_clients()


Expand Down
8 changes: 4 additions & 4 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,9 @@ async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
await async_client_context.client.close()
async_client_context.client = None
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()


Expand Down Expand Up @@ -1153,7 +1154,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):

@async_client_context.require_connection
async def asyncSetUp(self) -> None:
if not _IS_SYNC and async_client_context.client is not None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
Expand Down Expand Up @@ -1228,7 +1229,6 @@ async def async_teardown():
await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie")
await c.close()

print_running_clients()


Expand Down
5 changes: 5 additions & 0 deletions test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ async def get_x(db):
async def test_server_selection_timeout(self):
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()

client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)

Expand All @@ -1292,18 +1293,22 @@ async def test_server_selection_timeout(self):
self.assertRaises(
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
)
await client.close()

client = AsyncMongoClient(
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()

client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
await client.close()

# Test invalid timeout in URI ignored and set to default.
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
await client.close()

client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
Expand Down
5 changes: 5 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,7 @@ def get_x(db):
def test_server_selection_timeout(self):
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()

client = MongoClient(serverSelectionTimeoutMS=0, connect=False)

Expand All @@ -1253,16 +1254,20 @@ def test_server_selection_timeout(self):
self.assertRaises(
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
)
client.close()

client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()

client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
client.close()

# Test invalid timeout in URI ignored and set to default.
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
client.close()

client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
Expand Down
Loading