From 0f84ad6ed98900ba43f0942c9a06ce0f0b073559 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Fri, 20 Sep 2024 10:06:03 -0700 Subject: [PATCH 01/13] PYTHON-4769 Avoid pytest collection overhead when running perf benchmarks (#1869) --- .evergreen/run-tests.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 66df6b26ca..8d7a9f082a 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -224,6 +224,9 @@ if [ -n "$PERF_TEST" ]; then python -m pip install simplejson start_time=$(date +%s) TEST_SUITES="perf" + # PYTHON-4769 Run perf_test.py directly otherwise pytest's test collection negatively + # affects the benchmark results. + TEST_ARGS="test/performance/perf_test.py $TEST_ARGS" fi echo "Running $AUTH tests over $SSL with python $(which python)" From e03f8f24f2387882fcaa5d3099d2cef7ae100816 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 20 Sep 2024 16:50:59 -0500 Subject: [PATCH 02/13] PYTHON-4781 Handle errors on Async PyMongo import (#1873) --- pymongo/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 7ee177bdae..8116788bc3 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -88,7 +88,6 @@ from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple -from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION from pymongo.cursor import CursorType from pymongo.operations import ( @@ -105,6 +104,14 @@ from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +try: + from pymongo.asynchronous.mongo_client import AsyncMongoClient +except Exception as e: + # PYTHON-4781: Importing asyncio can fail on Windows. + import warnings as _warnings + + _warnings.warn(f"Failed to import Async PyMongo: {e!r}", ImportWarning, stacklevel=2) + version = __version__ """Current version of PyMongo.""" From 7742b7f24fd4a16f22d620471cfca5f88cf0b628 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 30 Sep 2024 14:14:12 -0400 Subject: [PATCH 03/13] PYTHON-4797 - Convert test.test_raw_bson to async (#1882) --- test/asynchronous/test_raw_bson.py | 219 +++++++++++++++++++++++++++++ test/test_raw_bson.py | 5 +- tools/synchro.py | 5 +- 3 files changed, 225 insertions(+), 4 deletions(-) create mode 100644 test/asynchronous/test_raw_bson.py diff --git a/test/asynchronous/test_raw_bson.py b/test/asynchronous/test_raw_bson.py new file mode 100644 index 0000000000..70832ea668 --- /dev/null +++ b/test/asynchronous/test_raw_bson.py @@ -0,0 +1,219 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +import sys +import uuid + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest + +from bson import Code, DBRef, decode, encode +from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation +from bson.codec_options import CodecOptions +from bson.errors import InvalidBSON +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument +from bson.son import SON + +_IS_SYNC = False + + +class TestRawBSONDocument(AsyncIntegrationTest): + # {'_id': ObjectId('556df68b6e32ab21a95e0785'), + # 'name': 'Sherlock', + # 'addresses': [{'street': 'Baker Street'}]} + bson_string = ( + b"Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t" + b"\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e" + b"\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00" + ) + document = RawBSONDocument(bson_string) + + async def asyncTearDown(self): + if async_client_context.connected: + await self.client.pymongo_test.test_raw.drop() + + def test_decode(self): + self.assertEqual("Sherlock", self.document["name"]) + first_address = self.document["addresses"][0] + self.assertIsInstance(first_address, RawBSONDocument) + self.assertEqual("Baker Street", first_address["street"]) + + def test_raw(self): + self.assertEqual(self.bson_string, self.document.raw) + + def test_empty_doc(self): + doc = RawBSONDocument(encode({})) + with self.assertRaises(KeyError): + doc["does-not-exist"] + + def test_invalid_bson_sequence(self): + bson_byte_sequence = encode({"a": 1}) + encode({}) + with self.assertRaisesRegex(InvalidBSON, "invalid object length"): + RawBSONDocument(bson_byte_sequence) + + def test_invalid_bson_eoo(self): + invalid_bson_eoo = encode({"a": 1})[:-1] + b"\x01" + with self.assertRaisesRegex(InvalidBSON, "bad eoo"): + RawBSONDocument(invalid_bson_eoo) + + @async_client_context.require_connection + async def test_round_trip(self): + db = self.client.get_database( + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ) + await db.test_raw.insert_one(self.document) + result = await db.test_raw.find_one(self.document["_id"]) + assert result is not None + self.assertIsInstance(result, RawBSONDocument) + self.assertEqual(dict(self.document.items()), dict(result.items())) + + @async_client_context.require_connection + async def test_round_trip_raw_uuid(self): + coll = self.client.get_database("pymongo_test").test_raw + uid = uuid.uuid4() + doc = {"_id": 1, "bin4": Binary(uid.bytes, 4), "bin3": Binary(uid.bytes, 3)} + raw = RawBSONDocument(encode(doc)) + await coll.insert_one(raw) + self.assertEqual(await coll.find_one(), doc) + uuid_coll = coll.with_options( + codec_options=coll.codec_options.with_options( + uuid_representation=UuidRepresentation.STANDARD + ) + ) + self.assertEqual( + await uuid_coll.find_one(), {"_id": 1, "bin4": uid, "bin3": Binary(uid.bytes, 3)} + ) + + # Test that the raw bytes haven't changed. + raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS) + self.assertEqual(await raw_coll.find_one(), raw) + + def test_with_codec_options(self): + # {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + # '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} + # encoded with JAVA_LEGACY uuid representation. + bson_string = ( + b"-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02" + b"\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM" + b"\x01\x00\x00\x00" + ) + document = RawBSONDocument( + bson_string, + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) + + self.assertEqual(uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), document["_id"]) + + @async_client_context.require_connection + async def test_round_trip_codec_options(self): + doc = { + "date": datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + "_id": uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), + } + db = self.client.pymongo_test + coll = db.get_collection( + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ) + await coll.insert_one(doc) + raw_java_legacy = CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ) + coll = db.get_collection("test_raw", codec_options=raw_java_legacy) + self.assertEqual( + RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), await coll.find_one() + ) + + @async_client_context.require_connection + async def test_raw_bson_document_embedded(self): + doc = {"embedded": self.document} + db = self.client.pymongo_test + await db.test_raw.insert_one(doc) + result = await db.test_raw.find_one() + assert result is not None + self.assertEqual(decode(self.document.raw), result["embedded"]) + + # Make sure that CodecOptions are preserved. + # {'embedded': [ + # {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + # '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} + # ]} + # encoded with JAVA_LEGACY uuid representation. + bson_string = ( + b"D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00" + b"\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00" + b"\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00" + b"\x00" + ) + rbd = RawBSONDocument( + bson_string, + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) + + await db.test_raw.drop() + await db.test_raw.insert_one(rbd) + result = await db.get_collection( + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ).find_one() + assert result is not None + self.assertEqual(rbd["embedded"][0]["_id"], result["embedded"][0]["_id"]) + + @async_client_context.require_connection + async def test_write_response_raw_bson(self): + coll = self.client.get_database( + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ).test_raw + + # No Exceptions raised while handling write response. + await coll.insert_one(self.document) + await coll.delete_one(self.document) + await coll.insert_many([self.document]) + await coll.delete_many(self.document) + await coll.update_one(self.document, {"$set": {"a": "b"}}, upsert=True) + await coll.update_many(self.document, {"$set": {"b": "c"}}) + + def test_preserve_key_ordering(self): + keyvaluepairs = [ + ("a", 1), + ("b", 2), + ("c", 3), + ] + rawdoc = RawBSONDocument(encode(SON(keyvaluepairs))) + + for rkey, elt in zip(rawdoc, keyvaluepairs): + self.assertEqual(rkey, elt[0]) + + def test_contains_code_with_scope(self): + doc = RawBSONDocument(encode({"value": Code("x=1", scope={})})) + + self.assertEqual(decode(encode(doc)), {"value": Code("x=1", {})}) + self.assertEqual(doc["value"].scope, RawBSONDocument(encode({}))) + + def test_contains_dbref(self): + doc = RawBSONDocument(encode({"value": DBRef("test", "id")})) + raw = {"$ref": "test", "$id": "id"} + raw_encoded = encode(decode(encode(raw))) + + self.assertEqual(decode(encode(doc)), {"value": DBRef("test", "id")}) + self.assertEqual(doc["value"].raw, raw_encoded) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 11bc80dd9f..4d9a3ceb05 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -19,8 +19,7 @@ sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import IntegrationTest, client_context, unittest from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation @@ -29,6 +28,8 @@ from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from bson.son import SON +_IS_SYNC = True + class TestRawBSONDocument(IntegrationTest): # {'_id': ObjectId('556df68b6e32ab21a95e0785'), diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e5..0eca24b2cf 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -171,16 +171,17 @@ "test_change_stream.py", "test_client.py", "test_client_bulk_write.py", + "test_client_context.py", "test_collection.py", "test_cursor.py", "test_database.py", "test_encryption.py", "test_grid_file.py", "test_logger.py", + "test_monitoring.py", + "test_raw_bson.py", "test_session.py", "test_transactions.py", - "test_client_context.py", - "test_monitoring.py", ] sync_test_files = [ From 1e395de9c51aab501ed14bc994e88d96b773961a Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:17:57 -0700 Subject: [PATCH 04/13] PYTHON-4737 Migrate test_binary.py to async (#1863) --- test/asynchronous/test_client.py | 75 ++++++++++++++- test/test_binary.py | 156 ++++++++----------------------- test/test_client.py | 75 ++++++++++++++- 3 files changed, 188 insertions(+), 118 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f610f32779..1926ad74d2 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -17,6 +17,7 @@ import _thread as thread import asyncio +import base64 import contextlib import copy import datetime @@ -31,13 +32,15 @@ import sys import threading import time -from typing import Iterable, Type, no_type_check +import uuid +from typing import Any, Iterable, Type, no_type_check from unittest import mock from unittest.mock import patch import pytest import pytest_asyncio +from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation from pymongo.operations import _Op sys.path[0:0] = [""] @@ -57,6 +60,7 @@ unittest, ) from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, @@ -2020,6 +2024,75 @@ def test_dict_hints_sort(self): async def test_dict_hints_create_index(self): await self.db.t.create_index({"x": pymongo.ASCENDING}) + async def test_legacy_java_uuid_roundtrip(self): + data = BinaryData.java_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) + + await async_client_context.client.pymongo_test.drop_collection("java_uuid") + db = async_client_context.client.pymongo_test + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) + + await coll.insert_many(docs) + self.assertEqual(5, await coll.count_documents({})) + async for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + async for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + await async_client_context.client.pymongo_test.drop_collection("java_uuid") + + async def test_legacy_csharp_uuid_roundtrip(self): + data = BinaryData.csharp_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) + + await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") + db = async_client_context.client.pymongo_test + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) + + await coll.insert_many(docs) + self.assertEqual(5, await coll.count_documents({})) + async for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + async for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") + + async def test_uri_to_uuid(self): + uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" + client = await self.async_single_client(uri, connect=False) + self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + + async def test_uuid_queries(self): + db = async_client_context.client.pymongo_test + coll = db.test + await coll.drop() + + uu = uuid.uuid4() + await coll.insert_one({"uuid": Binary(uu.bytes, 3)}) + self.assertEqual(1, await coll.count_documents({})) + + # Test regular UUID queries (using subtype 4). + coll = db.get_collection( + "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + ) + self.assertEqual(0, await coll.count_documents({"uuid": uu})) + await coll.insert_one({"uuid": uu}) + self.assertEqual(2, await coll.count_documents({})) + docs = await coll.find({"uuid": uu}).to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(uu, docs[0]["uuid"]) + + # Test both. + uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) + predicate = {"uuid": {"$in": [uu, uu_legacy]}} + self.assertEqual(2, await coll.count_documents(predicate)) + docs = await coll.find(predicate).to_list() + self.assertEqual(2, len(docs)) + await coll.drop() + class TestExhaustCursor(AsyncIntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" diff --git a/test/test_binary.py b/test/test_binary.py index 93f6d08315..567c5ae92f 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -34,53 +34,49 @@ from bson.codec_options import CodecOptions from bson.son import SON from pymongo.common import validate_uuid_representation -from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern -class TestBinary(unittest.TestCase): - csharp_data: bytes - java_data: bytes +class BinaryData: + # Generated by the Java driver + from_java = ( + b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu" + b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND" + b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+" + b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1" + b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR" + b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA" + b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z" + b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf" + b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx" + b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My" + b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB" + b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp" + b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc" + b"0MQAA" + ) + java_data = base64.b64decode(from_java) + + # Generated by the .net driver + from_csharp = ( + b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl" + b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2" + b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V" + b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl" + b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A" + b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z" + b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU" + b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn" + b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA" + b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT" + b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP" + b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00" + b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=" + ) + csharp_data = base64.b64decode(from_csharp) - @classmethod - def setUpClass(cls): - # Generated by the Java driver - from_java = ( - b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu" - b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND" - b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+" - b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1" - b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR" - b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA" - b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z" - b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf" - b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx" - b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My" - b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB" - b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp" - b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc" - b"0MQAA" - ) - cls.java_data = base64.b64decode(from_java) - - # Generated by the .net driver - from_csharp = ( - b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl" - b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2" - b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V" - b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl" - b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A" - b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z" - b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU" - b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn" - b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA" - b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT" - b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP" - b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00" - b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=" - ) - cls.csharp_data = base64.b64decode(from_csharp) +class TestBinary(unittest.TestCase): def test_binary(self): a_string = "hello world" a_binary = Binary(b"hello world") @@ -159,7 +155,7 @@ def test_uuid_subtype_4(self): def test_legacy_java_uuid(self): # Test decoding - data = self.java_data + data = BinaryData.java_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY)) for d in docs: self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) @@ -197,27 +193,8 @@ def test_legacy_java_uuid(self): ) self.assertEqual(data, encoded) - @client_context.require_connection - def test_legacy_java_uuid_roundtrip(self): - data = self.java_data - docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) - - client_context.client.pymongo_test.drop_collection("java_uuid") - db = client_context.client.pymongo_test - coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) - - coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) - for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - - coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("java_uuid") - def test_legacy_csharp_uuid(self): - data = self.csharp_data + data = BinaryData.csharp_data # Test decoding docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY)) @@ -257,59 +234,6 @@ def test_legacy_csharp_uuid(self): ) self.assertEqual(data, encoded) - @client_context.require_connection - def test_legacy_csharp_uuid_roundtrip(self): - data = self.csharp_data - docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) - - client_context.client.pymongo_test.drop_collection("csharp_uuid") - db = client_context.client.pymongo_test - coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) - - coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) - for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - - coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("csharp_uuid") - - def test_uri_to_uuid(self): - uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" - client = MongoClient(uri, connect=False) - self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) - - @client_context.require_connection - def test_uuid_queries(self): - db = client_context.client.pymongo_test - coll = db.test - coll.drop() - - uu = uuid.uuid4() - coll.insert_one({"uuid": Binary(uu.bytes, 3)}) - self.assertEqual(1, coll.count_documents({})) - - # Test regular UUID queries (using subtype 4). - coll = db.get_collection( - "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - ) - self.assertEqual(0, coll.count_documents({"uuid": uu})) - coll.insert_one({"uuid": uu}) - self.assertEqual(2, coll.count_documents({})) - docs = list(coll.find({"uuid": uu})) - self.assertEqual(1, len(docs)) - self.assertEqual(uu, docs[0]["uuid"]) - - # Test both. - uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) - predicate = {"uuid": {"$in": [uu, uu_legacy]}} - self.assertEqual(2, coll.count_documents(predicate)) - docs = list(coll.find(predicate)) - self.assertEqual(2, len(docs)) - coll.drop() - def test_pickle(self): b1 = Binary(b"123", 2) diff --git a/test/test_client.py b/test/test_client.py index bc45325f0b..2642a87fdf 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -17,6 +17,7 @@ import _thread as thread import asyncio +import base64 import contextlib import copy import datetime @@ -31,12 +32,14 @@ import sys import threading import time -from typing import Iterable, Type, no_type_check +import uuid +from typing import Any, Iterable, Type, no_type_check from unittest import mock from unittest.mock import patch import pytest +from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation from pymongo.operations import _Op sys.path[0:0] = [""] @@ -56,6 +59,7 @@ unittest, ) from test.pymongo_mocks import MockClient +from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, @@ -1978,6 +1982,75 @@ def test_dict_hints_sort(self): def test_dict_hints_create_index(self): self.db.t.create_index({"x": pymongo.ASCENDING}) + def test_legacy_java_uuid_roundtrip(self): + data = BinaryData.java_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) + + client_context.client.pymongo_test.drop_collection("java_uuid") + db = client_context.client.pymongo_test + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) + + coll.insert_many(docs) + self.assertEqual(5, coll.count_documents({})) + for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("java_uuid") + + def test_legacy_csharp_uuid_roundtrip(self): + data = BinaryData.csharp_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) + + client_context.client.pymongo_test.drop_collection("csharp_uuid") + db = client_context.client.pymongo_test + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) + + coll.insert_many(docs) + self.assertEqual(5, coll.count_documents({})) + for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("csharp_uuid") + + def test_uri_to_uuid(self): + uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" + client = self.single_client(uri, connect=False) + self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + + def test_uuid_queries(self): + db = client_context.client.pymongo_test + coll = db.test + coll.drop() + + uu = uuid.uuid4() + coll.insert_one({"uuid": Binary(uu.bytes, 3)}) + self.assertEqual(1, coll.count_documents({})) + + # Test regular UUID queries (using subtype 4). + coll = db.get_collection( + "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + ) + self.assertEqual(0, coll.count_documents({"uuid": uu})) + coll.insert_one({"uuid": uu}) + self.assertEqual(2, coll.count_documents({})) + docs = coll.find({"uuid": uu}).to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(uu, docs[0]["uuid"]) + + # Test both. + uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) + predicate = {"uuid": {"$in": [uu, uu_legacy]}} + self.assertEqual(2, coll.count_documents(predicate)) + docs = coll.find(predicate).to_list() + self.assertEqual(2, len(docs)) + coll.drop() + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" From 3ef565fa43734dfef6bdbb7458b41a4d71451cb1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 30 Sep 2024 18:01:53 -0500 Subject: [PATCH 05/13] PYTHON-4796 Update type checkers and handle with_options typing (#1880) --- bson/__init__.py | 4 ++-- bson/decimal128.py | 2 +- bson/json_util.py | 2 +- bson/son.py | 2 +- hatch.toml | 5 +++-- pymongo/_csot.py | 3 +-- pymongo/asynchronous/collection.py | 23 +++++++++++++++++++- pymongo/asynchronous/database.py | 22 ++++++++++++++++++- pymongo/asynchronous/pool.py | 2 +- pymongo/common.py | 2 +- pymongo/compression_support.py | 2 +- pymongo/encryption_options.py | 2 +- pymongo/synchronous/collection.py | 23 +++++++++++++++++++- pymongo/synchronous/database.py | 22 ++++++++++++++++++- pymongo/synchronous/pool.py | 2 +- requirements/typing.txt | 7 ++++++ test/asynchronous/test_database.py | 2 +- test/test_database.py | 2 +- test/test_typing.py | 34 +++++++++++++++++++----------- tools/synchro.py | 2 +- 20 files changed, 132 insertions(+), 33 deletions(-) create mode 100644 requirements/typing.txt diff --git a/bson/__init__.py b/bson/__init__.py index e8ac7c4441..e866a99c8d 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1324,7 +1324,7 @@ def decode_iter( elements = data[position : position + obj_size] position += obj_size - yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var] + yield _bson_to_dict(elements, opts) # type:ignore[misc] @overload @@ -1370,7 +1370,7 @@ def decode_file_iter( raise InvalidBSON("cut off in middle of objsize") obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4 elements = size_data + file_obj.read(max(0, obj_size)) - yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc] + yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc] def is_valid(bson: bytes) -> bool: diff --git a/bson/decimal128.py b/bson/decimal128.py index 8581d5a3c8..016afb5eb8 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -223,7 +223,7 @@ def __init__(self, value: _VALUE_OPTIONS) -> None: "from list or tuple. Must have exactly 2 " "elements." ) - self.__high, self.__low = value # type: ignore + self.__high, self.__low = value else: raise TypeError(f"Cannot convert {value!r} to Decimal128") diff --git a/bson/json_util.py b/bson/json_util.py index 4269ba9858..6f34e4103d 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -324,7 +324,7 @@ def __new__( "JSONOptions.datetime_representation must be one of LEGACY, " "NUMBERLONG, or ISO8601 from DatetimeRepresentation." ) - self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type] + self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): raise ValueError( "JSONOptions.json_mode must be one of LEGACY, RELAXED, " diff --git a/bson/son.py b/bson/son.py index cf62717238..24275fce16 100644 --- a/bson/son.py +++ b/bson/son.py @@ -68,7 +68,7 @@ def __init__( self.update(kwargs) def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]: - instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var] + instance = super().__new__(cls, *args, **kwargs) instance.__keys = [] return instance diff --git a/hatch.toml b/hatch.toml index 8b1cf93e32..d5293a1d7f 100644 --- a/hatch.toml +++ b/hatch.toml @@ -13,8 +13,9 @@ features = ["docs","test"] test = "sphinx-build -E -b doctest doc ./doc/_build/doctest" [envs.typing] -features = ["encryption", "ocsp", "zstd", "aws"] -dependencies = ["mypy==1.2.0","pyright==1.1.290", "certifi", "typing_extensions"] +pre-install-commands = [ + "pip install -q -r requirements/typing.txt", +] [envs.typing.scripts] check-mypy = [ "mypy --install-types --non-interactive bson gridfs tools pymongo", diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 94328f9819..06c6b68ac9 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -75,14 +75,13 @@ def __init__(self, timeout: Optional[float]): self._timeout = timeout self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None - def __enter__(self) -> _TimeoutContext: + def __enter__(self) -> None: timeout_token = TIMEOUT.set(self._timeout) prev_deadline = DEADLINE.get() next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) rtt_token = RTT.set(0.0) self._tokens = (timeout_token, deadline_token, rtt_token) - return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._tokens: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 1ec74aad02..5abc41a7e0 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -35,6 +35,7 @@ TypeVar, Union, cast, + overload, ) from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions @@ -332,13 +333,33 @@ def database(self) -> AsyncDatabase[_DocumentType]: """ return self._database + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncCollection[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncCollection[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> AsyncCollection[_DocumentType]: + ) -> AsyncCollection[_DocumentType] | AsyncCollection[_DocumentTypeArg]: """Get a clone of this collection changing the specified settings. >>> coll1.read_preference diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 06c0eca2c1..98a0a6ff3b 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -146,13 +146,33 @@ def name(self) -> str: """The name of this :class:`AsyncDatabase`.""" return self._name + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncDatabase[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncDatabase[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> AsyncDatabase[_DocumentType]: + ) -> AsyncDatabase[_DocumentType] | AsyncDatabase[_DocumentTypeArg]: """Get a clone of this database changing the specified settings. >>> db1.read_preference diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a657042423..442d6c7ed6 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -913,7 +913,7 @@ async def _configured_socket( and not options.tls_allow_invalid_hostnames ): try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] except _CertificateError: ssl_sock.close() raise diff --git a/pymongo/common.py b/pymongo/common.py index a073eba577..fe8fdd8949 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -850,7 +850,7 @@ def get_normed_key(x: str) -> str: return x def get_setter_key(x: str) -> str: - return options.cased_key(x) # type: ignore[attr-defined] + return options.cased_key(x) else: validated_options = {} diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 7123b90dfe..c71e4bddcf 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -26,7 +26,7 @@ def _have_snappy() -> bool: try: - import snappy # type:ignore[import] # noqa: F401 + import snappy # type:ignore[import-not-found] # noqa: F401 return True except ImportError: diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index df13026500..ee749e7ac1 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional try: - import pymongocrypt # type:ignore[import] # noqa: F401 + import pymongocrypt # type:ignore[import-untyped] # noqa: F401 # Check for pymongocrypt>=1.10. from pymongocrypt import synchronous as _ # noqa: F401 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 7a41aef31f..15a1913eaa 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -34,6 +34,7 @@ TypeVar, Union, cast, + overload, ) from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions @@ -333,13 +334,33 @@ def database(self) -> Database[_DocumentType]: """ return self._database + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Collection[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Collection[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> Collection[_DocumentType]: + ) -> Collection[_DocumentType] | Collection[_DocumentTypeArg]: """Get a clone of this collection changing the specified settings. >>> coll1.read_preference diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index c57a59e09a..a0bef55343 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -146,13 +146,33 @@ def name(self) -> str: """The name of this :class:`Database`.""" return self._name + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Database[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Database[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> Database[_DocumentType]: + ) -> Database[_DocumentType] | Database[_DocumentTypeArg]: """Get a clone of this database changing the specified settings. >>> db1.read_preference diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 94a1d10436..1b8b1f1ec9 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -909,7 +909,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. and not options.tls_allow_invalid_hostnames ): try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] except _CertificateError: ssl_sock.close() raise diff --git a/requirements/typing.txt b/requirements/typing.txt new file mode 100644 index 0000000000..1669e6bbc2 --- /dev/null +++ b/requirements/typing.txt @@ -0,0 +1,7 @@ +mypy==1.11.2 +pyright==1.1.382.post1 +typing_extensions +-r ./encryption.txt +-r ./ocsp.txt +-r ./zstd.txt +-r ./aws.txt diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index c5d62323df..61369c8542 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -711,7 +711,7 @@ def test_with_options(self): "write_concern": WriteConcern(w=1), "read_concern": ReadConcern(level="local"), } - db2 = db1.with_options(**newopts) # type: ignore[arg-type] + db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload] for opt in newopts: self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) diff --git a/test/test_database.py b/test/test_database.py index fe07f343c5..4973ed0134 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -702,7 +702,7 @@ def test_with_options(self): "write_concern": WriteConcern(w=1), "read_concern": ReadConcern(level="local"), } - db2 = db1.with_options(**newopts) # type: ignore[arg-type] + db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload] for opt in newopts: self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) diff --git a/test/test_typing.py b/test/test_typing.py index 6cfe40537b..441707616e 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -34,7 +34,7 @@ cast, ) -try: +if TYPE_CHECKING: from typing_extensions import NotRequired, TypedDict from bson import ObjectId @@ -49,16 +49,13 @@ class MovieWithId(TypedDict): year: int class ImplicitMovie(TypedDict): - _id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues] + _id: NotRequired[ObjectId] name: str year: int - -except ImportError: - Movie = dict # type:ignore[misc,assignment] - ImplicitMovie = dict # type: ignore[assignment,misc] - MovieWithId = dict # type: ignore[assignment,misc] - TypedDict = None - NotRequired = None # type: ignore[assignment] +else: + Movie = dict + ImplicitMovie = dict + NotRequired = None try: @@ -234,6 +231,19 @@ def execute_transaction(session): execute_transaction, read_preference=ReadPreference.PRIMARY ) + def test_with_options(self) -> None: + coll: Collection[Dict[str, Any]] = self.coll + coll.drop() + doc = {"name": "foo", "year": 1982, "other": 1} + coll.insert_one(doc) + + coll2 = coll.with_options(codec_options=CodecOptions(document_class=Movie)) + retrieved = coll2.find_one() + assert retrieved is not None + assert retrieved["name"] == "foo" + # We expect a type error here. + assert retrieved["other"] == 1 # type:ignore[typeddict-item] + class TestDecode(unittest.TestCase): def test_bson_decode(self) -> None: @@ -426,7 +436,7 @@ def test_bulk_write_document_type_insertion(self): ) coll.bulk_write( [ - InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore ] # No error because it is in-line. ) @@ -443,7 +453,7 @@ def test_bulk_write_document_type_replacement(self): ) coll.bulk_write( [ - ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore ] # No error because it is in-line. ) @@ -566,7 +576,7 @@ def test_explicit_document_type(self) -> None: def test_typeddict_document_type(self) -> None: options: CodecOptions[Movie] = CodecOptions() # Suppress: Cannot instantiate type "Type[Movie]". - obj = options.document_class(name="a", year=1) # type: ignore[misc] + obj = options.document_class(name="a", year=1) assert obj["year"] == 1 assert obj["name"] == "a" diff --git a/tools/synchro.py b/tools/synchro.py index 0eca24b2cf..86506b7798 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -23,7 +23,7 @@ from os import listdir from pathlib import Path -from unasync import Rule, unasync_files # type: ignore[import] +from unasync import Rule, unasync_files # type: ignore[import-not-found] replacements = { "AsyncCollection": "Collection", From 083359f95f7ce1c202d54a0a16506ae4c5162b23 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Mon, 30 Sep 2024 19:09:57 -0400 Subject: [PATCH 06/13] PYTHON-1714 Add c extension use to client metadata (#1874) --- pymongo/__init__.py | 12 +----------- pymongo/common.py | 10 ++++++++++ pymongo/pool_options.py | 6 ++++++ test/asynchronous/test_client.py | 17 +++++++++++++---- test/test_client.py | 17 +++++++++++++---- tools/synchro.py | 1 + 6 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 8116788bc3..6416f939e8 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -88,7 +88,7 @@ from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple -from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c from pymongo.cursor import CursorType from pymongo.operations import ( DeleteMany, @@ -116,16 +116,6 @@ """Current version of PyMongo.""" -def has_c() -> bool: - """Is the C extension installed?""" - try: - from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401 - - return True - except ImportError: - return False - - def timeout(seconds: Optional[float]) -> ContextManager[None]: """**(Provisional)** Apply the given timeout for a block of operations. diff --git a/pymongo/common.py b/pymongo/common.py index fe8fdd8949..126d0ee46e 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -1060,3 +1060,13 @@ def update(self, other: Mapping[str, Any]) -> None: # type: ignore[override] def cased_key(self, key: str) -> Any: return self.__casedkeys[key.lower()] + + +def has_c() -> bool: + """Is the C extension installed?""" + try: + from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401 + + return True + except ImportError: + return False diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py index 6ec97d7d1b..61486c91c6 100644 --- a/pymongo/pool_options.py +++ b/pymongo/pool_options.py @@ -33,6 +33,7 @@ MAX_POOL_SIZE, MIN_POOL_SIZE, WAIT_QUEUE_TIMEOUT, + has_c, ) if TYPE_CHECKING: @@ -363,6 +364,11 @@ def __init__( # }, # 'platform': 'CPython 3.8.0|MyPlatform' # } + if has_c(): + self.__metadata["driver"]["name"] = "{}|{}".format( + self.__metadata["driver"]["name"], + "c", + ) if not is_sync: self.__metadata["driver"]["name"] = "{}|{}".format( self.__metadata["driver"]["name"], diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 1926ad74d2..b6324d3bac 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -99,7 +99,7 @@ from pymongo.asynchronous.settings import TOPOLOGY_TYPE from pymongo.asynchronous.topology import _ErrorContext from pymongo.client_options import ClientOptions -from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( @@ -347,7 +347,10 @@ async def test_read_preference(self): async def test_metadata(self): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo|async" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async" + else: + metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options @@ -370,7 +373,10 @@ async def test_metadata(self): with self.assertRaises(TypeError): self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. - metadata["driver"]["name"] = "PyMongo|async|FooDriver" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async|FooDriver" + else: + metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) client = self.simple_client( "foo", @@ -1931,7 +1937,10 @@ def test_sigstop_sigcont(self): async def _test_handshake(self, env_vars, expected_env): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo|async" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async" + else: + metadata["driver"]["name"] = "PyMongo|async" if expected_env is not None: metadata["env"] = expected_env diff --git a/test/test_client.py b/test/test_client.py index 2642a87fdf..86b9f41ec9 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -87,7 +87,7 @@ from bson.tz_util import utc from pymongo import event_loggers, message, monitoring from pymongo.client_options import ClientOptions -from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( @@ -339,7 +339,10 @@ def test_read_preference(self): def test_metadata(self): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c" + else: + metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options @@ -362,7 +365,10 @@ def test_metadata(self): with self.assertRaises(TypeError): self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. - metadata["driver"]["name"] = "PyMongo|FooDriver" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|FooDriver" + else: + metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) client = self.simple_client( "foo", @@ -1889,7 +1895,10 @@ def test_sigstop_sigcont(self): def _test_handshake(self, env_vars, expected_env): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c" + else: + metadata["driver"]["name"] = "PyMongo" if expected_env is not None: metadata["env"] = expected_env diff --git a/tools/synchro.py b/tools/synchro.py index 86506b7798..6ce897a0b8 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -101,6 +101,7 @@ "default_async": "default", "aclose": "close", "PyMongo|async": "PyMongo", + "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", } From 821811e80d72d2ae822d11bac30e1c7a935208c2 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 16:24:07 -0700 Subject: [PATCH 07/13] PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait (#1875) --- pymongo/asynchronous/pool.py | 7 +- pymongo/asynchronous/topology.py | 5 +- pymongo/lock.py | 147 +++++- pymongo/synchronous/pool.py | 9 +- pymongo/synchronous/topology.py | 7 +- test/asynchronous/test_client.py | 4 +- test/asynchronous/test_client_bulk_write.py | 5 +- test/asynchronous/test_cursor.py | 4 +- test/asynchronous/test_locks.py | 513 ++++++++++++++++++++ test/test_client.py | 4 +- test/test_client_bulk_write.py | 5 +- test/test_cursor.py | 4 +- test/test_server_selection_in_window.py | 13 +- tools/synchro.py | 14 +- 14 files changed, 693 insertions(+), 48 deletions(-) create mode 100644 test/asynchronous/test_locks.py diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 442d6c7ed6..a9f02d650a 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -992,7 +992,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _ALock(_create_lock()) + _lock = _create_lock() + self.lock = _ALock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1018,7 +1019,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self.size_cond = _ACondition(threading.Condition(_lock)) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1026,7 +1027,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 4e778cbc17..82af4257ba 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _ALock(_create_lock()) - self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _ALock(_lock) + self._condition = _ACondition(self._settings.condition_class(_lock)) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/pymongo/lock.py b/pymongo/lock.py index b05f6acffb..0cbfb4a57e 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -14,17 +14,20 @@ from __future__ import annotations import asyncio +import collections import os import threading import time import weakref -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") # References to instances of _create_lock _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() +_T = TypeVar("_T") + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -43,7 +46,14 @@ def _release_locks() -> None: lock.release() +# Needed only for synchro.py compat. +def _Lock(lock: threading.Lock) -> threading.Lock: + return lock + + class _ALock: + __slots__ = ("_lock",) + def __init__(self, lock: threading.Lock) -> None: self._lock = lock @@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: self.release() +def _safe_set_result(fut: asyncio.Future) -> None: + # Ensure the future hasn't been cancelled before calling set_result. + if not fut.done(): + fut.set_result(False) + + class _ACondition: + __slots__ = ("_condition", "_waiters") + def __init__(self, condition: threading.Condition) -> None: self._condition = condition + self._waiters: collections.deque = collections.deque() async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: if timeout > 0: @@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: await asyncio.sleep(0) async def wait(self, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait(0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False - - async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait_for(predicate, 0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._waiters.append((loop, fut)) + self.release() + try: + try: + try: + await asyncio.wait_for(fut, timeout) + return True + except asyncio.TimeoutError: + return False # Return false on timeout for sync pool compat. + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except asyncio.exceptions.CancelledError as e: + err = e + + self._waiters.remove((loop, fut)) + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self.notify(1) + raise + + async def wait_for(self, predicate: Callable[[], _T]) -> _T: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result def notify(self, n: int = 1) -> None: - self._condition.notify(n) + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + idx = 0 + to_remove = [] + for loop, fut in self._waiters: + if idx >= n: + break + + if fut.done(): + continue + + try: + loop.call_soon_threadsafe(_safe_set_result, fut) + except RuntimeError: + # Loop was closed, ignore. + to_remove.append((loop, fut)) + continue + + idx += 1 + + for waiter in to_remove: + self._waiters.remove(waiter) def notify_all(self) -> None: - self._condition.notify_all() + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + def locked(self) -> bool: + """Only needed for tests in test_locks.""" + return self._condition._lock.locked() # type: ignore[attr-defined] def release(self) -> None: self._condition.release() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1b8b1f1ec9..eb007a3471 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -62,7 +62,7 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -988,7 +988,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _create_lock() + _lock = _create_lock() + self.lock = _Lock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1014,7 +1015,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self.size_cond = threading.Condition(_lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1022,7 +1023,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self._max_connecting_cond = threading.Condition(_lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e8070e30ab..a350c1702e 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,7 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _create_lock() - self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _Lock(_lock) + self._condition = self._settings.condition_class(_lock) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index b6324d3bac..5c06331790 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2433,7 +2433,9 @@ async def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + await (await c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(await c.address, ("a", 1)) async def _test_network_error(self, operation_callback): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 3a17299453..80cfd30bde 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( @@ -597,7 +598,9 @@ async def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - await client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + await client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 33eaacee96..e79ad00641 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1414,7 +1414,7 @@ async def test_to_list_length(self): async def test_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) @@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self): async def test_command_cursor_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py new file mode 100644 index 0000000000..e0e7f2fc8d --- /dev/null +++ b/test/asynchronous/test_locks.py @@ -0,0 +1,513 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lock.py""" +from __future__ import annotations + +import asyncio +import sys +import threading +import unittest + +sys.path[0:0] = [""] + +from pymongo.lock import _ACondition + + +# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# Includes tests for: +# - https://github.com/python/cpython/issues/111693 +# - https://github.com/python/cpython/issues/112202 +class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _ACondition(threading.Condition(threading.Lock())) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = _ACondition(threading.Condition(threading.Lock())) + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = _ACondition(threading.Condition(threading.Lock())) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _ACondition(threading.Condition(threading.Lock())) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _ACondition(threading.Condition(threading.Lock())) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + +class TestCondition(unittest.IsolatedAsyncioTestCase): + async def test_multiple_loops_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + def tmain(cond): + async def atmain(cond): + await asyncio.sleep(1) + async with cond: + cond.notify(1) + + asyncio.run(atmain(cond)) + + t = threading.Thread(target=tmain, args=(cond,)) + t.start() + + async with cond: + self.assertTrue(await cond.wait(30)) + t.join() + + async def test_multiple_loops_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + results = [] + + def tmain(cond, results): + async def atmain(cond, results): + await asyncio.sleep(1) + async with cond: + res = await cond.wait(30) + results.append(res) + + asyncio.run(atmain(cond, results)) + + nthreads = 5 + threads = [] + for _ in range(nthreads): + threads.append(threading.Thread(target=tmain, args=(cond, results))) + for t in threads: + t.start() + + await asyncio.sleep(2) + async with cond: + cond.notify_all() + + for t in threads: + t.join() + + self.assertEqual(results, [True] * nthreads) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 86b9f41ec9..c88a8fd9b4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2389,7 +2389,9 @@ def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(c.address, ("a", 1)) def _test_network_error(self, operation_callback): diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ebbdc74c1c..d1aff03fc9 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( ClientBulkWriteException, @@ -597,7 +598,9 @@ def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/test_cursor.py b/test/test_cursor.py index d99732aec3..7c073bf351 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1405,7 +1405,7 @@ def test_to_list_length(self): def test_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) @@ -1447,7 +1447,7 @@ def test_command_cursor_to_list_length(self): def test_command_cursor_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 8e030f61e8..7cab42cca2 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -19,6 +19,7 @@ import threading from test import IntegrationTest, client_context, unittest from test.utils import ( + CMAPListener, OvertCommandListener, SpecTestCreator, get_pool, @@ -27,6 +28,7 @@ from test.utils_selection_tests import create_topology from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference @@ -131,19 +133,20 @@ def frequencies(self, client, listener, n_finds=10): @client_context.require_multiple_mongoses def test_load_balancing(self): listener = OvertCommandListener() + cmap_listener = CMAPListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", - event_listeners=[listener], + event_listeners=[listener, cmap_listener], localThresholdMS=30000, minPoolSize=10, ) - self.addCleanup(client.close) wait_until(lambda: len(client.nodes) == 2, "discover both nodes") - wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections") - # Delay find commands on + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. delay_finds = { "configureFailPoint": "failCommand", "mode": {"times": 10000}, @@ -161,7 +164,7 @@ def test_load_balancing(self): freqs = self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() - freqs = self.frequencies(client, listener, n_finds=100) + freqs = self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) diff --git a/tools/synchro.py b/tools/synchro.py index 6ce897a0b8..e0c194f962 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -145,7 +145,17 @@ _gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file() ] -test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()] + +def async_only_test(f: str) -> bool: + """Return True for async tests that should not be converted to sync.""" + return f in ["test_locks.py"] + + +test_files = [ + _test_base + f + for f in listdir(_test_base) + if (Path(_test_base) / f).is_file() and not async_only_test(f) +] sync_files = [ _pymongo_dest_base + f @@ -242,7 +252,7 @@ def translate_locks(lines: list[str]) -> list[str]: lock_lines = [line for line in lines if "_Lock(" in line] cond_lines = [line for line in lines if "_Condition(" in line] for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\(\))\)", line) + res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) if res: old = res[0] index = lines.index(line) From e76d411b593521cadd1f0cafc4009433ed65a246 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 16:48:14 -0700 Subject: [PATCH 08/13] PYTHON-4794 Start running IPv6 tests again (#1879) --- test/__init__.py | 2 +- test/asynchronous/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 1a17ff14c5..af12bc032a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -313,7 +313,7 @@ def _init_client(self): params = self.cmd_line["parsed"].get("setParameter", {}) if params.get("enableTestCommands") == "1": self.test_commands_enabled = True - self.has_ipv6 = self._server_started_with_ipv6() + self.has_ipv6 = self._server_started_with_ipv6() self.is_mongos = (self.hello).get("msg") == "isdbgrid" if self.is_mongos: diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0d94331587..2a44785b2f 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -313,7 +313,7 @@ async def _init_client(self): params = self.cmd_line["parsed"].get("setParameter", {}) if params.get("enableTestCommands") == "1": self.test_commands_enabled = True - self.has_ipv6 = await self._server_started_with_ipv6() + self.has_ipv6 = await self._server_started_with_ipv6() self.is_mongos = (await self.hello).get("msg") == "isdbgrid" if self.is_mongos: From 15b22651ec9b167148cc228b1d2704bd2a3b0a41 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 18:28:59 -0700 Subject: [PATCH 09/13] PYTHON-4801 Add beta warning to async tutorial (#1884) --- doc/async-tutorial.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/async-tutorial.rst b/doc/async-tutorial.rst index caa277f9d8..2ccf011d8e 100644 --- a/doc/async-tutorial.rst +++ b/doc/async-tutorial.rst @@ -1,6 +1,11 @@ Async Tutorial ============== +.. warning:: This API is currently in beta, meaning the classes, methods, + and behaviors described within may change before the full release. + If you come across any bugs during your use of this API, + please file a Jira ticket in the "Python Driver" project at https://jira.mongodb.org/browse/PYTHON. + .. code-block:: pycon from pymongo import AsyncMongoClient From 545b88cbd376a7900b1cab921716ed9c291efb73 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 30 Sep 2024 20:42:28 -0500 Subject: [PATCH 10/13] PYTHON-4800 Add changelog for 4.10.0 (#1883) --- doc/changelog.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index dfb3c79827..3b7ddd1553 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,25 @@ Changelog ========= +Changes in Version 4.10.0 +------------------------- + +- Added provisional **(BETA)** support for a new Binary BSON subtype (9) used for efficient storage and retrieval of vectors: + densely packed arrays of numbers, all of the same type. + This includes new methods :meth:`~bson.binary.Binary.from_vector` and :meth:`~bson.binary.Binary.as_vector`. +- Added C extension use to client metadata, for example: ``{"driver": {"name": "PyMongo|c", "version": "4.10.0"}, ...}`` +- Fixed a bug where :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` could deadlock. +- Fixed a bug where PyMongo could fail to import on Windows if ``asyncio`` is misconfigured. + +Issues Resolved +............... + +See the `PyMongo 4.10 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40553 + + Changes in Version 4.9.0 ------------------------- From ae6cfd6d102d885ac6b0873d31f0dac139b1ddae Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Mon, 30 Sep 2024 22:13:09 -0400 Subject: [PATCH 11/13] [DRIVERS-2926] [PYTHON-4577] BSON Binary Vector Subtype Support (#1813) Co-authored-by: Steven Silvester Co-authored-by: Steven Silvester --- .evergreen/resync-specs.sh | 3 + bson/binary.py | 152 +++++++++++++++++++++++- doc/api/bson/binary.rst | 8 ++ doc/changelog.rst | 1 - test/bson_binary_vector/float32.json | 42 +++++++ test/bson_binary_vector/int8.json | 57 +++++++++ test/bson_binary_vector/packed_bit.json | 50 ++++++++ test/bson_corpus/binary.json | 30 +++++ test/test_bson.py | 81 ++++++++++++- test/test_bson_binary_vector.py | 105 ++++++++++++++++ 10 files changed, 519 insertions(+), 10 deletions(-) create mode 100644 test/bson_binary_vector/float32.json create mode 100644 test/bson_binary_vector/int8.json create mode 100644 test/bson_binary_vector/packed_bit.json create mode 100644 test/test_bson_binary_vector.py diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index ac69449729..dca116c2d3 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -76,6 +76,9 @@ do atlas-data-lake-testing|data_lake) cpjson atlas-data-lake-testing/tests/ data_lake ;; + bson-binary-vector|bson_binary_vector) + cpjson bson-binary-vector/tests/ bson_binary_vector + ;; bson-corpus|bson_corpus) cpjson bson-corpus/tests/ bson_corpus ;; diff --git a/bson/binary.py b/bson/binary.py index 5fe1bacd16..47c52d4892 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -13,7 +13,10 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Tuple, Type, Union +import struct +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Sequence, Tuple, Type, Union from uuid import UUID """Tools for representing BSON binary data. @@ -191,21 +194,75 @@ class UuidRepresentation: """ +VECTOR_SUBTYPE = 9 +"""**(BETA)** BSON binary subtype for densely packed vector data. + +.. versionadded:: 4.10 +""" + + USER_DEFINED_SUBTYPE = 128 """BSON binary subtype for any user defined structure. """ +class BinaryVectorDtype(Enum): + """**(BETA)** Datatypes of vector subtype. + + :param FLOAT32: (0x27) Pack list of :class:`float` as float32 + :param INT8: (0x03) Pack list of :class:`int` in [-128, 127] as signed int8 + :param PACKED_BIT: (0x10) Pack list of :class:`int` in [0, 255] as unsigned uint8 + + The `PACKED_BIT` value represents a special case where vector values themselves + can only be of two values (0 or 1) but these are packed together into groups of 8, + a byte. In Python, these are displayed as ints in range [0, 255] + + Each value is of type bytes with a length of one. + + .. versionadded:: 4.10 + """ + + INT8 = b"\x03" + FLOAT32 = b"\x27" + PACKED_BIT = b"\x10" + + +@dataclass +class BinaryVector: + """**(BETA)** Vector of numbers along with metadata for binary interoperability. + .. versionadded:: 4.10 + """ + + __slots__ = ("data", "dtype", "padding") + + def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0): + """ + :param data: Sequence of numbers representing the mathematical vector. + :param dtype: The data type stored in binary + :param padding: The number of bits in the final byte that are to be ignored + when a vector element's size is less than a byte + and the length of the vector is not a multiple of 8. + """ + self.data = data + self.dtype = dtype + self.padding = padding + + class Binary(bytes): """Representation of BSON binary data. - This is necessary because we want to represent Python strings as - the BSON string type. We need to wrap binary data so we can tell + We want to represent Python strings as the BSON string type. + We need to wrap binary data so that we can tell the difference between what should be considered binary data and what should be considered a string when we encode to BSON. - Raises TypeError if `data` is not an instance of :class:`bytes` - or `subtype` is not an instance of :class:`int`. + **(BETA)** Subtype 9 provides a space-efficient representation of 1-dimensional vector data. + Its data is prepended with two bytes of metadata. + The first (dtype) describes its data type, such as float32 or int8. + The second (padding) prescribes the number of bits to ignore in the final byte. + This is relevant when the element size of the dtype is not a multiple of 8. + + Raises TypeError if `subtype` is not an instance of :class:`int`. Raises ValueError if `subtype` is not in [0, 256). .. note:: @@ -218,7 +275,10 @@ class Binary(bytes): to use .. versionchanged:: 3.9 - Support any bytes-like type that implements the buffer protocol. + Support any bytes-like type that implements the buffer protocol. + + .. versionchanged:: 4.10 + **(BETA)** Addition of vector subtype. """ _type_marker = 5 @@ -337,6 +397,86 @@ def as_uuid(self, uuid_representation: int = UuidRepresentation.STANDARD) -> UUI f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}" ) + @classmethod + def from_vector( + cls: Type[Binary], + vector: list[int, float], + dtype: BinaryVectorDtype, + padding: int = 0, + ) -> Binary: + """**(BETA)** Create a BSON :class:`~bson.binary.Binary` of Vector subtype from a list of Numbers. + + To interpret the representation of the numbers, a data type must be included. + See :class:`~bson.binary.BinaryVectorDtype` for available types and descriptions. + + The dtype and padding are prepended to the binary data's value. + + :param vector: List of values + :param dtype: Data type of the values + :param padding: For fractional bytes, number of bits to ignore at end of vector. + :return: Binary packed data identified by dtype and padding. + + .. versionadded:: 4.10 + """ + if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8 + format_str = "b" + if padding: + raise ValueError(f"padding does not apply to {dtype=}") + elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 + format_str = "B" + elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 + format_str = "f" + if padding: + raise ValueError(f"padding does not apply to {dtype=}") + else: + raise NotImplementedError("%s not yet supported" % dtype) + + metadata = struct.pack(" BinaryVector: + """**(BETA)** From the Binary, create a list of numbers, along with dtype and padding. + + :return: BinaryVector + + .. versionadded:: 4.10 + """ + + if self.subtype != VECTOR_SUBTYPE: + raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.") + + position = 0 + dtype, padding = struct.unpack_from(" int: """Subtype of this binary data.""" diff --git a/doc/api/bson/binary.rst b/doc/api/bson/binary.rst index c933a687b9..084fd02d50 100644 --- a/doc/api/bson/binary.rst +++ b/doc/api/bson/binary.rst @@ -21,6 +21,14 @@ .. autoclass:: UuidRepresentation :members: + .. autoclass:: BinaryVectorDtype + :members: + :show-inheritance: + + .. autoclass:: BinaryVector + :members: + + .. autoclass:: Binary(data, subtype=BINARY_SUBTYPE) :members: :show-inheritance: diff --git a/doc/changelog.rst b/doc/changelog.rst index 3b7ddd1553..6c8b8261ac 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -19,7 +19,6 @@ in this release. .. _PyMongo 4.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40553 - Changes in Version 4.9.0 ------------------------- diff --git a/test/bson_binary_vector/float32.json b/test/bson_binary_vector/float32.json new file mode 100644 index 0000000000..bbbe00b758 --- /dev/null +++ b/test/bson_binary_vector/float32.json @@ -0,0 +1,42 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} + diff --git a/test/bson_binary_vector/int8.json b/test/bson_binary_vector/int8.json new file mode 100644 index 0000000000..7529721e5e --- /dev/null +++ b/test/bson_binary_vector/int8.json @@ -0,0 +1,57 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} + diff --git a/test/bson_binary_vector/packed_bit.json b/test/bson_binary_vector/packed_bit.json new file mode 100644 index 0000000000..a41cd593f5 --- /dev/null +++ b/test/bson_binary_vector/packed_bit.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} + diff --git a/test/bson_corpus/binary.json b/test/bson_corpus/binary.json index 20aaef743b..0e0056f3a2 100644 --- a/test/bson_corpus/binary.json +++ b/test/bson_corpus/binary.json @@ -74,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [ diff --git a/test/test_bson.py b/test/test_bson.py index a0190ef2d8..96aa897d19 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -49,8 +49,9 @@ decode_iter, encode, is_valid, + json_util, ) -from bson.binary import USER_DEFINED_SUBTYPE, Binary, UuidRepresentation +from bson.binary import USER_DEFINED_SUBTYPE, Binary, BinaryVectorDtype, UuidRepresentation from bson.code import Code from bson.codec_options import CodecOptions, DatetimeConversion from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION @@ -148,6 +149,9 @@ def helper(doc): helper({"a binary": Binary(b"test", 128)}) helper({"a binary": Binary(b"test", 254)}) helper({"another binary": Binary(b"test", 2)}) + helper({"binary packed bit vector": Binary(b"\x10\x00\x7f\x07", 9)}) + helper({"binary int8 vector": Binary(b"\x03\x00\x7f\x07", 9)}) + helper({"binary float32 vector": Binary(b"'\x00\x00\x00\xfeB\x00\x00\xe0@", 9)}) helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))])) helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))])) helper({"big float": float(10000000000)}) @@ -447,6 +451,20 @@ def test_basic_encode(self): encode({"test": Binary(b"test", 128)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x80\x74\x65\x73\x74\x00", ) + self.assertEqual( + encode({"vector_int8": Binary.from_vector([-128, -1, 127], BinaryVectorDtype.INT8)}), + b"\x1c\x00\x00\x00\x05vector_int8\x00\x05\x00\x00\x00\t\x03\x00\x80\xff\x7f\x00", + ) + self.assertEqual( + encode({"vector_bool": Binary.from_vector([1, 127], BinaryVectorDtype.PACKED_BIT)}), + b"\x1b\x00\x00\x00\x05vector_bool\x00\x04\x00\x00\x00\t\x10\x00\x01\x7f\x00", + ) + self.assertEqual( + encode( + {"vector_float32": Binary.from_vector([-1.1, 1.1e10], BinaryVectorDtype.FLOAT32)} + ), + b"$\x00\x00\x00\x05vector_float32\x00\n\x00\x00\x00\t'\x00\xcd\xcc\x8c\xbf\xac\xe9#P\x00", + ) self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") self.assertEqual( encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}), @@ -711,9 +729,66 @@ def test_uuid_legacy(self): transformed = bin.as_uuid(UuidRepresentation.PYTHON_LEGACY) self.assertEqual(id, transformed) - # The C extension was segfaulting on unicode RegExs, so we have this test - # that doesn't really test anything but the lack of a segfault. + def test_vector(self): + """Tests of subtype 9""" + # We start with valid cases, across the 3 dtypes implemented. + # Work with a simple vector that can be interpreted as int8, float32, or ubyte + list_vector = [127, 7] + # As INT8, vector has length 2 + binary_vector = Binary.from_vector(list_vector, BinaryVectorDtype.INT8) + vector = binary_vector.as_vector() + assert vector.data == list_vector + # test encoding roundtrip + assert {"vector": binary_vector} == decode(encode({"vector": binary_vector})) + # test json roundtrip + assert binary_vector == json_util.loads(json_util.dumps(binary_vector)) + + # For vectors of bits, aka PACKED_BIT type, vector has length 8 * 2 + packed_bit_binary = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT) + packed_bit_vec = packed_bit_binary.as_vector() + assert packed_bit_vec.data == list_vector + + # A padding parameter permits vectors of length that aren't divisible by 8 + # The following ignores the last 3 bits in list_vector, + # hence it's length is 8 * len(list_vector) - padding + padding = 3 + padded_vec = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT, padding=padding) + assert padded_vec.as_vector().data == list_vector + # To visualize how this looks as a binary vector.. + uncompressed = "" + for val in list_vector: + uncompressed += format(val, "08b") + assert uncompressed[:-padding] == "0111111100000" + + # It is worthwhile explicitly showing the values encoded to BSON + padded_doc = {"padded_vec": padded_vec} + assert ( + encode(padded_doc) + == b"\x1a\x00\x00\x00\x05padded_vec\x00\x04\x00\x00\x00\t\x10\x03\x7f\x07\x00" + ) + # and dumped to json + assert ( + json_util.dumps(padded_doc) + == '{"padded_vec": {"$binary": {"base64": "EAN/Bw==", "subType": "09"}}}' + ) + + # FLOAT32 is also implemented + float_binary = Binary.from_vector(list_vector, BinaryVectorDtype.FLOAT32) + assert all(isinstance(d, float) for d in float_binary.as_vector().data) + + # Now some invalid cases + for x in [-1, 257]: + try: + Binary.from_vector([x], BinaryVectorDtype.PACKED_BIT) + except Exception as exc: + self.assertTrue(isinstance(exc, struct.error)) + else: + self.fail("Failed to raise an exception.") + def test_unicode_regex(self): + """Tests we do not get a segfault for C extension on unicode RegExs. + This had been happening. + """ regex = re.compile("revisi\xf3n") decode(encode({"regex": regex})) diff --git a/test/test_bson_binary_vector.py b/test/test_bson_binary_vector.py new file mode 100644 index 0000000000..00c82bbb65 --- /dev/null +++ b/test/test_bson_binary_vector.py @@ -0,0 +1,105 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import binascii +import codecs +import json +import struct +from pathlib import Path +from test import unittest + +from bson import decode, encode +from bson.binary import Binary, BinaryVectorDtype + +_TEST_PATH = Path(__file__).parent / "bson_binary_vector" + + +class TestBSONBinaryVector(unittest.TestCase): + """Runs Binary Vector subtype tests. + + Follows the style of the BSON corpus specification tests. + Tests are automatically generated on import + from json files in _TEST_PATH via `create_tests`. + The actual tests are defined in the inner function `run_test` + of the test generator `create_test`.""" + + +def create_test(case_spec): + """Create standard test given specification in json. + + We use the naming convention expected (exp) and observed (obj) + to differentiate what is in the json (expected or suffix _exp) + from what is produced by the API (observed or suffix _obs) + """ + test_key = case_spec.get("test_key") + + def run_test(self): + for test_case in case_spec.get("tests", []): + description = test_case["description"] + vector_exp = test_case["vector"] + dtype_hex_exp = test_case["dtype_hex"] + dtype_alias_exp = test_case.get("dtype_alias") + padding_exp = test_case.get("padding", 0) + canonical_bson_exp = test_case.get("canonical_bson") + # Convert dtype hex string into bytes + dtype_exp = BinaryVectorDtype(int(dtype_hex_exp, 16).to_bytes(1, byteorder="little")) + + if test_case["valid"]: + # Convert bson string to bytes + cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8")) + decoded_doc = decode(cB_exp) + binary_obs = decoded_doc[test_key] + # Handle special float cases like '-inf' + if dtype_exp in [BinaryVectorDtype.FLOAT32]: + vector_exp = [float(x) for x in vector_exp] + + # Test round-tripping canonical bson. + self.assertEqual(encode(decoded_doc), cB_exp, description) + + # Test BSON to Binary Vector + vector_obs = binary_obs.as_vector() + self.assertEqual(vector_obs.dtype, dtype_exp, description) + if dtype_alias_exp: + self.assertEqual( + vector_obs.dtype, BinaryVectorDtype[dtype_alias_exp], description + ) + self.assertEqual(vector_obs.data, vector_exp, description) + self.assertEqual(vector_obs.padding, padding_exp, description) + + # Test Binary Vector to BSON + vector_exp = Binary.from_vector(vector_exp, dtype_exp, padding_exp) + cB_obs = binascii.hexlify(encode({test_key: vector_exp})).decode().upper() + self.assertEqual(cB_obs, canonical_bson_exp, description) + + else: + with self.assertRaises((struct.error, ValueError), msg=description): + Binary.from_vector(vector_exp, dtype_exp, padding_exp) + + return run_test + + +def create_tests(): + for filename in _TEST_PATH.glob("*.json"): + with codecs.open(str(filename), encoding="utf-8") as test_file: + test_method = create_test(json.load(test_file)) + setattr(TestBSONBinaryVector, "test_" + filename.stem, test_method) + + +create_tests() + + +if __name__ == "__main__": + unittest.main() From 4713afa910f12d013571778a12fda2287d0bf19d Mon Sep 17 00:00:00 2001 From: "mongodb-dbx-release-bot[bot]" <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:14:15 +0000 Subject: [PATCH 12/13] BUMP 4.10.0 Signed-off-by: mongodb-dbx-release-bot[bot] <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> --- pymongo/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/_version.py b/pymongo/_version.py index 5ff72d6cc8..7cc4bb8e1d 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -18,7 +18,7 @@ import re from typing import List, Tuple, Union -__version__ = "4.10.0.dev0" +__version__ = "4.10.0" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: From c0f7810d56555c8a285beaa9aa5fe6d2b7185eff Mon Sep 17 00:00:00 2001 From: "mongodb-dbx-release-bot[bot]" <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:31:13 +0000 Subject: [PATCH 13/13] BUMP 4.11.0.dev0 Signed-off-by: mongodb-dbx-release-bot[bot] <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com> --- pymongo/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/_version.py b/pymongo/_version.py index 7cc4bb8e1d..3de24a8e14 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -18,7 +18,7 @@ import re from typing import List, Tuple, Union -__version__ = "4.10.0" +__version__ = "4.11.0.dev0" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: