Skip to content

Commit

Permalink
Implement _async_receive_ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Sep 19, 2024
1 parent adb1670 commit 3d399da
Showing 1 changed file with 52 additions and 4 deletions.
56 changes: 52 additions & 4 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,44 @@ def _is_ready(fut: Future) -> None:
loop.add_reader(fd, _is_ready, fut)
loop.add_writer(fd, _is_ready, fut)
await fut

async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
fd = conn.fileno()
read = 0

def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)
loop.remove_reader(fd)
if fut.done():
return
fut.set_result(None)

while read < length:
try:
read += conn.recv_into(mv[read:])
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
if fd == -1:
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
await fut
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
await fut
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
loop.add_writer(fd, _is_ready, fut)
await fut
return mv

else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
Expand All @@ -138,6 +176,20 @@ async def _async_sendall_ssl(
sent = 0
total_sent += sent

async def _async_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
except BLOCKING_IO_ERRORS:
await asyncio.sleep(0.5)
read = 0
total_read += read
return mv


def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)
Expand Down Expand Up @@ -181,10 +233,6 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
return mv


async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001
return memoryview(b"")


# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
Expand Down

0 comments on commit 3d399da

Please sign in to comment.