Skip to content

Commit

Permalink
Force keyword arguments in some cases to avoid ambiguity (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Aug 1, 2024
1 parent 3db311a commit cbace92
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 67 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ line-length = 120
[tool.ruff.lint]
preview = true
select = ["ALL"]
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001"]
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001", "PLR0913", "PLC2801"]
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] }

[tool.pytest.ini_options]
Expand Down
12 changes: 6 additions & 6 deletions stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@asynccontextmanager
async def connection_lifespan( # noqa: PLR0913
async def connection_lifespan(
*,
connection: AbstractConnection,
connection_parameters: ConnectionParameters,
Expand Down Expand Up @@ -122,7 +122,7 @@ async def unsubscribe(self) -> None:
del self._active_subscriptions[self.id]
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))

async def _run_handler(self, frame: MessageFrame) -> None:
async def _run_handler(self, *, frame: MessageFrame) -> None:
try:
await self.handler(frame)
except self.supressed_exception_classes as exception:
Expand Down Expand Up @@ -185,7 +185,7 @@ async def __aexit__(
self._active_transactions.remove(self)

async def send(
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
) -> None:
frame = SendFrame.build(
body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers
Expand Down Expand Up @@ -295,7 +295,7 @@ async def _listen_to_frames(self) -> None:
match frame:
case MessageFrame():
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
task_group.create_task(subscription._run_handler(frame)) # noqa: SLF001
task_group.create_task(subscription._run_handler(frame=frame)) # noqa: SLF001
elif self.on_unhandled_message_frame:
self.on_unhandled_message_frame(frame)
case ErrorFrame():
Expand All @@ -308,7 +308,7 @@ async def _listen_to_frames(self) -> None:
pass

async def send(
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
) -> None:
await self._connection_manager.write_frame_reconnecting(
SendFrame.build(
Expand All @@ -323,7 +323,7 @@ async def begin(self) -> AsyncGenerator[Transaction, None]:
) as transaction:
yield transaction

async def subscribe( # noqa: PLR0913
async def subscribe(
self,
destination: str,
handler: Callable[[MessageFrame], Coroutine[None, None, None]],
Expand Down
5 changes: 3 additions & 2 deletions stompman/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from typing import NamedTuple, Self, TypedDict
from typing import Self, TypedDict
from urllib.parse import unquote


class Heartbeat(NamedTuple):
@dataclass(frozen=True, slots=True)
class Heartbeat:
will_send_interval_ms: int
want_to_receive_interval_ms: int

Expand Down
8 changes: 4 additions & 4 deletions stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
@dataclass(kw_only=True)
class AbstractConnection(Protocol):
@classmethod
async def connect( # noqa: PLR0913
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
) -> Self | None: ...
async def close(self) -> None: ...
def write_heartbeat(self) -> None: ...
Expand All @@ -38,8 +38,8 @@ class Connection(AbstractConnection):
read_timeout: int

@classmethod
async def connect( # noqa: PLR0913
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
) -> Self | None:
try:
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)
Expand Down
3 changes: 2 additions & 1 deletion stompman/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ConnectionManager:
connect_timeout: int
read_timeout: int
read_max_chunk_size: int

_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)

Expand Down Expand Up @@ -87,7 +88,7 @@ async def _get_active_connection_state(self) -> ActiveConnectionState:
self._active_connection_state = ActiveConnectionState(connection=connection, lifespan=lifespan)

try:
await lifespan.__aenter__() # noqa: PLC2801
await lifespan.__aenter__()
except ConnectionLostError:
self._clear_active_connection_state()
else:
Expand Down
9 changes: 5 additions & 4 deletions stompman/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def __str__(self) -> str:
return self.__repr__()


@dataclass(frozen=True, kw_only=True, slots=True)
class ConnectionLostError(Error):
"""Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private."""


@dataclass(frozen=True, kw_only=True, slots=True)
class ConnectionConfirmationTimeoutError(Error):
timeout: int
Expand All @@ -36,7 +41,3 @@ class FailedAllConnectAttemptsError(Error):
@dataclass(frozen=True, kw_only=True, slots=True)
class RepeatedConnectionLostError(Error):
retry_attempts: int


@dataclass(frozen=True, kw_only=True, slots=True)
class ConnectionLostError(Error): ...
2 changes: 1 addition & 1 deletion stompman/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SendFrame:
body: bytes = b""

@classmethod
def build( # noqa: PLR0913
def build(
cls,
*,
body: bytes,
Expand Down
13 changes: 5 additions & 8 deletions stompman/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dump_frame(frame: AnyClientFrame | AnyRealServerFrame) -> bytes:
return b"".join(lines)


def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None:
def unescape_byte(*, byte: bytes, previous_byte: bytes | None) -> bytes | None:
if previous_byte == BACKSLASH:
return HEADER_UNESCAPE_CHARS.get(byte)
if byte == BACKSLASH:
Expand All @@ -123,7 +123,7 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
just_escaped_line = False
if byte != BACKSLASH:
(value_buffer if key_parsed else key_buffer).extend(byte)
elif unescaped_byte := unescape_byte(byte, previous_byte):
elif unescaped_byte := unescape_byte(byte=byte, previous_byte=previous_byte):
just_escaped_line = True
(value_buffer if key_parsed else key_buffer).extend(unescaped_byte)

Expand All @@ -136,13 +136,10 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
return None


def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
frame_type = COMMANDS_TO_FRAMES[command]
return (
frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg]
if frame_type in FRAMES_WITH_BODY
else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg]
)
headers_ = cast(Any, headers)
return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg]


def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:
Expand Down
7 changes: 3 additions & 4 deletions testing/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@


async def main() -> None:
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:

async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
print(frame) # noqa: T201
async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
print(frame) # noqa: T201

async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:
await client.subscribe("DLQ", handler=handle_message, on_suppressed_exception=print)


Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from polyfactory.factories.dataclass_factory import DataclassFactory

import stompman
from stompman.frames import HeartbeatFrame


@pytest.fixture(
Expand All @@ -34,8 +35,8 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No

class BaseMockConnection(stompman.AbstractConnection):
@classmethod
async def connect( # noqa: PLR0913
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
) -> Self | None:
return cls()

Expand All @@ -45,7 +46,7 @@ async def write_frame(self, frame: stompman.AnyClientFrame) -> None: ...
@staticmethod
async def read_frames() -> AsyncGenerator[stompman.AnyServerFrame, None]: # pragma: no cover
await asyncio.Future()
yield # type: ignore[misc]
yield HeartbeatFrame()


@dataclass(kw_only=True, slots=True)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
await asyncio.sleep(0)

with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info:
await EnrichedClient( # noqa: PLC2801
await EnrichedClient(
connection_class=MockConnection, connection_confirmation_timeout=connection_confirmation_timeout
).__aenter__()

Expand All @@ -161,7 +161,7 @@ async def test_client_connection_lifespan_unsupported_protocol_version() -> None
given_version = FAKER.pystr()

with pytest.raises(UnsupportedProtocolVersionError) as exc_info:
await EnrichedClient( # noqa: PLC2801
await EnrichedClient(
connection_class=create_spying_connection(
[build_dataclass(ConnectedFrame, headers={"version": given_version})]
)[0]
Expand Down Expand Up @@ -279,7 +279,8 @@ async def test_client_subscribtions_lifespan_no_active_subs_in_aexit(monkeypatch
@pytest.mark.parametrize("direct_error", [True, False])
async def test_client_subscribtions_lifespan_with_active_subs_in_aexit(
monkeypatch: pytest.MonkeyPatch,
direct_error: bool, # noqa: FBT001
*,
direct_error: bool,
) -> None:
subscription_id, destination = FAKER.pystr(), FAKER.pystr()
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))
Expand Down Expand Up @@ -388,7 +389,7 @@ async def test_client_listen_unsubscribe_before_ack_or_nack(

@pytest.mark.parametrize("ok", [True, False])
@pytest.mark.parametrize("ack", ["client", "client-individual"])
async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, ok: bool) -> None: # noqa: FBT001
async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, *, ok: bool) -> None:
subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))

Expand Down Expand Up @@ -418,7 +419,7 @@ async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack:


@pytest.mark.parametrize("ok", [True, False])
async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, ok: bool) -> None: # noqa: FBT001
async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, *, ok: bool) -> None:
subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))

