From 7fbeca9793bb88561a8b9972df926193d54a1225 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Tue, 1 Oct 2024 15:06:12 -0400 Subject: [PATCH 01/10] [PYTHON-4803] Big endian fix for binary bson vectors (#1885) --- bson/binary.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bson/binary.py b/bson/binary.py index 47c52d4892..96b61b6dab 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -432,7 +432,7 @@ def from_vector( raise NotImplementedError("%s not yet supported" % dtype) metadata = struct.pack(" BinaryVector: @@ -454,7 +454,7 @@ def as_vector(self) -> BinaryVector: if dtype == BinaryVectorDtype.INT8: dtype_format = "b" - format_string = f"{n_values}{dtype_format}" + format_string = f"<{n_values}{dtype_format}" vector = list(struct.unpack_from(format_string, self, position)) return BinaryVector(vector, dtype, padding) @@ -465,13 +465,16 @@ def as_vector(self) -> BinaryVector: raise ValueError( "Corrupt data. N bytes for a float32 vector must be a multiple of 4." ) - vector = list(struct.unpack_from(f"{n_values}f", self, position)) + dtype_format = "f" + format_string = f"<{n_values}{dtype_format}" + vector = list(struct.unpack_from(format_string, self, position)) return BinaryVector(vector, dtype, padding) elif dtype == BinaryVectorDtype.PACKED_BIT: # data packed as uint8 dtype_format = "B" - unpacked_uint8s = list(struct.unpack_from(f"{n_values}{dtype_format}", self, position)) + format_string = f"<{n_values}{dtype_format}" + unpacked_uint8s = list(struct.unpack_from(format_string, self, position)) return BinaryVector(unpacked_uint8s, dtype, padding) else: From 02794079802f264be021707d20fc64e292ef74b7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 1 Oct 2024 14:31:21 -0500 Subject: [PATCH 02/10] PYTHON-4806 Fix expected metadata in mockupdb tests (#1888) --- hatch.toml | 2 +- test/mockupdb/test_handshake.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/hatch.toml b/hatch.toml index d5293a1d7f..60bd0af014 100644 --- a/hatch.toml +++ b/hatch.toml @@ -43,7 +43,7 @@ features = ["test"] test = "pytest -v --durations=5 --maxfail=10 {args}" test-eg = "bash ./.evergreen/run-tests.sh {args}" test-async = "pytest -v --durations=5 --maxfail=10 -m default_async {args}" -test-mockupdb = ["pip install -U git+https://github.com/ajdavis/mongo-mockup-db@master", "test -m mockupdb"] +test-mockupdb = ["pip install -U git+https://github.com/mongodb-labs/mongo-mockup-db@master", "test -m mockupdb"] [envs.encryption] skip-install = true diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 19e10f9617..8193714a86 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -26,7 +26,7 @@ from bson.objectid import ObjectId -from pymongo import MongoClient +from pymongo import MongoClient, has_c from pymongo import version as pymongo_version from pymongo.errors import OperationFailure from pymongo.server_api import ServerApi, ServerApiVersion @@ -39,7 +39,11 @@ def _check_handshake_data(request): data = request["client"] assert data["application"] == {"name": "my app"} - assert data["driver"] == {"name": "PyMongo", "version": pymongo_version} + if has_c(): + name = "PyMongo|c" + else: + name = "PyMongo" + assert data["driver"] == {"name": name, "version": pymongo_version} # Keep it simple, just check these fields exist. assert "os" in data From 7848feb09a12bb6a14fb18deb8b873d8c2eff8a9 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 1 Oct 2024 18:32:41 -0400 Subject: [PATCH 03/10] PYTHON-4786 - Fix UpdateResult.did_upsert TypeError (#1878) --- pymongo/results.py | 7 ++-- test/asynchronous/test_client_bulk_write.py | 40 +++++++++++++++++++++ test/asynchronous/test_collection.py | 13 +++++++ test/test_client_bulk_write.py | 40 +++++++++++++++++++++ test/test_collection.py | 13 +++++++ test/test_results.py | 22 ++++++++++++ 6 files changed, 133 insertions(+), 2 deletions(-) diff --git a/pymongo/results.py b/pymongo/results.py index b34f6c4926..d17ff1c3ea 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -171,9 +171,12 @@ def upserted_id(self) -> Any: @property def did_upsert(self) -> bool: - """Whether or not an upsert took place.""" + """Whether an upsert took place. + + .. versionadded:: 4.9 + """ assert self.__raw_result is not None - return len(self.__raw_result.get("upserted", {})) > 0 + return "upserted" in self.__raw_result class DeleteResult(_WriteResult): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 80cfd30bde..9464337809 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -550,6 +550,46 @@ async def test_returns_error_if_auto_encryption_configured(self): "bulk_write does not currently support automatic encryption", context.exception._message ) + @async_client_context.require_version_min(8, 0, 0, -24) + @async_client_context.require_no_serverless + async def test_upserted_result(self): + client = await self.async_rs_or_single_client() + + collection = client.db["coll"] + self.addAsyncCleanup(collection.drop) + await collection.drop() + + models = [] + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": "a"}, + update={"$set": {"x": 1}}, + upsert=True, + ) + ) + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": None}, + update={"$set": {"x": 1}}, + upsert=True, + ) + ) + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": None}, + update={"$set": {"x": 1}}, + ) + ) + result = await client.bulk_write(models=models, verbose_results=True) + + self.assertEqual(result.upserted_count, 2) + self.assertEqual(result.update_results[0].did_upsert, True) + self.assertEqual(result.update_results[1].did_upsert, True) + self.assertEqual(result.update_results[2].did_upsert, False) + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#11-multi-batch-bulkwrites class TestClientBulkWriteCSOT(AsyncIntegrationTest): diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 74a4a5151d..612090b69f 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -1444,6 +1444,19 @@ async def test_update_one(self): self.assertRaises(InvalidOperation, lambda: result.upserted_id) self.assertFalse(result.acknowledged) + async def test_update_result(self): + db = self.db + await db.drop_collection("test") + + result = await db.test.update_one({"x": 0}, {"$inc": {"x": 1}}, upsert=True) + self.assertEqual(result.did_upsert, True) + + result = await db.test.update_one({"_id": None, "x": 0}, {"$inc": {"x": 1}}, upsert=True) + self.assertEqual(result.did_upsert, True) + + result = await db.test.update_one({"_id": None}, {"$inc": {"x": 1}}) + self.assertEqual(result.did_upsert, False) + async def test_update_many(self): db = self.db await db.drop_collection("test") diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index d1aff03fc9..58b5015dd2 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -550,6 +550,46 @@ def test_returns_error_if_auto_encryption_configured(self): "bulk_write does not currently support automatic encryption", context.exception._message ) + @client_context.require_version_min(8, 0, 0, -24) + @client_context.require_no_serverless + def test_upserted_result(self): + client = self.rs_or_single_client() + + collection = client.db["coll"] + self.addCleanup(collection.drop) + collection.drop() + + models = [] + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": "a"}, + update={"$set": {"x": 1}}, + upsert=True, + ) + ) + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": None}, + update={"$set": {"x": 1}}, + upsert=True, + ) + ) + models.append( + UpdateOne( + namespace="db.coll", + filter={"_id": None}, + update={"$set": {"x": 1}}, + ) + ) + result = client.bulk_write(models=models, verbose_results=True) + + self.assertEqual(result.upserted_count, 2) + self.assertEqual(result.update_results[0].did_upsert, True) + self.assertEqual(result.update_results[1].did_upsert, True) + self.assertEqual(result.update_results[2].did_upsert, False) + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#11-multi-batch-bulkwrites class TestClientBulkWriteCSOT(IntegrationTest): diff --git a/test/test_collection.py b/test/test_collection.py index dab59cf1b2..a2c3b0b0b6 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1429,6 +1429,19 @@ def test_update_one(self): self.assertRaises(InvalidOperation, lambda: result.upserted_id) self.assertFalse(result.acknowledged) + def test_update_result(self): + db = self.db + db.drop_collection("test") + + result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}}, upsert=True) + self.assertEqual(result.did_upsert, True) + + result = db.test.update_one({"_id": None, "x": 0}, {"$inc": {"x": 1}}, upsert=True) + self.assertEqual(result.did_upsert, True) + + result = db.test.update_one({"_id": None}, {"$inc": {"x": 1}}) + self.assertEqual(result.did_upsert, False) + def test_update_many(self): db = self.db db.drop_collection("test") diff --git a/test/test_results.py b/test/test_results.py index 19e086a9a5..deb09d7ed4 100644 --- a/test/test_results.py +++ b/test/test_results.py @@ -122,6 +122,28 @@ def test_update_result(self): self.assertEqual(raw_result["n"], result.matched_count) self.assertEqual(raw_result["nModified"], result.modified_count) self.assertEqual(raw_result["upserted"], result.upserted_id) + self.assertEqual(result.did_upsert, True) + + raw_result_2 = { + "n": 1, + "nModified": 1, + "upserted": [ + {"index": 5, "_id": 1}, + ], + } + self.repr_test(UpdateResult, raw_result_2) + + result = UpdateResult(raw_result_2, True) + self.assertEqual(result.did_upsert, True) + + raw_result_3 = { + "n": 1, + "nModified": 1, + } + self.repr_test(UpdateResult, raw_result_3) + + result = UpdateResult(raw_result_3, True) + self.assertEqual(result.did_upsert, False) result = UpdateResult(raw_result, False) self.assertEqual(raw_result, result.raw_result) From 1c284307250dec4fb6b7b161d72e1876c067b4cb Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 1 Oct 2024 17:52:16 -0500 Subject: [PATCH 04/10] PYTHON-4808 Add changelog for 4.10.1 (#1890) --- doc/changelog.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index 6c8b8261ac..76e91c2b27 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,23 @@ Changelog ========= +Changes in Version 4.10.1 +------------------------- + +Version 4.10.1 is a bug fix release. + +- Fixed a bug where :meth:`~pymongo.results.UpdateResult.did_upsert` would raise a ``TypeError``. +- Fixed Binary BSON subtype (9) support on big-endian operating systems (such as zSeries). + +Issues Resolved +............... + +See the `PyMongo 4.10.1 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.10.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40788 + + Changes in Version 4.10.0 ------------------------- From 77cd7ab9f6dc48e72a3bae94d2cca2e4200e6978 Mon Sep 17 00:00:00 2001 From: "mongodb-dbx-release-bot[bot]" <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:53:25 +0000 Subject: [PATCH 05/10] BUMP 4.10.1 Signed-off-by: mongodb-dbx-release-bot[bot] <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> --- pymongo/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/_version.py b/pymongo/_version.py index 3de24a8e14..c0232ba514 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -18,7 +18,7 @@ import re from typing import List, Tuple, Union -__version__ = "4.11.0.dev0" +__version__ = "4.10.1" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: From da059a6b0afdf971abe6ffbdc5ca4aec09c61b0d Mon Sep 17 00:00:00 2001 From: "mongodb-dbx-release-bot[bot]" <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 23:09:24 +0000 Subject: [PATCH 06/10] BUMP 4.11.0.dev0 Signed-off-by: mongodb-dbx-release-bot[bot] <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> --- pymongo/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/_version.py b/pymongo/_version.py index c0232ba514..3de24a8e14 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -18,7 +18,7 @@ import re from typing import List, Tuple, Union -__version__ = "4.10.1" +__version__ = "4.11.0.dev0" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: From 2a83349f7159c0117848cf3ab1a67b6ad7d6cf0d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 2 Oct 2024 11:34:43 -0500 Subject: [PATCH 07/10] PYTHON-4812 Update changelog for 4.9.2 and 4.9.1 [master] (#1892) --- doc/changelog.rst | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index 76e91c2b27..574ecad763 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -36,6 +36,36 @@ in this release. .. _PyMongo 4.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40553 +Changes in Version 4.9.2 +------------------------- + +- Fixed a bug where :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` could deadlock. +- Fixed a bug where PyMongo could fail to import on Windows if ``asyncio`` is misconfigured. +- Fixed a bug where :meth:`~pymongo.results.UpdateResult.did_upsert` would raise a ``TypeError``. + +Issues Resolved +............... + +See the `PyMongo 4.9.2 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.9.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40732 + + +Changes in Version 4.9.1 +------------------------- + +- Add missing documentation about the fact the async API is in beta state. + +Issues Resolved +............... + +See the `PyMongo 4.9.1 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.9.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40720 + + Changes in Version 4.9.0 ------------------------- From af23139b4ab7aeba5da71b571809cac6474391a1 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 3 Oct 2024 10:27:22 -0700 Subject: [PATCH 08/10] PYTHON-4805 Migrate test_connections_survive_primary_stepdown_spec.py to async (#1889) --- test/asynchronous/helpers.py | 11 ++ ...nnections_survive_primary_stepdown_spec.py | 148 ++++++++++++++++++ test/helpers.py | 11 ++ ...nnections_survive_primary_stepdown_spec.py | 10 +- test/utils.py | 48 ++++-- tools/synchro.py | 3 + 6 files changed, 217 insertions(+), 14 deletions(-) create mode 100644 test/asynchronous/test_connections_survive_primary_stepdown_spec.py diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 46f66af62d..b5fc5d8ac4 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -42,6 +42,7 @@ from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +async def async_repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + await client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + await client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py new file mode 100644 index 0000000000..289cf49751 --- /dev/null +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -0,0 +1,148 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test compliance with the connections survive primary step down spec.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import async_repl_set_step_down +from test.utils import ( + CMAPListener, + async_ensure_all_connected, +) + +from bson import SON +from pymongo import monitoring +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.errors import NotPrimaryError +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): + listener: CMAPListener + coll: AsyncCollection + + @classmethod + @async_client_context.require_replica_set + async def _setup_class(cls): + await super()._setup_class() + cls.listener = CMAPListener() + cls.client = await cls.unmanaged_async_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + ) + + # Ensure connections to all servers in replica set. This is to test + # that the is_writable flag is properly updated for connections that + # survive a replica set election. + await async_ensure_all_connected(cls.client) + cls.listener.reset() + + cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) + cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() + + async def asyncSetUp(self): + # Note that all ops use same write-concern as self.db (majority). + await self.db.drop_collection("step-down") + await self.db.create_collection("step-down") + self.listener.reset() + + async def set_fail_point(self, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await self.client.admin.command(cmd) + + def verify_pool_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 1) + + def verify_pool_not_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 0) + + @async_client_context.require_version_min(4, 2, -1) + async def test_get_more_iteration(self): + # Insert 5 documents with WC majority. + await self.coll.insert_many([{"data": k} for k in range(5)]) + # Start a find operation and retrieve first batch of results. + batch_size = 2 + cursor = self.coll.find(batch_size=batch_size) + for _ in range(batch_size): + await cursor.next() + # Force step-down the primary. + await async_repl_set_step_down(self.client, replSetStepDown=5, force=True) + # Get await anext batch of results. + for _ in range(batch_size): + await cursor.next() + # Verify pool not cleared. + self.verify_pool_not_cleared() + # Attempt insertion to mark server description as stale and prevent a + # NotPrimaryError on the subsequent operation. + try: + await self.coll.insert_one({}) + except NotPrimaryError: + pass + # Next insert should succeed on the new primary without clearing pool. + await self.coll.insert_one({}) + self.verify_pool_not_cleared() + + async def run_scenario(self, error_code, retry, pool_status_checker): + # Set fail point. + await self.set_fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": error_code}} + ) + self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) + # Insert record and verify failure. + with self.assertRaises(NotPrimaryError) as exc: + await self.coll.insert_one({"test": 1}) + self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload] + # Retry before CMAPListener assertion if retry_before=True. + if retry: + await self.coll.insert_one({"test": 1}) + # Verify pool cleared/not cleared. + pool_status_checker() + # Always retry here to ensure discovery of new primary. + await self.coll.insert_one({"test": 1}) + + @async_client_context.require_version_min(4, 2, -1) + @async_client_context.require_test_commands + async def test_not_primary_keep_connection_pool(self): + await self.run_scenario(10107, True, self.verify_pool_not_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_version_max(4, 1, 0, -1) + @async_client_context.require_test_commands + async def test_not_primary_reset_connection_pool(self): + await self.run_scenario(10107, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_shutdown_in_progress(self): + await self.run_scenario(91, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_interrupted_at_shutdown(self): + await self.run_scenario(11600, False, self.verify_pool_cleared) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/helpers.py b/test/helpers.py index bf6186d1a0..11d5ab0374 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -42,6 +42,7 @@ from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +def repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index fba7675743..54cc4e0482 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -20,10 +20,10 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest +from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, ensure_all_connected, - repl_set_step_down, ) from bson import SON @@ -32,6 +32,8 @@ from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener @@ -39,8 +41,8 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @client_context.require_replica_set - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.listener = CMAPListener() cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 @@ -56,7 +58,7 @@ def setUpClass(cls): cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.client.close() def setUp(self): diff --git a/test/utils.py b/test/utils.py index 9615034899..9c78cff3ad 100644 --- a/test/utils.py +++ b/test/utils.py @@ -599,6 +599,44 @@ def discover(): ) +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + def one(s): """Get one element of a set""" return next(iter(s)) @@ -761,16 +799,6 @@ async def async_wait_until(predicate, success_description, timeout=10): await asyncio.sleep(interval) -def repl_set_step_down(client, **kwargs): - """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" - cmd = SON([("replSetStepDown", 1)]) - cmd.update(kwargs) - - # Unfreeze a secondary to ensure a speedy election. - client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) - client.admin.command(cmd) - - def is_mongos(client): res = client.admin.command(HelloCompat.LEGACY_CMD) return res.get("msg", "") == "isdbgrid" diff --git a/tools/synchro.py b/tools/synchro.py index 3333b0de2e..d8ec9ae46f 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", "async_set_fail_point": "set_fail_point", + "async_ensure_all_connected": "ensure_all_connected", + "async_repl_set_step_down": "repl_set_step_down", } docstring_replacements: dict[tuple[str, str], str] = { @@ -186,6 +188,7 @@ def async_only_test(f: str) -> bool: "test_client_bulk_write.py", "test_client_context.py", "test_collection.py", + "test_connections_survive_primary_stepdown_spec.py", "test_cursor.py", "test_database.py", "test_encryption.py", From 7380097dbca42580f9547bbd632f1efe96afc460 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 3 Oct 2024 13:39:04 -0400 Subject: [PATCH 09/10] PYTHON-3959 - NULL Initialize PyObjects (#1859) --- bson/_cbsonmodule.c | 24 ++++++++++++------------ pymongo/_cmessagemodule.c | 34 +++++++++++++++++----------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 34b407b940..223c392280 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -207,7 +207,7 @@ static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) { * * Returns a new ref */ static PyObject* _error(char* name) { - PyObject* error; + PyObject* error = NULL; PyObject* errors = PyImport_ImportModule("bson.errors"); if (!errors) { return NULL; @@ -279,7 +279,7 @@ static PyObject* datetime_from_millis(long long millis) { * micros = diff * 1000 111000 * Resulting in datetime(1, 1, 1, 1, 1, 1, 111000) -- the expected result */ - PyObject* datetime; + PyObject* datetime = NULL; int diff = (int)(((millis % 1000) + 1000) % 1000); int microseconds = diff * 1000; Time64_T seconds = (millis - diff) / 1000; @@ -294,7 +294,7 @@ static PyObject* datetime_from_millis(long long millis) { timeinfo.tm_sec, microseconds); if(!datetime) { - PyObject *etype, *evalue, *etrace; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; /* * Calling _error clears the error state, so fetch it first. @@ -350,8 +350,8 @@ static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){ return NULL; } - PyObject* dt; - PyObject* ll_millis; + PyObject* dt = NULL; + PyObject* ll_millis = NULL; if (!(ll_millis = PyLong_FromLongLong(millis))){ return NULL; @@ -1790,7 +1790,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) { PyObject* result; unsigned char check_keys; unsigned char top_level = 1; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; PyObject* raw_bson_document_bytes_obj; @@ -2512,8 +2512,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, * Wrap any non-InvalidBSON errors in InvalidBSON. */ if (PyErr_Occurred()) { - PyObject *etype, *evalue, *etrace; - PyObject *InvalidBSON; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject *InvalidBSON = NULL; /* * Calling _error clears the error state, so fetch it first. @@ -2585,8 +2585,8 @@ static int _element_to_dict(PyObject* self, const char* string, if (!*name) { /* If NULL is returned then wrap the UnicodeDecodeError in an InvalidBSON error */ - PyObject *etype, *evalue, *etrace; - PyObject *InvalidBSON; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject *InvalidBSON = NULL; PyErr_Fetch(&etype, &evalue, &etrace); if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { @@ -2620,7 +2620,7 @@ static PyObject* _cbson_element_to_dict(PyObject* self, PyObject* args) { /* TODO: Support buffer protocol */ char* string; PyObject* bson; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; unsigned position; unsigned max; @@ -2732,7 +2732,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { int32_t size; Py_ssize_t total_size; const char* string; - PyObject* bson; + PyObject* bson = NULL; codec_options_t options; PyObject* result = NULL; PyObject* options_obj; diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index f95b949380..b5adbeec32 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -45,7 +45,7 @@ struct module_state { * * Returns a new ref */ static PyObject* _error(char* name) { - PyObject* error; + PyObject* error = NULL; PyObject* errors = PyImport_ImportModule("pymongo.errors"); if (!errors) { return NULL; @@ -75,9 +75,9 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) { int begin, cur_size, max_size = 0; int num_to_skip; int num_to_return; - PyObject* query; - PyObject* field_selector; - PyObject* options_obj; + PyObject* query = NULL; + PyObject* field_selector = NULL; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; @@ -221,12 +221,12 @@ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ int request_id = rand(); unsigned int flags; - PyObject* command; + PyObject* command = NULL; char* identifier = NULL; Py_ssize_t identifier_length = 0; - PyObject* docs; - PyObject* doc; - PyObject* options_obj; + PyObject* docs = NULL; + PyObject* doc = NULL; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; @@ -535,12 +535,12 @@ static PyObject* _cbson_encode_batched_op_msg(PyObject* self, PyObject* args) { unsigned char op; unsigned char ack; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); @@ -592,12 +592,12 @@ _cbson_batched_op_msg(PyObject* self, PyObject* args) { unsigned char ack; int request_id; int position; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); @@ -868,12 +868,12 @@ _cbson_encode_batched_write_command(PyObject* self, PyObject* args) { char *ns = NULL; unsigned char op; Py_ssize_t ns_len; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); From b111cbf5d5dab906a94d2c4b2a209cfde2971a94 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 3 Oct 2024 15:18:33 -0400 Subject: [PATCH 10/10] PYTHON-4636 - Avoid blocking I/O calls in async code paths (#1870) Co-authored-by: Shane Harvey --- pymongo/asynchronous/network.py | 81 +---------- pymongo/network_layer.py | 230 +++++++++++++++++++++++++++++-- pymongo/pyopenssl_context.py | 13 +- pymongo/synchronous/network.py | 77 +---------- test/asynchronous/test_client.py | 6 +- test/test_client.py | 6 +- tools/synchro.py | 1 + 7 files changed, 248 insertions(+), 166 deletions(-) diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 44a63a2fc3..d17aead120 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -15,11 +15,8 @@ """Internal network layer helper methods.""" from __future__ import annotations -import asyncio import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -40,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + async_receive_data, async_sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -318,9 +312,7 @@ async def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER( - await _receive_data_on_socket(conn, 16, deadline) - ) + length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -336,11 +328,11 @@ async def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await _receive_data_on_socket(conn, 9, deadline) + await async_receive_data(conn, 9, deadline) ) - data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id) + data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) else: - data = await _receive_data_on_socket(conn, length - 16, deadline) + data = await async_receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -349,66 +341,3 @@ async def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - await asyncio.sleep(0) - - -async def _receive_data_on_socket( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - await wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 82a6228acc..4b57620d83 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -16,15 +16,21 @@ from __future__ import annotations import asyncio +import errno import socket import struct import sys +import time from asyncio import AbstractEventLoop, Future from typing import ( + TYPE_CHECKING, + Optional, Union, ) -from pymongo import ssl_support +from pymongo import _csot, ssl_support +from pymongo.errors import _OperationCancelled +from pymongo.socket_checker import _errno_from_exception try: from ssl import SSLError, SSLSocket @@ -51,6 +57,10 @@ BLOCKING_IO_WRITE_ERROR, ) +if TYPE_CHECKING: + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.synchronous.pool import Connection + _UNPACK_HEADER = struct.Struct(" None: view = memoryview(buf) - fd = sock.fileno() sent = 0 def _is_ready(fut: Future) -> None: - loop.remove_writer(fd) - loop.remove_reader(fd) if fut.done(): return fut.set_result(None) @@ -101,33 +108,240 @@ def _is_ready(fut: Future) -> None: if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + + async def _async_receive_ssl( + conn: _sslConn, length: int, loop: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + + def _is_ready(fut: Future) -> None: + if fut.done(): + return + fut.set_result(None) + + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + total_read += read + except BLOCKING_IO_ERRORS as exc: + fd = conn.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_reader(fd) + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + return mv + else: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support + # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. async def _async_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) total_length = len(buf) total_sent = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 while total_sent < total_length: try: sent = sock.send(view[total_sent:]) except BLOCKING_IO_ERRORS: - await asyncio.sleep(0.5) + await asyncio.sleep(backoff) sent = 0 + if sent > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) total_sent += sent + async def _async_receive_ssl( + conn: _sslConn, length: int, dummy: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + except BLOCKING_IO_ERRORS: + await asyncio.sleep(backoff) + read = 0 + if read > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) + total_read += read + return mv + def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) + + +async def _poll_cancellation(conn: AsyncConnection) -> None: + while True: + if conn.cancel_context.cancelled: + return + + await asyncio.sleep(_POLL_TIMEOUT) + + +async def async_receive_data( + conn: AsyncConnection, length: int, deadline: Optional[float] +) -> memoryview: + sock = conn.conn + sock_timeout = sock.gettimeout() + timeout: Optional[Union[float, int]] + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + else: + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + cancellation_task = asyncio.create_task(_poll_cancellation(conn)) + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + else: + read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + tasks = [read_task, cancellation_task] + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + await asyncio.wait(pending) + if len(done) == 0: + raise socket.timeout("timed out") + if read_task in done: + return read_task.result() + raise _OperationCancelled("operation cancelled") + finally: + sock.settimeout(sock_timeout) + + +async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: + mv = memoryview(bytearray(length)) + bytes_read = 0 + while bytes_read < length: + chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) + if chunk_length == 0: + raise OSError("connection closed") + bytes_read += chunk_length + return mv + + +# Sync version: +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + +def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: + buf = bytearray(length) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < length: + try: + wait_for_read(conn, deadline) + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + except OSError as exc: + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + + return mv diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 4f6f6f4a89..50d8680a74 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -105,13 +105,19 @@ def _ragged_eof(exc: BaseException) -> bool: # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__( - self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool + self, + ctx: _SSL.Context, + sock: Optional[_socket.socket], + suppress_ragged_eofs: bool, + is_async: bool = False, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) + self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: + is_async = kwargs.pop("allow_async", True) and self._is_async timeout = self.gettimeout() if timeout: start = _time.monotonic() @@ -119,6 +125,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: + if is_async: + raise exc # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: @@ -139,6 +147,7 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: + kwargs["allow_async"] = False return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args: Any, **kwargs: Any) -> bytes: @@ -381,7 +390,7 @@ async def a_wrap_socket( """Wrap an existing Python socket connection and return a TLS socket object. """ - ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) + ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True) loop = asyncio.get_running_loop() if session: ssl_conn.set_session(session) diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index c1978087a9..7206dca735 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,9 +16,7 @@ from __future__ import annotations import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -39,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + receive_data, sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -317,7 +312,7 @@ def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -332,12 +327,10 @@ def receive_message( f"message size ({max_message_size!r})" ) if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(conn, 9, deadline) - ) - data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(conn, length - 16, deadline) + data = receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -346,63 +339,3 @@ def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - -def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5c06331790..2052d1cd7f 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1713,6 +1713,7 @@ def compression_settings(client): # No error await client.pymongo_test.test.find_one() + @async_client_context.require_sync async def test_reset_during_update_pool(self): client = await self.async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") @@ -1737,10 +1738,7 @@ async def _run(self): await asyncio.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() diff --git a/test/test_client.py b/test/test_client.py index c88a8fd9b4..936c38b8c6 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1671,6 +1671,7 @@ def compression_settings(client): # No error client.pymongo_test.test.find_one() + @client_context.require_sync def test_reset_during_update_pool(self): client = self.rs_or_single_client(minPoolSize=10) client.admin.command("ping") @@ -1695,10 +1696,7 @@ def _run(self): time.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() diff --git a/tools/synchro.py b/tools/synchro.py index d8ec9ae46f..585fc5fefd 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -43,6 +43,7 @@ "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", + "async_receive_data": "receive_data", "async_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous",