diff --git a/CHANGES/10137.misc.rst b/CHANGES/10137.misc.rst new file mode 100644 index 00000000000..43b19c33f32 --- /dev/null +++ b/CHANGES/10137.misc.rst @@ -0,0 +1,3 @@ +Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`. + +Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions. diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index edd19ed65da..8cab85882fd 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -1,6 +1,7 @@ """Http related parsers and protocol.""" import asyncio +import sys import zlib from typing import ( # noqa Any, @@ -24,6 +25,17 @@ __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") +MIN_PAYLOAD_FOR_WRITELINES = 2048 +IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2) +IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9) +SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9 +# writelines is not safe for use +# on Python 3.12+ until 3.12.9 +# on Python 3.13+ until 3.13.2 +# and on older versions it not any faster than write +# CVE-2024-12254: https://github.com/python/cpython/pull/127656 + + class HttpVersion(NamedTuple): major: int minor: int @@ -90,7 +102,10 @@ def _writelines(self, chunks: Iterable[bytes]) -> None: transport = self._protocol.transport if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") - transport.write(b"".join(chunks)) + if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES: + transport.write(b"".join(chunks)) + else: + transport.writelines(chunks) async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 65afc05ae10..95dabeab377 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -2,7 +2,7 @@ import array import asyncio import zlib -from typing import Any, Iterable +from typing import Any, Generator, Iterable from unittest import mock import pytest @@ -13,6 +13,18 @@ from aiohttp.test_utils import make_mocked_coro +@pytest.fixture +def enable_writelines() -> Generator[None, None, None]: + with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False): + yield + + +@pytest.fixture +def force_writelines_small_payloads() -> Generator[None, None, None]: + with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1): + yield + + @pytest.fixture def buf() -> bytearray: return bytearray() @@ -136,6 +148,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof( assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("enable_writelines") +async def test_write_large_payload_deflate_compression_data_in_eof_writelines( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + + await msg.write(b"data" * 4096) + assert transport.write.called # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + transport.write.reset_mock() # type: ignore[attr-defined] + assert not transport.writelines.called # type: ignore[attr-defined] + + # This payload compresses to 20447 bytes + payload = b"".join( + [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] + ) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + assert transport.writelines.called # type: ignore[attr-defined] + chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + async def test_write_payload_chunked_filter( protocol: BaseProtocol, transport: asyncio.Transport, @@ -207,6 +246,26 @@ async def test_write_payload_deflate_compression_chunked( assert content == expected +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +async def test_write_payload_deflate_compression_chunked_writelines( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof() + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + async def test_write_payload_deflate_and_chunked( buf: bytearray, protocol: BaseProtocol, @@ -243,6 +302,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof( assert content == expected +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof(b"end") + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + async def test_write_large_payload_deflate_compression_chunked_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, @@ -269,6 +348,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof( assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"data" * 4096) + # This payload compresses to 1111 bytes + payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + + chunks = [] + for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] + chunked_payload = list(write_lines_call[1][0])[1:] + chunked_payload.pop() + chunks.extend(chunked_payload) + + assert all(chunks) + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + async def test_write_payload_deflate_compression_chunked_connection_lost( protocol: BaseProtocol, transport: asyncio.Transport,