diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 44a63a2fc3..d17aead120 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -15,11 +15,8 @@ """Internal network layer helper methods.""" from __future__ import annotations -import asyncio import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -40,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + async_receive_data, async_sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -318,9 +312,7 @@ async def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER( - await _receive_data_on_socket(conn, 16, deadline) - ) + length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -336,11 +328,11 @@ async def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await _receive_data_on_socket(conn, 9, deadline) + await async_receive_data(conn, 9, deadline) ) - data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id) + data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) else: - data = await _receive_data_on_socket(conn, length - 16, deadline) + data = await async_receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -349,66 +341,3 @@ async def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - await asyncio.sleep(0) - - -async def _receive_data_on_socket( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - await wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 82a6228acc..4b57620d83 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -16,15 +16,21 @@ from __future__ import annotations import asyncio +import errno import socket import struct import sys +import time from asyncio import AbstractEventLoop, Future from typing import ( + TYPE_CHECKING, + Optional, Union, ) -from pymongo import ssl_support +from pymongo import _csot, ssl_support +from pymongo.errors import _OperationCancelled +from pymongo.socket_checker import _errno_from_exception try: from ssl import SSLError, SSLSocket @@ -51,6 +57,10 @@ BLOCKING_IO_WRITE_ERROR, ) +if TYPE_CHECKING: + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.synchronous.pool import Connection + _UNPACK_HEADER = struct.Struct(" None: view = memoryview(buf) - fd = sock.fileno() sent = 0 def _is_ready(fut: Future) -> None: - loop.remove_writer(fd) - loop.remove_reader(fd) if fut.done(): return fut.set_result(None) @@ -101,33 +108,240 @@ def _is_ready(fut: Future) -> None: if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + + async def _async_receive_ssl( + conn: _sslConn, length: int, loop: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + + def _is_ready(fut: Future) -> None: + if fut.done(): + return + fut.set_result(None) + + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + total_read += read + except BLOCKING_IO_ERRORS as exc: + fd = conn.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_reader(fd) + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) - await fut + try: + await fut + finally: + loop.remove_writer(fd) + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + return mv + else: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support + # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. async def _async_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) total_length = len(buf) total_sent = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 while total_sent < total_length: try: sent = sock.send(view[total_sent:]) except BLOCKING_IO_ERRORS: - await asyncio.sleep(0.5) + await asyncio.sleep(backoff) sent = 0 + if sent > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) total_sent += sent + async def _async_receive_ssl( + conn: _sslConn, length: int, dummy: AbstractEventLoop + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + except BLOCKING_IO_ERRORS: + await asyncio.sleep(backoff) + read = 0 + if read > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) + total_read += read + return mv + def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) + + +async def _poll_cancellation(conn: AsyncConnection) -> None: + while True: + if conn.cancel_context.cancelled: + return + + await asyncio.sleep(_POLL_TIMEOUT) + + +async def async_receive_data( + conn: AsyncConnection, length: int, deadline: Optional[float] +) -> memoryview: + sock = conn.conn + sock_timeout = sock.gettimeout() + timeout: Optional[Union[float, int]] + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + else: + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + cancellation_task = asyncio.create_task(_poll_cancellation(conn)) + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + else: + read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + tasks = [read_task, cancellation_task] + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + await asyncio.wait(pending) + if len(done) == 0: + raise socket.timeout("timed out") + if read_task in done: + return read_task.result() + raise _OperationCancelled("operation cancelled") + finally: + sock.settimeout(sock_timeout) + + +async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: + mv = memoryview(bytearray(length)) + bytes_read = 0 + while bytes_read < length: + chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) + if chunk_length == 0: + raise OSError("connection closed") + bytes_read += chunk_length + return mv + + +# Sync version: +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + +def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: + buf = bytearray(length) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < length: + try: + wait_for_read(conn, deadline) + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + except OSError as exc: + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + + return mv diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 4f6f6f4a89..50d8680a74 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -105,13 +105,19 @@ def _ragged_eof(exc: BaseException) -> bool: # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__( - self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool + self, + ctx: _SSL.Context, + sock: Optional[_socket.socket], + suppress_ragged_eofs: bool, + is_async: bool = False, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) + self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: + is_async = kwargs.pop("allow_async", True) and self._is_async timeout = self.gettimeout() if timeout: start = _time.monotonic() @@ -119,6 +125,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: + if is_async: + raise exc # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: @@ -139,6 +147,7 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: + kwargs["allow_async"] = False return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args: Any, **kwargs: Any) -> bytes: @@ -381,7 +390,7 @@ async def a_wrap_socket( """Wrap an existing Python socket connection and return a TLS socket object. """ - ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) + ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True) loop = asyncio.get_running_loop() if session: ssl_conn.set_session(session) diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index c1978087a9..7206dca735 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,9 +16,7 @@ from __future__ import annotations import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -39,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + receive_data, sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -317,7 +312,7 @@ def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -332,12 +327,10 @@ def receive_message( f"message size ({max_message_size!r})" ) if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(conn, 9, deadline) - ) - data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(conn, length - 16, deadline) + data = receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -346,63 +339,3 @@ def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - -def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5c06331790..2052d1cd7f 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1713,6 +1713,7 @@ def compression_settings(client): # No error await client.pymongo_test.test.find_one() + @async_client_context.require_sync async def test_reset_during_update_pool(self): client = await self.async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") @@ -1737,10 +1738,7 @@ async def _run(self): await asyncio.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() diff --git a/test/test_client.py b/test/test_client.py index c88a8fd9b4..936c38b8c6 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1671,6 +1671,7 @@ def compression_settings(client): # No error client.pymongo_test.test.find_one() + @client_context.require_sync def test_reset_during_update_pool(self): client = self.rs_or_single_client(minPoolSize=10) client.admin.command("ping") @@ -1695,10 +1696,7 @@ def _run(self): time.sleep(0.001) def run(self): - if _IS_SYNC: - self._run() - else: - asyncio.run(self._run()) + self._run() t = ResetPoolThread(pool) t.start() diff --git a/tools/synchro.py b/tools/synchro.py index 3333b0de2e..ec4a9202ee 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -43,6 +43,7 @@ "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", + "async_receive_data": "receive_data", "async_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous",