From adb16700c57d636367138bd42a7173f8c9744770 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 18 Sep 2024 16:45:56 -0700 Subject: [PATCH 01/13] PYTHON-4636 Stop blocking the I/O Loop for socket reads --- pymongo/asynchronous/network.py | 81 ++-------------------- pymongo/network_layer.py | 115 +++++++++++++++++++++++++++++++- pymongo/synchronous/network.py | 77 ++------------------- tools/synchro.py | 1 + 4 files changed, 125 insertions(+), 149 deletions(-) diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 44a63a2fc3..d17aead120 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -15,11 +15,8 @@ """Internal network layer helper methods.""" from __future__ import annotations -import asyncio import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -40,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + async_receive_data, async_sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -318,9 +312,7 @@ async def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER( - await _receive_data_on_socket(conn, 16, deadline) - ) + length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -336,11 +328,11 @@ async def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await _receive_data_on_socket(conn, 9, deadline) + await async_receive_data(conn, 9, deadline) ) - data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id) + data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) else: - data = await _receive_data_on_socket(conn, length - 16, deadline) + data = await async_receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -349,66 +341,3 @@ async def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - await asyncio.sleep(0) - - -async def _receive_data_on_socket( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - await wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 82a6228acc..7479b6be80 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -16,15 +16,21 @@ from __future__ import annotations import asyncio +import errno import socket import struct import sys +import time from asyncio import AbstractEventLoop, Future from typing import ( + TYPE_CHECKING, + Optional, Union, ) -from pymongo import ssl_support +from pymongo import _csot, ssl_support +from pymongo.errors import _OperationCancelled +from pymongo.socket_checker import _errno_from_exception try: from ssl import SSLError, SSLSocket @@ -51,6 +57,10 @@ BLOCKING_IO_WRITE_ERROR, ) +if TYPE_CHECKING: + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.synchronous.pool import Connection + _UNPACK_HEADER = struct.Struct(" None: sock.sendall(buf) + + +async def async_receive_data( + conn: AsyncConnection, length: int, deadline: Optional[float] +) -> memoryview: + sock = conn.conn + sock_timeout = sock.gettimeout() + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + else: + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout) + else: + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc + finally: + sock.settimeout(sock_timeout) + + +async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: + mv = memoryview(bytearray(length)) + bytes_read = 0 + while bytes_read < length: + chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) + if chunk_length == 0: + raise OSError("connection closed") + bytes_read += chunk_length + return mv + + +async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001 + return memoryview(b"") + + +# Sync version: +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + +def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: + buf = bytearray(length) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < length: + try: + wait_for_read(conn, deadline) + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + except OSError as exc: + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + + return mv diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index c1978087a9..7206dca735 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,9 +16,7 @@ from __future__ import annotations import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -39,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + receive_data, sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -317,7 +312,7 @@ def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -332,12 +327,10 @@ def receive_message( f"message size ({max_message_size!r})" ) if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(conn, 9, deadline) - ) - data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(conn, length - 16, deadline) + data = receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -346,63 +339,3 @@ def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - -def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e5..f9d9ee826d 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -43,6 +43,7 @@ "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", + "async_receive_data": "receive_data", "async_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous", From 3d399da7cc1f52dc012af6275a8aa904066767cb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 19 Sep 2024 15:41:35 -0400 Subject: [PATCH 02/13] Implement _async_receive_ssl --- pymongo/network_layer.py | 56 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 7479b6be80..eb7e8cd4f0 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -121,6 +121,44 @@ def _is_ready(fut: Future) -> None: loop.add_reader(fd, _is_ready, fut) loop.add_writer(fd, _is_ready, fut) await fut + + async def _async_receive_ssl( + conn: _sslConn, length: int, loop: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + fd = conn.fileno() + read = 0 + + def _is_ready(fut: Future) -> None: + loop.remove_writer(fd) + loop.remove_reader(fd) + if fut.done(): + return + fut.set_result(None) + + while read < length: + try: + read += conn.recv_into(mv[read:]) + except BLOCKING_IO_ERRORS as exc: + fd = conn.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + await fut + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() + loop.add_writer(fd, _is_ready, fut) + await fut + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + loop.add_writer(fd, _is_ready, fut) + await fut + return mv + else: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support @@ -138,6 +176,20 @@ async def _async_sendall_ssl( sent = 0 total_sent += sent + async def _async_receive_ssl( + conn: _sslConn, length: int, dummy: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + except BLOCKING_IO_ERRORS: + await asyncio.sleep(0.5) + read = 0 + total_read += read + return mv + def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) @@ -181,10 +233,6 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo return mv -async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001 - return memoryview(b"") - - # Sync version: def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" From 7f714304099f1c74b126c7e986bc6ee7bd388239 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 20 Sep 2024 10:59:07 -0400 Subject: [PATCH 03/13] Add async support for cancellation --- pymongo/network_layer.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index eb7e8cd4f0..46805ad1cb 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -139,6 +139,8 @@ def _is_ready(fut: Future) -> None: while read < length: try: read += conn.recv_into(mv[read:]) + if read == 0: + raise OSError("connection closed") except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() # Check for closed socket. @@ -195,11 +197,20 @@ def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) +async def _poll_cancellation(conn: AsyncConnection) -> None: + while True: + if conn.cancel_context.cancelled: + return + + await asyncio.sleep(_POLL_TIMEOUT) + + async def async_receive_data( conn: AsyncConnection, length: int, deadline: Optional[float] ) -> memoryview: sock = conn.conn sock_timeout = sock.gettimeout() + timeout: Optional[Union[float, int]] if deadline: # When the timeout has expired perform one final check to # see if the socket is readable. This helps avoid spurious @@ -210,14 +221,22 @@ async def async_receive_data( sock.settimeout(0.0) loop = asyncio.get_event_loop() + cancellation_task = asyncio.create_task(_poll_cancellation(conn)) try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout) + read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] else: - return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] - except asyncio.TimeoutError as exc: - # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. - raise socket.timeout("timed out") from exc + read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + tasks = [read_task, cancellation_task] + result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) + if len(result[1]) == 2: + raise socket.timeout("timed out") + finished = next(iter(result[0])) + next(iter(result[1])).cancel() + if finished == read_task: + return finished.result() # type: ignore[return-value] + else: + raise _OperationCancelled("operation cancelled") finally: sock.settimeout(sock_timeout) From d69b5f622d493ce43d353ea9ce42b3d7204edf19 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 20 Sep 2024 14:56:14 -0400 Subject: [PATCH 04/13] Async pyopenssl support --- pymongo/network_layer.py | 27 ++++++++++++++++----------- pymongo/pyopenssl_context.py | 11 +++++++++-- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 46805ad1cb..199dd6763f 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -127,7 +127,7 @@ async def _async_receive_ssl( ) -> memoryview: mv = memoryview(bytearray(length)) fd = conn.fileno() - read = 0 + total_read = 0 def _is_ready(fut: Future) -> None: loop.remove_writer(fd) @@ -136,11 +136,12 @@ def _is_ready(fut: Future) -> None: return fut.set_result(None) - while read < length: + while total_read < length: try: - read += conn.recv_into(mv[read:]) + read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() # Check for closed socket. @@ -228,15 +229,19 @@ async def async_receive_data( else: read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] - result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) - if len(result[1]) == 2: + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if len(done) == 0: raise socket.timeout("timed out") - finished = next(iter(result[0])) - next(iter(result[1])).cancel() - if finished == read_task: - return finished.result() # type: ignore[return-value] - else: - raise _OperationCancelled("operation cancelled") + for task in done: + if task == read_task: + return read_task.result() + else: + raise _OperationCancelled("operation cancelled") + return None # type: ignore[return-value] finally: sock.settimeout(sock_timeout) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 4f6f6f4a89..e521a92789 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -105,11 +105,16 @@ def _ragged_eof(exc: BaseException) -> bool: # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__( - self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool + self, + ctx: _SSL.Context, + sock: Optional[_socket.socket], + suppress_ragged_eofs: bool, + is_async: bool = False, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) + self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: timeout = self.gettimeout() @@ -119,6 +124,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: + if self._is_async: + raise exc # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: @@ -381,7 +388,7 @@ async def a_wrap_socket( """Wrap an existing Python socket connection and return a TLS socket object. """ - ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) + ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True) loop = asyncio.get_running_loop() if session: ssl_conn.set_session(session) From 45b40451325511855865c2652d471191f74940a3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 20 Sep 2024 15:28:32 -0400 Subject: [PATCH 05/13] Fixes --- pymongo/network_layer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 199dd6763f..40fbe61b3f 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -236,12 +236,9 @@ async def async_receive_data( task.cancel() if len(done) == 0: raise socket.timeout("timed out") - for task in done: - if task == read_task: - return read_task.result() - else: - raise _OperationCancelled("operation cancelled") - return None # type: ignore[return-value] + if read_task in done: + return read_task.result() + raise _OperationCancelled("operation cancelled") finally: sock.settimeout(sock_timeout) From 0e035816f52bb01a9392f11bc2e867ccd7efd1f4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 26 Sep 2024 12:31:50 -0400 Subject: [PATCH 06/13] Fix async ssl handshake --- pymongo/network_layer.py | 4 +++- pymongo/pyopenssl_context.py | 4 +++- test/asynchronous/test_client.py | 6 ++---- test/test_client.py | 6 ++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 40fbe61b3f..a39c192b83 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -238,7 +238,9 @@ async def async_receive_data( raise socket.timeout("timed out") if read_task in done: return read_task.result() - raise _OperationCancelled("operation cancelled") + elif conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + return read_task.result() finally: sock.settimeout(sock_timeout) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index e521a92789..50d8680a74 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -117,6 +117,7 @@ def __init__( self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: + is_async = kwargs.pop("allow_async", True) and self._is_async timeout = self.gettimeout() if timeout: start = _time.monotonic() @@ -124,7 +125,7 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: - if self._is_async: + if is_async: raise exc # Check for closed socket. if self.fileno() == -1: @@ -146,6 +147,7 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: + kwargs["allow_async"] = False return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args: Any, **kwargs: Any) -> bytes: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f610f32779..a88d9bd136 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1703,6 +1703,7 @@ def compression_settings(client): # No error await client.pymongo_test.test.find_one() + @async_client_context.require_sync async def test_reset_during_update_pool(self): client = await self.async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") @@ -1727,10 +1728,7 @@ async def _run(self): await asyncio.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() diff --git a/test/test_client.py b/test/test_client.py index bc45325f0b..2c8407537a 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1661,6 +1661,7 @@ def compression_settings(client): # No error client.pymongo_test.test.find_one() + @client_context.require_sync def test_reset_during_update_pool(self): client = self.rs_or_single_client(minPoolSize=10) client.admin.command("ping") @@ -1685,10 +1686,7 @@ def _run(self): time.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() From b241f6d08a25bbb3f81fb6fcec94cbeeb6baf4b0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 26 Sep 2024 14:14:00 -0400 Subject: [PATCH 07/13] Undo fix --- pymongo/network_layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index a39c192b83..40fbe61b3f 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -238,9 +238,7 @@ async def async_receive_data( raise socket.timeout("timed out") if read_task in done: return read_task.result() - elif conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - return read_task.result() + raise _OperationCancelled("operation cancelled") finally: sock.settimeout(sock_timeout) From 33a927c5c2a32219a863d9c07c21ea3b02e19c8a Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 1 Oct 2024 16:46:18 -0400 Subject: [PATCH 08/13] Ensure cancelled tasks are actually cancelled --- pymongo/network_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 40fbe61b3f..d2eff756c0 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -234,6 +234,7 @@ async def async_receive_data( ) for task in pending: task.cancel() + await asyncio.sleep(0) # Ensure the task actually cancels by yielding to the loop here if len(done) == 0: raise socket.timeout("timed out") if read_task in done: From 5dc1fd2b6de6a1d82b3cdebc8aa0d564645e9adb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 2 Oct 2024 16:31:16 -0400 Subject: [PATCH 09/13] Fix waiting for pending task cancellation --- pymongo/network_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index d2eff756c0..051c33d8de 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -234,7 +234,7 @@ async def async_receive_data( ) for task in pending: task.cancel() - await asyncio.sleep(0) # Ensure the task actually cancels by yielding to the loop here + await asyncio.wait(pending) if len(done) == 0: raise socket.timeout("timed out") if read_task in done: From d9c9612be40f9c4eaf8b1fae7691dcfc9d13951c Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 2 Oct 2024 18:00:56 -0700 Subject: [PATCH 10/13] PYTHON-4636 Always remove reader/writer callbacks --- pymongo/network_layer.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 051c33d8de..1aaab0b777 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -90,12 +90,9 @@ async def _async_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop ) -> None: view = memoryview(buf) - fd = sock.fileno() sent = 0 def _is_ready(fut: Future) -> None: - loop.remove_writer(fd) - loop.remove_reader(fd) if fut.done(): return fut.set_result(None) @@ -111,27 +108,34 @@ def _is_ready(fut: Future) -> None: if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - loop.add_writer(fd, _is_ready, fut) - await fut + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) async def _async_receive_ssl( conn: _sslConn, length: int, loop: AbstractEventLoop ) -> memoryview: mv = memoryview(bytearray(length)) - fd = conn.fileno() total_read = 0 def _is_ready(fut: Future) -> None: - loop.remove_writer(fd) - loop.remove_reader(fd) if fut.done(): return fut.set_result(None) @@ -150,16 +154,26 @@ def _is_ready(fut: Future) -> None: if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - loop.add_writer(fd, _is_ready, fut) - await fut + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) return mv else: From 1e9a134802e8106aba3d063069fa04980541602d Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 3 Oct 2024 10:00:39 -0700 Subject: [PATCH 11/13] PYTHON-4636 Properly check for socket close on windows --- pymongo/network_layer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 1aaab0b777..1f28942fed 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -201,6 +201,8 @@ async def _async_receive_ssl( while total_read < length: try: read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") except BLOCKING_IO_ERRORS: await asyncio.sleep(0.5) read = 0 From 24221255e94438adc2d3319ce1b05787459f412b Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 17 Sep 2024 18:42:53 -0700 Subject: [PATCH 12/13] PYTHON-4770 Improve CPU overhead of async locks and latency on Windows TLS sendall --- pymongo/lock.py | 4 ++-- pymongo/network_layer.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index 0cbfb4a57e..9d3e8d7a16 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -71,7 +71,7 @@ async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: return False if not blocking: return False - await asyncio.sleep(0) + await asyncio.sleep(0.001) def release(self) -> None: self._lock.release() @@ -115,7 +115,7 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: return False if not blocking: return False - await asyncio.sleep(0) + await asyncio.sleep(0.001) async def wait(self, timeout: Optional[float] = None) -> bool: """Wait until notified. diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 1f28942fed..a1a4124d40 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -179,18 +179,26 @@ def _is_ready(fut: Future) -> None: else: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support + # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. async def _async_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) total_length = len(buf) total_sent = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 while total_sent < total_length: try: sent = sock.send(view[total_sent:]) except BLOCKING_IO_ERRORS: - await asyncio.sleep(0.5) + await asyncio.sleep(backoff) sent = 0 + if sent > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) total_sent += sent async def _async_receive_ssl( From 5b37776b1c76ff39be282e7b24cf82377e59d7cd Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 3 Oct 2024 10:39:02 -0700 Subject: [PATCH 13/13] PYTHON-4636 Fix windows recv perf --- pymongo/lock.py | 4 ++-- pymongo/network_layer.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index 9d3e8d7a16..0cbfb4a57e 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -71,7 +71,7 @@ async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: return False if not blocking: return False - await asyncio.sleep(0.001) + await asyncio.sleep(0) def release(self) -> None: self._lock.release() @@ -115,7 +115,7 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: return False if not blocking: return False - await asyncio.sleep(0.001) + await asyncio.sleep(0) async def wait(self, timeout: Optional[float] = None) -> bool: """Wait until notified. diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index a1a4124d40..4b57620d83 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -206,14 +206,21 @@ async def _async_receive_ssl( ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 while total_read < length: try: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") except BLOCKING_IO_ERRORS: - await asyncio.sleep(0.5) + await asyncio.sleep(backoff) read = 0 + if read > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) total_read += read return mv