-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Joost Lekkerkerker <[email protected]>
- Loading branch information
Showing
13 changed files
with
658 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""go2rtc client.""" | ||
|
||
from .client import Go2RtcClient | ||
from . import ws | ||
from .models import Stream, WebRTCSdpAnswer, WebRTCSdpOffer | ||
from .rest import Go2RtcRestClient | ||
|
||
__all__ = ["Go2RtcClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer"] | ||
__all__ = ["Go2RtcRestClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer", "ws"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Go2rtc client exceptions.""" | ||
|
||
from __future__ import annotations | ||
|
||
|
||
class Go2RtcClientError(Exception): | ||
"""Base exception for go2rtc client.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Websocket module.""" | ||
|
||
from .client import Go2RtcWsClient | ||
from .messages import ( | ||
ReceiveMessages, | ||
SendMessages, | ||
WebRTCAnswer, | ||
WebRTCCandidate, | ||
WebRTCOffer, | ||
WsError, | ||
) | ||
|
||
__all__ = [ | ||
"ReceiveMessages", | ||
"SendMessages", | ||
"Go2RtcWsClient", | ||
"WebRTCCandidate", | ||
"WebRTCOffer", | ||
"WebRTCAnswer", | ||
"WsError", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Websocket client for go2rtc server.""" | ||
|
||
import asyncio | ||
from collections.abc import Callable | ||
import logging | ||
from typing import TYPE_CHECKING, Any | ||
from urllib.parse import urljoin | ||
|
||
from aiohttp import ( | ||
ClientError, | ||
ClientSession, | ||
ClientWebSocketResponse, | ||
WSMsgType, | ||
WSServerHandshakeError, | ||
) | ||
|
||
from go2rtc_client.exceptions import Go2RtcClientError | ||
from go2rtc_client.ws.messages import BaseMessage | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class Go2RtcWsClient: | ||
"""Websocket client for go2rtc server.""" | ||
|
||
def __init__( | ||
self, | ||
session: ClientSession, | ||
server_url: str, | ||
*, | ||
source: str | None = None, | ||
destination: str | None = None, | ||
) -> None: | ||
"""Initialize Client.""" | ||
if source: | ||
if destination: | ||
msg = "Source and destination cannot be set at the same time" | ||
raise ValueError(msg) | ||
params = {"src": source} | ||
elif destination: | ||
params = {"dst": destination} | ||
else: | ||
msg = "Source or destination must be set" | ||
raise ValueError(msg) | ||
|
||
self._server_url = server_url | ||
self._session = session | ||
self._params = params | ||
self._client: ClientWebSocketResponse | None = None | ||
self._rx_task: asyncio.Task[None] | None = None | ||
self._subscribers: list[Callable[[BaseMessage], None]] = [] | ||
self._connect_lock = asyncio.Lock() | ||
|
||
@property | ||
def connected(self) -> bool: | ||
"""Return if we're currently connected.""" | ||
return self._client is not None and not self._client.closed | ||
|
||
async def connect(self) -> None: | ||
"""Connect to device.""" | ||
async with self._connect_lock: | ||
if self.connected: | ||
return | ||
|
||
_LOGGER.debug("Trying to connect to %s", self._server_url) | ||
try: | ||
self._client = await self._session.ws_connect( | ||
urljoin(self._server_url, "/api/ws"), params=self._params | ||
) | ||
except ( | ||
WSServerHandshakeError, | ||
ClientError, | ||
) as err: | ||
raise Go2RtcClientError(err) from err | ||
|
||
self._rx_task = asyncio.create_task(self._receive_messages()) | ||
_LOGGER.info("Connected to %s", self._server_url) | ||
|
||
async def close(self) -> None: | ||
"""Close connection.""" | ||
if self.connected: | ||
if TYPE_CHECKING: | ||
assert self._client is not None | ||
client = self._client | ||
self._client = None | ||
await client.close() | ||
|
||
async def send(self, message: BaseMessage) -> None: | ||
"""Send a message.""" | ||
if not self.connected: | ||
await self.connect() | ||
|
||
if TYPE_CHECKING: | ||
assert self._client is not None | ||
|
||
await self._client.send_str(message.to_json()) | ||
|
||
def _process_text_message(self, data: Any) -> None: | ||
"""Process text message.""" | ||
try: | ||
message = BaseMessage.from_json(data) | ||
except Exception: # pylint: disable=broad-except | ||
_LOGGER.exception("Invalid message received: %s", data) | ||
else: | ||
for subscriber in self._subscribers: | ||
try: | ||
subscriber(message) | ||
except Exception: # pylint: disable=broad-except | ||
_LOGGER.exception("Error on subscriber callback") | ||
|
||
async def _receive_messages(self) -> None: | ||
"""Receive messages.""" | ||
if TYPE_CHECKING: | ||
assert self._client | ||
|
||
try: | ||
while self.connected: | ||
msg = await self._client.receive() | ||
match msg.type: | ||
case ( | ||
WSMsgType.CLOSE | ||
| WSMsgType.CLOSED | ||
| WSMsgType.CLOSING | ||
| WSMsgType.PING | ||
| WSMsgType.PONG | ||
): | ||
break | ||
case WSMsgType.ERROR: | ||
_LOGGER.error("Error received: %s", msg.data) | ||
case WSMsgType.TEXT: | ||
self._process_text_message(msg.data) | ||
case _: | ||
_LOGGER.warning("Received unknown message: %s", msg) | ||
except Exception: | ||
_LOGGER.exception("Unexpected error while receiving message") | ||
raise | ||
finally: | ||
_LOGGER.debug( | ||
"Websocket client connection from %s closed", self._server_url | ||
) | ||
|
||
if self.connected: | ||
await self.close() | ||
|
||
def subscribe(self, callback: Callable[[BaseMessage], None]) -> Callable[[], None]: | ||
"""Subscribe to messages.""" | ||
|
||
def _unsubscribe() -> None: | ||
self._subscribers.remove(callback) | ||
|
||
self._subscribers.append(callback) | ||
return _unsubscribe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Go2rtc websocket messages.""" | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any, ClassVar | ||
|
||
from mashumaro import field_options | ||
from mashumaro.config import BaseConfig | ||
from mashumaro.mixins.orjson import DataClassORJSONMixin | ||
from mashumaro.types import Discriminator | ||
|
||
|
||
@dataclass(frozen=True) | ||
class BaseMessage(DataClassORJSONMixin): | ||
"""Base message class.""" | ||
|
||
TYPE: ClassVar[str] | ||
|
||
class Config(BaseConfig): | ||
"""Config for BaseMessage.""" | ||
|
||
serialize_by_alias = True | ||
discriminator = Discriminator( | ||
field="type", | ||
include_subtypes=True, | ||
variant_tagger_fn=lambda cls: cls.TYPE, | ||
) | ||
|
||
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]: | ||
"""Add type to serialized dict.""" | ||
# ClassVar will not serialize by default | ||
d["type"] = self.TYPE | ||
return d | ||
|
||
|
||
@dataclass(frozen=True) | ||
class WebRTCCandidate(BaseMessage): | ||
"""WebRTC ICE candidate message.""" | ||
|
||
TYPE = "webrtc/candidate" | ||
candidate: str = field(metadata=field_options(alias="value")) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class WebRTCOffer(BaseMessage): | ||
"""WebRTC offer message.""" | ||
|
||
TYPE = "webrtc/offer" | ||
offer: str = field(metadata=field_options(alias="value")) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class WebRTCAnswer(BaseMessage): | ||
"""WebRTC answer message.""" | ||
|
||
TYPE = "webrtc/answer" | ||
answer: str = field(metadata=field_options(alias="value")) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class WsError(BaseMessage): | ||
"""Error message.""" | ||
|
||
TYPE = "error" | ||
error: str = field(metadata=field_options(alias="value")) | ||
|
||
|
||
ReceiveMessages = WebRTCAnswer | WebRTCCandidate | WsError | ||
SendMessages = WebRTCCandidate | WebRTCOffer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.