diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 44a63a2fc3..d17aead120 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -15,11 +15,8 @@ """Internal network layer helper methods.""" from __future__ import annotations -import asyncio import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -40,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + async_receive_data, async_sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -318,9 +312,7 @@ async def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER( - await _receive_data_on_socket(conn, 16, deadline) - ) + length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -336,11 +328,11 @@ async def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await _receive_data_on_socket(conn, 9, deadline) + await async_receive_data(conn, 9, deadline) ) - data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id) + data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) else: - data = await _receive_data_on_socket(conn, length - 16, deadline) + data = await async_receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -349,66 +341,3 @@ async def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - await asyncio.sleep(0) - - -async def _receive_data_on_socket( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - await wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 82a6228acc..7479b6be80 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -16,15 +16,21 @@ from __future__ import annotations import asyncio +import errno import socket import struct import sys +import time from asyncio import AbstractEventLoop, Future from typing import ( + TYPE_CHECKING, + Optional, Union, ) -from pymongo import ssl_support +from pymongo import _csot, ssl_support +from pymongo.errors import _OperationCancelled +from pymongo.socket_checker import _errno_from_exception try: from ssl import SSLError, SSLSocket @@ -51,6 +57,10 @@ BLOCKING_IO_WRITE_ERROR, ) +if TYPE_CHECKING: + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.synchronous.pool import Connection + _UNPACK_HEADER = struct.Struct(" None: sock.sendall(buf) + + +async def async_receive_data( + conn: AsyncConnection, length: int, deadline: Optional[float] +) -> memoryview: + sock = conn.conn + sock_timeout = sock.gettimeout() + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + else: + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout) + else: + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc + finally: + sock.settimeout(sock_timeout) + + +async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: + mv = memoryview(bytearray(length)) + bytes_read = 0 + while bytes_read < length: + chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) + if chunk_length == 0: + raise OSError("connection closed") + bytes_read += chunk_length + return mv + + +async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001 + return memoryview(b"") + + +# Sync version: +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + +def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: + buf = bytearray(length) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < length: + try: + wait_for_read(conn, deadline) + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + except OSError as exc: + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + + return mv diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index c1978087a9..7206dca735 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,9 +16,7 @@ from __future__ import annotations import datetime -import errno import logging -import socket import time from typing import ( TYPE_CHECKING, @@ -39,19 +37,16 @@ NotPrimaryError, OperationFailure, ProtocolError, - _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, _UNPACK_HEADER, - BLOCKING_IO_ERRORS, + receive_data, sendall, ) -from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions @@ -317,7 +312,7 @@ def receive_message( else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -332,12 +327,10 @@ def receive_message( f"message size ({max_message_size!r})" ) if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(conn, 9, deadline) - ) - data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(conn, length - 16, deadline) + data = receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -346,63 +339,3 @@ def receive_message( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) - - -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - -def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: - buf = bytearray(length) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: - continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length - - return mv diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e5..f9d9ee826d 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -43,6 +43,7 @@ "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", + "async_receive_data": "receive_data", "async_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous",