Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into async-webrtc-offer-wi…
Browse files Browse the repository at this point in the history
…th-ice-servers
  • Loading branch information
edenhaus committed Oct 28, 2024
2 parents 5173a36 + 140d2cb commit 20b5d18
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 41 deletions.
39 changes: 39 additions & 0 deletions go2rtc_client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,45 @@

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any

from aiohttp import ClientError
from mashumaro.exceptions import (
ExtraKeysError,
InvalidFieldValue,
MissingDiscriminatorError,
MissingField,
SuitableVariantNotFoundError,
UnserializableDataError,
)

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine


class Go2RtcClientError(Exception):
"""Base exception for go2rtc client."""


def handle_error[**_P, _R](
func: Callable[_P, Coroutine[Any, Any, _R]],
) -> Callable[_P, Coroutine[Any, Any, _R]]:
"""Wrap aiohttp and mashumaro errors."""

@wraps(func)
async def _func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return await func(*args, **kwargs)
except (
ClientError,
ExtraKeysError,
InvalidFieldValue,
MissingDiscriminatorError,
MissingField,
SuitableVariantNotFoundError,
UnserializableDataError,
) as exc:
raise Go2RtcClientError from exc

return _func
23 changes: 23 additions & 0 deletions go2rtc_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,30 @@
from dataclasses import dataclass, field
from typing import Literal

from awesomeversion import AwesomeVersion
from mashumaro import field_options
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import SerializationStrategy


class _AwesomeVersionSerializer(SerializationStrategy):
def serialize(self, value: AwesomeVersion) -> str:
return str(value)

def deserialize(self, value: str) -> AwesomeVersion:
return AwesomeVersion(value)


@dataclass
class ApplicationInfo(DataClassORJSONMixin):
"""Application info model.
Currently only the server version is exposed.
"""

version: AwesomeVersion = field(
metadata=field_options(serialization_strategy=_AwesomeVersionSerializer())
)


@dataclass
Expand Down
37 changes: 29 additions & 8 deletions go2rtc_client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@

import logging
from typing import TYPE_CHECKING, Any, Final, Literal
from urllib.parse import urljoin

from aiohttp import ClientError, ClientResponse, ClientSession
from aiohttp.client import _RequestOptions
from awesomeversion import AwesomeVersion
from mashumaro.codecs.basic import BasicDecoder
from mashumaro.mixins.dict import DataClassDictMixin
from yarl import URL

from .models import Stream, WebRTCSdpAnswer, WebRTCSdpOffer
from .exceptions import handle_error
from .models import ApplicationInfo, Stream, WebRTCSdpAnswer, WebRTCSdpOffer

if TYPE_CHECKING:
from collections.abc import Mapping

_LOGGER = logging.getLogger(__name__)

_API_PREFIX = "/api"
_SUPPORTED_VERSION: Final = AwesomeVersion("1.9.4")


class _BaseClient:
Expand All @@ -27,7 +30,7 @@ class _BaseClient:
def __init__(self, websession: ClientSession, server_url: str) -> None:
"""Initialize Client."""
self._session = websession
self._base_url = server_url
self._base_url = URL(server_url)

async def request(
self,
Expand All @@ -38,7 +41,7 @@ async def request(
data: DataClassDictMixin | dict[str, Any] | None = None,
) -> ClientResponse:
"""Make a request to the server."""
url = self._request_url(path)
url = self._base_url.with_path(path)
_LOGGER.debug("request[%s] %s", method, url)
if isinstance(data, DataClassDictMixin):
data = data.to_dict()
Expand All @@ -56,9 +59,18 @@ async def request(
resp.raise_for_status()
return resp

def _request_url(self, path: str) -> str:
"""Return a request url for the specific path."""
return urljoin(self._base_url, path)

class _ApplicationClient:
PATH: Final = _API_PREFIX

def __init__(self, client: _BaseClient) -> None:
"""Initialize Client."""
self._client = client

async def get_info(self) -> ApplicationInfo:
"""Get application info."""
resp = await self._client.request("GET", self.PATH)
return ApplicationInfo.from_dict(await resp.json())


class _WebRTCClient:
Expand All @@ -82,6 +94,7 @@ async def _forward_sdp_offer(
)
return WebRTCSdpAnswer.from_dict(await resp.json())

@handle_error
async def forward_whep_sdp_offer(
self, source_name: str, offer: WebRTCSdpOffer
) -> WebRTCSdpAnswer:
Expand All @@ -103,17 +116,19 @@ def __init__(self, client: _BaseClient) -> None:
"""Initialize Client."""
self._client = client

@handle_error
async def list(self) -> dict[str, Stream]:
"""List streams registered with the server."""
resp = await self._client.request("GET", self.PATH)
return _GET_STREAMS_DECODER.decode(await resp.json())

@handle_error
async def add(self, name: str, source: str) -> None:
"""Add a stream to the server."""
await self._client.request(
"PUT",
self.PATH,
params={"name": name, "src": source},
params={"name": name, "src": [source, f"ffmpeg:{name}#audio=opus"]},
)


Expand All @@ -123,5 +138,11 @@ class Go2RtcRestClient:
def __init__(self, websession: ClientSession, server_url: str) -> None:
"""Initialize Client."""
self._client = _BaseClient(websession, server_url)
self.application: Final = _ApplicationClient(self._client)
self.streams: Final = _StreamClient(self._client)
self.webrtc: Final = _WebRTCClient(self._client)

async def validate_server_version(self) -> bool:
"""Validate the server version is compatible."""
application_info = await self.application.get_info()
return application_info.version == _SUPPORTED_VERSION
25 changes: 8 additions & 17 deletions go2rtc_client/ws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from aiohttp import (
ClientError,
ClientSession,
ClientWebSocketResponse,
WSMsgType,
WSServerHandshakeError,
)
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType

from go2rtc_client.exceptions import Go2RtcClientError
from go2rtc_client.exceptions import handle_error

from .messages import BaseMessage, SendMessages, WebRTC, WsMessage

Expand Down Expand Up @@ -57,26 +51,22 @@ def connected(self) -> bool:
"""Return if we're currently connected."""
return self._client is not None and not self._client.closed

@handle_error
async def connect(self) -> None:
"""Connect to device."""
async with self._connect_lock:
if self.connected:
return

_LOGGER.debug("Trying to connect to %s", self._server_url)
try:
self._client = await self._session.ws_connect(
urljoin(self._server_url, "/api/ws"), params=self._params
)
except (
WSServerHandshakeError,
ClientError,
) as err:
raise Go2RtcClientError(err) from err
self._client = await self._session.ws_connect(
urljoin(self._server_url, "/api/ws"), params=self._params
)

self._rx_task = asyncio.create_task(self._receive_messages())
_LOGGER.info("Connected to %s", self._server_url)

@handle_error
async def close(self) -> None:
"""Close connection."""
if self.connected:
Expand All @@ -86,6 +76,7 @@ async def close(self) -> None:
self._client = None
await client.close()

@handle_error
async def send(self, message: SendMessages) -> None:
"""Send a message."""
if not self.connected:
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
requires-python = ">=3.12.0"
dependencies = [
"aiohttp~=3.10",
"awesomeversion>=24.6.0",
"mashumaro~=3.13",
"orjson>=3.10.7",
"webrtc-models>=0.1.0",
Expand All @@ -34,7 +35,7 @@ version = "0.0.0"
dev-dependencies = [
"aioresponses>=0.7.6",
"covdefaults>=2.3.0",
"mypy==1.11.2",
"mypy-dev==1.13.0a1",
"pre-commit==3.8.0",
"pylint-per-file-ignores>=1.3.2",
"pylint==3.2.7",
Expand Down Expand Up @@ -96,6 +97,7 @@ warn_unused_ignores = true
"D100", # Missing docstring in public module
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"FBT001", # boolean-type-hint-positional-argument
"N802", # Function name {name} should be lowercase
"N816", # Variable {name} in global scope should not be mixedCase
"PLR0913", # Too many arguments in function definition
Expand All @@ -107,6 +109,8 @@ warn_unused_ignores = true
[tool.pylint.BASIC]
good-names = [
"_",
"_P",
"_R",
"ex",
"fp",
"i",
Expand All @@ -115,6 +119,8 @@ good-names = [
"k",
"on",
"Run",
"P",
"R",
"T",
]

Expand Down
10 changes: 10 additions & 0 deletions tests/__snapshots__/test_rest.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# serializer version: 1
# name: test_application_info
dict({
'version': <AwesomeVersion SemVer '1.9.4'>,
})
# ---
# name: test_application_info.1
dict({
'version': '1.9.4',
})
# ---
# name: test_streams_get[empty]
dict({
})
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures/application_info_answer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"config_path": "/home/erik/go2rtc/go2rtc.yaml",
"host": "127.0.0.1:1984",
"revision": "a4885c2",
"rtsp": { "listen": ":8554", "default_query": "video\u0026audio" },
"version": "1.9.4"
}
Loading

0 comments on commit 20b5d18

Please sign in to comment.