Expand Down
3 changes: 1 addition & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class MockWriter:
HeartbeatFrame(),
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
]
max_chunk_size = 1024

class MockReader:
read = mock.AsyncMock(side_effect=read_bytes)
Expand All @@ -100,7 +99,7 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
MockWriter.close.assert_called_once_with()
MockWriter.wait_closed.assert_called_once_with()
MockWriter.drain.assert_called_once_with()
MockReader.read.mock_calls = [mock.call(max_chunk_size)] * len(read_bytes) # type: ignore[assignment]
assert MockReader.read.mock_calls == [mock.call(connection.read_max_chunk_size)] * len(read_bytes)
assert MockWriter.write.mock_calls == [mock.call(NEWLINE), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")]


Expand Down
61 changes: 35 additions & 26 deletions tests/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ async def test_connect_to_one_server_ok(ok_on_attempt: int, monkeypatch: pytest.

class MockConnection(BaseMockConnection):
@classmethod
async def connect( # noqa: PLR0913
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
) -> Self | None:
assert (host, port) == (manager.servers[0].host, manager.servers[0].port)
nonlocal attempts
attempts += 1

return (
await super().connect(host, port, timeout, read_max_chunk_size, read_timeout)
await super().connect(
host=host,
port=port,
timeout=timeout,
read_max_chunk_size=read_max_chunk_size,
read_timeout=read_timeout,
)
if attempts == ok_on_attempt
else None
)
Expand All @@ -60,11 +66,17 @@ class MockConnection(BaseMockConnection):
async def test_connect_to_any_server_ok() -> None:
class MockConnection(BaseMockConnection):
@classmethod
async def connect( # noqa: PLR0913
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
) -> Self | None:
return (
await super().connect(host, port, timeout, read_max_chunk_size, read_timeout)
await super().connect(
host=host,
port=port,
timeout=timeout,
read_max_chunk_size=read_max_chunk_size,
read_timeout=read_timeout,
)
if port == successful_server.port
else None
)
Expand Down Expand Up @@ -224,19 +236,17 @@ class MockConnection(BaseMockConnection):


async def test_read_frames_reconnecting_raises() -> None:
async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]:
raise ConnectionLostError
yield
await asyncio.sleep(0)

class MockConnection(BaseMockConnection):
read_frames = read_frames_mock # type: ignore[assignment]
@staticmethod
async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
raise ConnectionLostError
yield
await asyncio.sleep(0)

manager = EnrichedConnectionManager(connection_class=MockConnection)

with pytest.raises(RepeatedConnectionLostError): # noqa: PT012
async for _ in manager.read_frames_reconnecting():
pass # pragma: no cover
with pytest.raises(RepeatedConnectionLostError):
[_ async for _ in manager.read_frames_reconnecting()]


SIDE_EFFECTS = [(None,), (ConnectionLostError(), None), (ConnectionLostError(), ConnectionLostError(), None)]
Expand Down Expand Up @@ -279,18 +289,17 @@ async def test_read_frames_reconnecting_ok(side_effect: tuple[None | ConnectionL
]
attempt = -1

async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]:
nonlocal attempt
attempt += 1
current_effect = side_effect[attempt]
if isinstance(current_effect, ConnectionLostError):
raise ConnectionLostError
for frame in frames:
yield frame
await asyncio.sleep(0)

class MockConnection(BaseMockConnection):
read_frames = read_frames_mock # type: ignore[assignment]
@staticmethod
async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
nonlocal attempt
attempt += 1
current_effect = side_effect[attempt]
if isinstance(current_effect, ConnectionLostError):
raise ConnectionLostError
for frame in frames:
yield frame
await asyncio.sleep(0)

manager = EnrichedConnectionManager(connection_class=MockConnection)

Expand Down

0 comments on commit cbace92

Please sign in to comment.