Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] argilla server: add user and server id on telemetry metrics #5445

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@
from argilla_server.security.authentication.userinfo import UserInfo


def set_request_user(request: Request, user: User):
"""
Set the request user in the request state.

Parameters:
request: The request object.
user: The user.

"""

request.state.user = user


def get_request_user(request: Request) -> Optional[User]:
"""
Get the current user from the request.

Parameters:
request (Request): The request object.

Returns:
The user if available, None otherwise.
"""
return getattr(request.state, "user", None)


class AuthenticationProvider:
"""Authentication provider for the API requests."""

Expand Down Expand Up @@ -58,6 +84,7 @@ async def get_current_user(
if not user:
raise UnauthorizedError()

set_request_user(request, user)
return user

async def _authenticate_request_user(self, db: AsyncSession, request: Request) -> Optional[UserInfo]:
Expand Down
33 changes: 17 additions & 16 deletions argilla-server/src/argilla_server/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@
import json
import logging
import platform
import uuid
from typing import Union

from fastapi import Request, Response
from huggingface_hub.utils import send_telemetry

from argilla_server._version import __version__
from argilla_server.api.errors.v1.exception_handlers import get_request_error
from argilla_server.settings import settings
from argilla_server.utils._fastapi import resolve_endpoint_path_for_request
from argilla_server.utils._telemetry import (
is_running_on_docker_container,
server_deployment_type,
)
from argilla_server.security.authentication.provider import get_request_user

_LOGGER = logging.getLogger(__name__)


@dataclasses.dataclass
class TelemetryClient:
_server_id: int = uuid.getnode()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The uuid.getnode returns an integer value. We can use it "as is".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check that uuid.getnode() is returning the same value on HF Spaces even after a factory rebuild.


def __post_init__(self):
self._system_info = {
"server_id": self._server_id,
"system": platform.system(),
"machine": platform.machine(),
"platform": platform.platform(),
Expand All @@ -47,16 +52,11 @@ def __post_init__(self):
_LOGGER.info("System Info:")
_LOGGER.info(f"Context: {json.dumps(self._system_info, indent=2)}")

async def track_data(self, topic: str, data: dict, include_system_info: bool = True, count: int = 1):
library_name = "argilla/server"
topic = f"{library_name}/{topic}"

user_agent = {**data}
if include_system_info:
user_agent.update(self._system_info)
if count is not None:
user_agent["count"] = count
async def track_data(self, topic: str, data: dict):
library_name = "argilla-server"
Copy link
Member Author

@frascuchon frascuchon Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aligns the library name (argilla-server) to the package published in pypi.org

topic = f"argilla/server/{topic}"

user_agent = {**data, **self._system_info}
send_telemetry(topic=topic, library_name=library_name, library_version=__version__, user_agent=user_agent)

async def track_api_request(self, request: Request, response: Response) -> None:
Expand Down Expand Up @@ -85,15 +85,16 @@ async def track_api_request(self, request: Request, response: Response) -> None:
"response.status": str(response.status_code),
}

if "Server-Timing" in response.headers:
duration_in_ms = response.headers["Server-Timing"]
duration_in_ms = duration_in_ms.removeprefix("total;dur=")

if server_timing := response.headers.get("Server-Timing"):
duration_in_ms = server_timing.removeprefix("total;dur=")
data["duration_in_milliseconds"] = duration_in_ms

if user := get_request_user(request=request):
data["user.id"] = str(user.id)
data["user.role"] = user.role

if response.status_code >= 400:
argilla_error: Exception = get_request_error(request=request)
if argilla_error:
if argilla_error := get_request_error(request=request):
data["response.error_code"] = argilla_error.code # noqa

await self.track_data(topic=topic, data=data)
Expand Down
14 changes: 14 additions & 0 deletions argilla-server/tests/unit/test_api_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import MagicMock, ANY

import pytest
from pytest_mock import MockerFixture
from starlette.testclient import TestClient

from argilla_server._app import create_server_app
Expand Down Expand Up @@ -44,6 +45,19 @@ def test_track_api_request_call_on_error(self, test_telemetry: TelemetryClient):

test_telemetry.track_api_request.assert_called_once()

def test_track_api_request_with_unexpected_telemetry_error(
self, test_telemetry: TelemetryClient, mocker: "MockerFixture"
):
with mocker.patch.object(test_telemetry, "track_api_request", side_effect=Exception("mocked error")):
settings.enable_telemetry = True

client = TestClient(create_server_app())

response = client.get("/api/v1/version")

test_telemetry.track_api_request.assert_called_once()
assert response.status_code == 200

def test_not_track_api_request_call_when_disabled_telemetry(self, test_telemetry: TelemetryClient):
settings.enable_telemetry = False

Expand Down
Loading