Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore zero copy writes on Python 3.12.9+/3.13.2+ #10137

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/10137.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`.

Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions.
17 changes: 16 additions & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Http related parsers and protocol."""

import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Expand All @@ -24,6 +25,17 @@
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")


MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656


class HttpVersion(NamedTuple):
major: int
minor: int
Expand Down Expand Up @@ -90,7 +102,10 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(b"".join(chunks))
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
Expand Down
109 changes: 108 additions & 1 deletion tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import array
import asyncio
import zlib
from typing import Any, Iterable
from typing import Any, Generator, Iterable
from unittest import mock

import pytest
Expand All @@ -13,6 +13,18 @@
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
def enable_writelines() -> Generator[None, None, None]:
with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False):
yield


@pytest.fixture
def force_writelines_small_payloads() -> Generator[None, None, None]:
with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1):
yield


@pytest.fixture
def buf() -> bytearray:
return bytearray()
Expand Down Expand Up @@ -136,6 +148,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
assert zlib.decompress(content) == (b"data" * 4096) + payload


@pytest.mark.usefixtures("enable_writelines")
async def test_write_large_payload_deflate_compression_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")

await msg.write(b"data" * 4096)
assert transport.write.called # type: ignore[attr-defined]
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
transport.write.reset_mock() # type: ignore[attr-defined]
assert not transport.writelines.called # type: ignore[attr-defined]

# This payload compresses to 20447 bytes
payload = b"".join(
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
)
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]
assert transport.writelines.called # type: ignore[attr-defined]
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_chunked_filter(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand Down Expand Up @@ -207,6 +246,26 @@ async def test_write_payload_deflate_compression_chunked(
assert content == expected


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_payload_deflate_compression_chunked_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof()

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_payload_deflate_and_chunked(
buf: bytearray,
protocol: BaseProtocol,
Expand Down Expand Up @@ -243,6 +302,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
assert content == expected


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof(b"end")

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand All @@ -269,6 +348,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
assert zlib.decompress(content) == (b"data" * 4096) + payload


@pytest.mark.usefixtures("enable_writelines")
@pytest.mark.usefixtures("force_writelines_small_payloads")
async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()

await msg.write(b"data" * 4096)
# This payload compresses to 1111 bytes
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]

chunks = []
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
chunked_payload = list(write_lines_call[1][0])[1:]
chunked_payload.pop()
chunks.extend(chunked_payload)

assert all(chunks)
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_deflate_compression_chunked_connection_lost(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand Down
Loading