Skip to content

Commit

Permalink
As ws client (#2)
Browse files Browse the repository at this point in the history
Co-authored-by: Joost Lekkerkerker <[email protected]>
  • Loading branch information
edenhaus and joostlek authored Oct 16, 2024
1 parent 047e12a commit e85e8fd
Show file tree
Hide file tree
Showing 13 changed files with 658 additions and 17 deletions.
5 changes: 3 additions & 2 deletions go2rtc_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""go2rtc client."""

from .client import Go2RtcClient
from . import ws
from .models import Stream, WebRTCSdpAnswer, WebRTCSdpOffer
from .rest import Go2RtcRestClient

__all__ = ["Go2RtcClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer"]
__all__ = ["Go2RtcRestClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer", "ws"]
7 changes: 7 additions & 0 deletions go2rtc_client/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Go2rtc client exceptions."""

from __future__ import annotations


class Go2RtcClientError(Exception):
"""Base exception for go2rtc client."""
4 changes: 2 additions & 2 deletions go2rtc_client/client.py → go2rtc_client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ async def add(self, name: str, source: str) -> None:
)


class Go2RtcClient:
"""Client for go2rtc server."""
class Go2RtcRestClient:
"""Rest client for go2rtc server."""

def __init__(self, websession: ClientSession, server_url: str) -> None:
"""Initialize Client."""
Expand Down
21 changes: 21 additions & 0 deletions go2rtc_client/ws/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Websocket module."""

from .client import Go2RtcWsClient
from .messages import (
ReceiveMessages,
SendMessages,
WebRTCAnswer,
WebRTCCandidate,
WebRTCOffer,
WsError,
)

__all__ = [
"ReceiveMessages",
"SendMessages",
"Go2RtcWsClient",
"WebRTCCandidate",
"WebRTCOffer",
"WebRTCAnswer",
"WsError",
]
152 changes: 152 additions & 0 deletions go2rtc_client/ws/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Websocket client for go2rtc server."""

import asyncio
from collections.abc import Callable
import logging
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from aiohttp import (
ClientError,
ClientSession,
ClientWebSocketResponse,
WSMsgType,
WSServerHandshakeError,
)

from go2rtc_client.exceptions import Go2RtcClientError
from go2rtc_client.ws.messages import BaseMessage

_LOGGER = logging.getLogger(__name__)


class Go2RtcWsClient:
"""Websocket client for go2rtc server."""

def __init__(
self,
session: ClientSession,
server_url: str,
*,
source: str | None = None,
destination: str | None = None,
) -> None:
"""Initialize Client."""
if source:
if destination:
msg = "Source and destination cannot be set at the same time"
raise ValueError(msg)
params = {"src": source}
elif destination:
params = {"dst": destination}
else:
msg = "Source or destination must be set"
raise ValueError(msg)

self._server_url = server_url
self._session = session
self._params = params
self._client: ClientWebSocketResponse | None = None
self._rx_task: asyncio.Task[None] | None = None
self._subscribers: list[Callable[[BaseMessage], None]] = []
self._connect_lock = asyncio.Lock()

@property
def connected(self) -> bool:
"""Return if we're currently connected."""
return self._client is not None and not self._client.closed

async def connect(self) -> None:
"""Connect to device."""
async with self._connect_lock:
if self.connected:
return

_LOGGER.debug("Trying to connect to %s", self._server_url)
try:
self._client = await self._session.ws_connect(
urljoin(self._server_url, "/api/ws"), params=self._params
)
except (
WSServerHandshakeError,
ClientError,
) as err:
raise Go2RtcClientError(err) from err

self._rx_task = asyncio.create_task(self._receive_messages())
_LOGGER.info("Connected to %s", self._server_url)

async def close(self) -> None:
"""Close connection."""
if self.connected:
if TYPE_CHECKING:
assert self._client is not None
client = self._client
self._client = None
await client.close()

async def send(self, message: BaseMessage) -> None:
"""Send a message."""
if not self.connected:
await self.connect()

if TYPE_CHECKING:
assert self._client is not None

await self._client.send_str(message.to_json())

def _process_text_message(self, data: Any) -> None:
"""Process text message."""
try:
message = BaseMessage.from_json(data)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Invalid message received: %s", data)
else:
for subscriber in self._subscribers:
try:
subscriber(message)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error on subscriber callback")

async def _receive_messages(self) -> None:
"""Receive messages."""
if TYPE_CHECKING:
assert self._client

