From cbace9266afc0e7df39dc2abb7ffa4f77e827d59 Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Thu, 1 Aug 2024 10:20:41 +0300 Subject: [PATCH] Force keyword arguments in some cases to avoid ambiguity (#47) --- pyproject.toml | 2 +- stompman/client.py | 12 +++---- stompman/config.py | 5 +-- stompman/connection.py | 8 ++--- stompman/connection_manager.py | 3 +- stompman/errors.py | 9 ++--- stompman/frames.py | 2 +- stompman/serde.py | 13 +++---- testing/consumer.py | 7 ++-- tests/conftest.py | 7 ++-- tests/test_client.py | 11 +++--- tests/test_connection.py | 3 +- tests/test_connection_manager.py | 61 ++++++++++++++++++-------------- 13 files changed, 76 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23de370..429c603 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ line-length = 120 [tool.ruff.lint] preview = true select = ["ALL"] -ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001"] +ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001", "PLR0913", "PLC2801"] extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] } [tool.pytest.ini_options] diff --git a/stompman/client.py b/stompman/client.py index dc9c2e4..a69ac10 100644 --- a/stompman/client.py +++ b/stompman/client.py @@ -31,7 +31,7 @@ @asynccontextmanager -async def connection_lifespan( # noqa: PLR0913 +async def connection_lifespan( *, connection: AbstractConnection, connection_parameters: ConnectionParameters, @@ -122,7 +122,7 @@ async def unsubscribe(self) -> None: del self._active_subscriptions[self.id] await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id})) - async def _run_handler(self, frame: MessageFrame) -> None: + async def _run_handler(self, *, frame: MessageFrame) -> None: try: await self.handler(frame) except self.supressed_exception_classes as exception: @@ -185,7 +185,7 @@ async def __aexit__( self._active_transactions.remove(self) async def send( - self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None + self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None ) -> None: frame = SendFrame.build( body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers @@ -295,7 +295,7 @@ async def _listen_to_frames(self) -> None: match frame: case MessageFrame(): if subscription := self._active_subscriptions.get(frame.headers["subscription"]): - task_group.create_task(subscription._run_handler(frame)) # noqa: SLF001 + task_group.create_task(subscription._run_handler(frame=frame)) # noqa: SLF001 elif self.on_unhandled_message_frame: self.on_unhandled_message_frame(frame) case ErrorFrame(): @@ -308,7 +308,7 @@ async def _listen_to_frames(self) -> None: pass async def send( - self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None + self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None ) -> None: await self._connection_manager.write_frame_reconnecting( SendFrame.build( @@ -323,7 +323,7 @@ async def begin(self) -> AsyncGenerator[Transaction, None]: ) as transaction: yield transaction - async def subscribe( # noqa: PLR0913 + async def subscribe( self, destination: str, handler: Callable[[MessageFrame], Coroutine[None, None, None]], diff --git a/stompman/config.py b/stompman/config.py index eb4af63..37080dd 100644 --- a/stompman/config.py +++ b/stompman/config.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field -from typing import NamedTuple, Self, TypedDict +from typing import Self, TypedDict from urllib.parse import unquote -class Heartbeat(NamedTuple): +@dataclass(frozen=True, slots=True) +class Heartbeat: will_send_interval_ms: int want_to_receive_interval_ms: int diff --git a/stompman/connection.py b/stompman/connection.py index 1b55702..cf03201 100644 --- a/stompman/connection.py +++ b/stompman/connection.py @@ -13,8 +13,8 @@ @dataclass(kw_only=True) class AbstractConnection(Protocol): @classmethod - async def connect( # noqa: PLR0913 - cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int + async def connect( + cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int ) -> Self | None: ... async def close(self) -> None: ... def write_heartbeat(self) -> None: ... @@ -38,8 +38,8 @@ class Connection(AbstractConnection): read_timeout: int @classmethod - async def connect( # noqa: PLR0913 - cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int + async def connect( + cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int ) -> Self | None: try: reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) diff --git a/stompman/connection_manager.py b/stompman/connection_manager.py index ab2d0cf..b9d9e3d 100644 --- a/stompman/connection_manager.py +++ b/stompman/connection_manager.py @@ -27,6 +27,7 @@ class ConnectionManager: connect_timeout: int read_timeout: int read_max_chunk_size: int + _active_connection_state: ActiveConnectionState | None = field(default=None, init=False) _reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock) @@ -87,7 +88,7 @@ async def _get_active_connection_state(self) -> ActiveConnectionState: self._active_connection_state = ActiveConnectionState(connection=connection, lifespan=lifespan) try: - await lifespan.__aenter__() # noqa: PLC2801 + await lifespan.__aenter__() except ConnectionLostError: self._clear_active_connection_state() else: diff --git a/stompman/errors.py b/stompman/errors.py index b969c9d..c9ff44f 100644 --- a/stompman/errors.py +++ b/stompman/errors.py @@ -13,6 +13,11 @@ def __str__(self) -> str: return self.__repr__() +@dataclass(frozen=True, kw_only=True, slots=True) +class ConnectionLostError(Error): + """Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private.""" + + @dataclass(frozen=True, kw_only=True, slots=True) class ConnectionConfirmationTimeoutError(Error): timeout: int @@ -36,7 +41,3 @@ class FailedAllConnectAttemptsError(Error): @dataclass(frozen=True, kw_only=True, slots=True) class RepeatedConnectionLostError(Error): retry_attempts: int - - -@dataclass(frozen=True, kw_only=True, slots=True) -class ConnectionLostError(Error): ... diff --git a/stompman/frames.py b/stompman/frames.py index 38540d8..92d139e 100644 --- a/stompman/frames.py +++ b/stompman/frames.py @@ -142,7 +142,7 @@ class SendFrame: body: bytes = b"" @classmethod - def build( # noqa: PLR0913 + def build( cls, *, body: bytes, diff --git a/stompman/serde.py b/stompman/serde.py index 226a3d0..3deab18 100644 --- a/stompman/serde.py +++ b/stompman/serde.py @@ -98,7 +98,7 @@ def dump_frame(frame: AnyClientFrame | AnyRealServerFrame) -> bytes: return b"".join(lines) -def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None: +def unescape_byte(*, byte: bytes, previous_byte: bytes | None) -> bytes | None: if previous_byte == BACKSLASH: return HEADER_UNESCAPE_CHARS.get(byte) if byte == BACKSLASH: @@ -123,7 +123,7 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None: just_escaped_line = False if byte != BACKSLASH: (value_buffer if key_parsed else key_buffer).extend(byte) - elif unescaped_byte := unescape_byte(byte, previous_byte): + elif unescaped_byte := unescape_byte(byte=byte, previous_byte=previous_byte): just_escaped_line = True (value_buffer if key_parsed else key_buffer).extend(unescaped_byte) @@ -136,13 +136,10 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None: return None -def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame: +def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame: frame_type = COMMANDS_TO_FRAMES[command] - return ( - frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg] - if frame_type in FRAMES_WITH_BODY - else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg] - ) + headers_ = cast(Any, headers) + return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg] def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame: diff --git a/testing/consumer.py b/testing/consumer.py index 5a5b673..e026831 100644 --- a/testing/consumer.py +++ b/testing/consumer.py @@ -5,11 +5,10 @@ async def main() -> None: - async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client: - - async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029 - print(frame) # noqa: T201 + async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029 + print(frame) # noqa: T201 + async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client: await client.subscribe("DLQ", handler=handle_message, on_suppressed_exception=print) diff --git a/tests/conftest.py b/tests/conftest.py index 5d902d7..e94092b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from polyfactory.factories.dataclass_factory import DataclassFactory import stompman +from stompman.frames import HeartbeatFrame @pytest.fixture( @@ -34,8 +35,8 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No class BaseMockConnection(stompman.AbstractConnection): @classmethod - async def connect( # noqa: PLR0913 - cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int + async def connect( + cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int ) -> Self | None: return cls() @@ -45,7 +46,7 @@ async def write_frame(self, frame: stompman.AnyClientFrame) -> None: ... @staticmethod async def read_frames() -> AsyncGenerator[stompman.AnyServerFrame, None]: # pragma: no cover await asyncio.Future() - yield # type: ignore[misc] + yield HeartbeatFrame() @dataclass(kw_only=True, slots=True) diff --git a/tests/test_client.py b/tests/test_client.py index e03200a..7fca56d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -148,7 +148,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]: await asyncio.sleep(0) with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info: - await EnrichedClient( # noqa: PLC2801 + await EnrichedClient( connection_class=MockConnection, connection_confirmation_timeout=connection_confirmation_timeout ).__aenter__() @@ -161,7 +161,7 @@ async def test_client_connection_lifespan_unsupported_protocol_version() -> None given_version = FAKER.pystr() with pytest.raises(UnsupportedProtocolVersionError) as exc_info: - await EnrichedClient( # noqa: PLC2801 + await EnrichedClient( connection_class=create_spying_connection( [build_dataclass(ConnectedFrame, headers={"version": given_version})] )[0] @@ -279,7 +279,8 @@ async def test_client_subscribtions_lifespan_no_active_subs_in_aexit(monkeypatch @pytest.mark.parametrize("direct_error", [True, False]) async def test_client_subscribtions_lifespan_with_active_subs_in_aexit( monkeypatch: pytest.MonkeyPatch, - direct_error: bool, # noqa: FBT001 + *, + direct_error: bool, ) -> None: subscription_id, destination = FAKER.pystr(), FAKER.pystr() monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id)) @@ -388,7 +389,7 @@ async def test_client_listen_unsubscribe_before_ack_or_nack( @pytest.mark.parametrize("ok", [True, False]) @pytest.mark.parametrize("ack", ["client", "client-individual"]) -async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, ok: bool) -> None: # noqa: FBT001 +async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, *, ok: bool) -> None: subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr() monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id)) @@ -418,7 +419,7 @@ async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: @pytest.mark.parametrize("ok", [True, False]) -async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, ok: bool) -> None: # noqa: FBT001 +async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, *, ok: bool) -> None: subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr() monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id)) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6aba20c..2871a5c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -76,7 +76,6 @@ class MockWriter: HeartbeatFrame(), ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}), ] - max_chunk_size = 1024 class MockReader: read = mock.AsyncMock(side_effect=read_bytes) @@ -100,7 +99,7 @@ async def take_frames(count: int) -> list[AnyServerFrame]: MockWriter.close.assert_called_once_with() MockWriter.wait_closed.assert_called_once_with() MockWriter.drain.assert_called_once_with() - MockReader.read.mock_calls = [mock.call(max_chunk_size)] * len(read_bytes) # type: ignore[assignment] + assert MockReader.read.mock_calls == [mock.call(connection.read_max_chunk_size)] * len(read_bytes) assert MockWriter.write.mock_calls == [mock.call(NEWLINE), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")] diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py index 1452556..cb2e3b2 100644 --- a/tests/test_connection_manager.py +++ b/tests/test_connection_manager.py @@ -29,15 +29,21 @@ async def test_connect_to_one_server_ok(ok_on_attempt: int, monkeypatch: pytest. class MockConnection(BaseMockConnection): @classmethod - async def connect( # noqa: PLR0913 - cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int + async def connect( + cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int ) -> Self | None: assert (host, port) == (manager.servers[0].host, manager.servers[0].port) nonlocal attempts attempts += 1 return ( - await super().connect(host, port, timeout, read_max_chunk_size, read_timeout) + await super().connect( + host=host, + port=port, + timeout=timeout, + read_max_chunk_size=read_max_chunk_size, + read_timeout=read_timeout, + ) if attempts == ok_on_attempt else None ) @@ -60,11 +66,17 @@ class MockConnection(BaseMockConnection): async def test_connect_to_any_server_ok() -> None: class MockConnection(BaseMockConnection): @classmethod - async def connect( # noqa: PLR0913 - cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int + async def connect( + cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int ) -> Self | None: return ( - await super().connect(host, port, timeout, read_max_chunk_size, read_timeout) + await super().connect( + host=host, + port=port, + timeout=timeout, + read_max_chunk_size=read_max_chunk_size, + read_timeout=read_timeout, + ) if port == successful_server.port else None ) @@ -224,19 +236,17 @@ class MockConnection(BaseMockConnection): async def test_read_frames_reconnecting_raises() -> None: - async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]: - raise ConnectionLostError - yield - await asyncio.sleep(0) - class MockConnection(BaseMockConnection): - read_frames = read_frames_mock # type: ignore[assignment] + @staticmethod + async def read_frames() -> AsyncGenerator[AnyServerFrame, None]: + raise ConnectionLostError + yield + await asyncio.sleep(0) manager = EnrichedConnectionManager(connection_class=MockConnection) - with pytest.raises(RepeatedConnectionLostError): # noqa: PT012 - async for _ in manager.read_frames_reconnecting(): - pass # pragma: no cover + with pytest.raises(RepeatedConnectionLostError): + [_ async for _ in manager.read_frames_reconnecting()] SIDE_EFFECTS = [(None,), (ConnectionLostError(), None), (ConnectionLostError(), ConnectionLostError(), None)] @@ -279,18 +289,17 @@ async def test_read_frames_reconnecting_ok(side_effect: tuple[None | ConnectionL ] attempt = -1 - async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]: - nonlocal attempt - attempt += 1 - current_effect = side_effect[attempt] - if isinstance(current_effect, ConnectionLostError): - raise ConnectionLostError - for frame in frames: - yield frame - await asyncio.sleep(0) - class MockConnection(BaseMockConnection): - read_frames = read_frames_mock # type: ignore[assignment] + @staticmethod + async def read_frames() -> AsyncGenerator[AnyServerFrame, None]: + nonlocal attempt + attempt += 1 + current_effect = side_effect[attempt] + if isinstance(current_effect, ConnectionLostError): + raise ConnectionLostError + for frame in frames: + yield frame + await asyncio.sleep(0) manager = EnrichedConnectionManager(connection_class=MockConnection)