try:
while self.connected:
msg = await self._client.receive()
match msg.type:
case (
WSMsgType.CLOSE
| WSMsgType.CLOSED
| WSMsgType.CLOSING
| WSMsgType.PING
| WSMsgType.PONG
):
break
case WSMsgType.ERROR:
_LOGGER.error("Error received: %s", msg.data)
case WSMsgType.TEXT:
self._process_text_message(msg.data)
case _:
_LOGGER.warning("Received unknown message: %s", msg)
except Exception:
_LOGGER.exception("Unexpected error while receiving message")
raise
finally:
_LOGGER.debug(
"Websocket client connection from %s closed", self._server_url
)

if self.connected:
await self.close()

def subscribe(self, callback: Callable[[BaseMessage], None]) -> Callable[[], None]:
"""Subscribe to messages."""

def _unsubscribe() -> None:
self._subscribers.remove(callback)

self._subscribers.append(callback)
return _unsubscribe
70 changes: 70 additions & 0 deletions go2rtc_client/ws/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Go2rtc websocket messages."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, ClassVar

from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import Discriminator


@dataclass(frozen=True)
class BaseMessage(DataClassORJSONMixin):
"""Base message class."""

TYPE: ClassVar[str]

class Config(BaseConfig):
"""Config for BaseMessage."""

serialize_by_alias = True
discriminator = Discriminator(
field="type",
include_subtypes=True,
variant_tagger_fn=lambda cls: cls.TYPE,
)

def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
"""Add type to serialized dict."""
# ClassVar will not serialize by default
d["type"] = self.TYPE
return d


@dataclass(frozen=True)
class WebRTCCandidate(BaseMessage):
"""WebRTC ICE candidate message."""

TYPE = "webrtc/candidate"
candidate: str = field(metadata=field_options(alias="value"))


@dataclass(frozen=True)
class WebRTCOffer(BaseMessage):
"""WebRTC offer message."""

TYPE = "webrtc/offer"
offer: str = field(metadata=field_options(alias="value"))


@dataclass(frozen=True)
class WebRTCAnswer(BaseMessage):
"""WebRTC answer message."""

TYPE = "webrtc/answer"
answer: str = field(metadata=field_options(alias="value"))


@dataclass(frozen=True)
class WsError(BaseMessage):
"""Error message."""

TYPE = "error"
error: str = field(metadata=field_options(alias="value"))


ReceiveMessages = WebRTCAnswer | WebRTCCandidate | WsError
SendMessages = WebRTCCandidate | WebRTCOffer
15 changes: 14 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
version = "0.0.0"

[project.urls]
"Homepage" = "https://deebot.readthedocs.io/"
"Homepage" = "https://pypi.org/project/go2rtc-client"
"Source Code" = "https://github.com/home-assistant-libs/python-go2rtc-client"
"Bug Reports" = "https://github.com/home-assistant-libs/python-go2rtc-client/issues"

Expand All @@ -35,7 +35,9 @@ dev-dependencies = [
"covdefaults>=2.3.0",
"mypy==1.11.2",
"pre-commit==3.8.0",
"pylint-per-file-ignores>=1.3.2",
"pylint==3.2.7",
"pytest-aiohttp>=1.0.5",
"pytest-asyncio==0.24.0",
"pytest-cov==5.0.0",
"pytest-timeout==2.3.1",
Expand Down Expand Up @@ -118,6 +120,11 @@ good-names = [
[tool.pylint.DESIGN]
max-attributes = 8

[tool.pylint.MASTER]
load-plugins=[
"pylint_per_file_ignores",
]

[tool.pylint."MESSAGES CONTROL"]
disable = [
"duplicate-code",
Expand All @@ -130,6 +137,12 @@ disable = [
"wrong-import-order",
]

per-file-ignores = [
# redefined-outer-name: Tests reference fixtures in the test function
"/tests/:redefined-outer-name",
]


[tool.pylint.SIMILARITIES]
ignore-imports = true

Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from syrupy import SnapshotAssertion

from go2rtc_client import Go2RtcClient
from go2rtc_client import Go2RtcRestClient

from . import URL
from .syrupy import Go2RtcSnapshotExtension
Expand All @@ -20,12 +20,12 @@ def snapshot_assertion(snapshot: SnapshotAssertion) -> SnapshotAssertion:


@pytest.fixture
async def client() -> AsyncGenerator[Go2RtcClient, None]:
"""Return a go2rtc client."""
async def rest_client() -> AsyncGenerator[Go2RtcRestClient, None]:
"""Return a go2rtc rest client."""
async with (
aiohttp.ClientSession() as session,
):
client_ = Go2RtcClient(
client_ = Go2RtcRestClient(
session,
URL,
)
Expand Down
Loading

0 comments on commit e85e8fd

Please sign in to comment.