Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] argilla server: add user and server id on telemetry metrics #5445

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@
from argilla_server.security.authentication.userinfo import UserInfo


def set_request_user(request: Request, user: User):
"""
Set the request user in the request state.

Parameters:
request: The request object.
user: The user.

"""

request.state.user = user


def get_request_user(request: Request) -> Optional[User]:
"""
Get the current user from the request.

Parameters:
request (Request): The request object.

Returns:
The user if available, None otherwise.
"""
return getattr(request.state, "user", None)


class AuthenticationProvider:
"""Authentication provider for the API requests."""

Expand Down Expand Up @@ -58,6 +84,7 @@ async def get_current_user(
if not user:
raise UnauthorizedError()

set_request_user(request, user)
return user

async def _authenticate_request_user(self, db: AsyncSession, request: Request) -> Optional[UserInfo]:
Expand Down
33 changes: 17 additions & 16 deletions argilla-server/src/argilla_server/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@
import json
import logging
import platform
import uuid
from typing import Union

from fastapi import Request, Response
from huggingface_hub.utils import send_telemetry

from argilla_server._version import __version__
from argilla_server.api.errors.v1.exception_handlers import get_request_error
from argilla_server.settings import settings
from argilla_server.utils._fastapi import resolve_endpoint_path_for_request
from argilla_server.utils._telemetry import (
is_running_on_docker_container,
server_deployment_type,
)
from argilla_server.security.authentication.provider import get_request_user

_LOGGER = logging.getLogger(__name__)


@dataclasses.dataclass
class TelemetryClient:
_server_id: int = uuid.getnode()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The uuid.getnode returns an integer value. We can use it "as is".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check that uuid.getnode() is returning the same value on HF Spaces even after a factory rebuild.


def __post_init__(self):
self._system_info = {
"server_id": self._server_id,
"system": platform.system(),
"machine": platform.machine(),
"platform": platform.platform(),
Expand All @@ -47,16 +52,11 @@
_LOGGER.info("System Info:")
_LOGGER.info(f"Context: {json.dumps(self._system_info, indent=2)}")

async def track_data(self, topic: str, data: dict, include_system_info: bool = True, count: int = 1):
library_name = "argilla/server"
topic = f"{library_name}/{topic}"

user_agent = {**data}
if include_system_info:
user_agent.update(self._system_info)
if count is not None:
user_agent["count"] = count
async def track_data(self, topic: str, data: dict):
library_name = "argilla-server"
Copy link
Member Author

@frascuchon frascuchon Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aligns the library name (argilla-server) to the package published in pypi.org

topic = f"argilla/server/{topic}"

user_agent = {**data, **self._system_info}
send_telemetry(topic=topic, library_name=library_name, library_version=__version__, user_agent=user_agent)

async def track_api_request(self, request: Request, response: Response) -> None:
Expand Down Expand Up @@ -85,15 +85,16 @@
"response.status": str(response.status_code),
}

if "Server-Timing" in response.headers:
duration_in_ms = response.headers["Server-Timing"]
duration_in_ms = duration_in_ms.removeprefix("total;dur=")

if server_timing := response.headers.get("Server-Timing"):
duration_in_ms = server_timing.removeprefix("total;dur=")
data["duration_in_milliseconds"] = duration_in_ms

if user := get_request_user(request=request):
data["user.id"] = str(user.id)
data["user.role"] = user.role

Check warning on line 94 in argilla-server/src/argilla_server/telemetry.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/telemetry.py#L93-L94

Added lines #L93 - L94 were not covered by tests

if response.status_code >= 400:
argilla_error: Exception = get_request_error(request=request)
if argilla_error:
if argilla_error := get_request_error(request=request):
data["response.error_code"] = argilla_error.code # noqa

await self.track_data(topic=topic, data=data)
Expand Down
28 changes: 24 additions & 4 deletions argilla-server/tests/unit/commons/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from unittest.mock import MagicMock

import pytest
from fastapi import Request, APIRouter
from fastapi.routing import APIRoute
from pytest_mock import mocker, MockerFixture
from fastapi import Request
from pytest_mock import MockerFixture
from starlette.responses import JSONResponse

from argilla_server.api.errors.v1.exception_handlers import set_request_error
Expand All @@ -30,6 +28,28 @@

@pytest.mark.asyncio
class TestSuiteTelemetry:
async def test_create_client_with_server_id(self):
test_telemetry = TelemetryClient()

assert "server_id" in test_telemetry._system_info
assert test_telemetry._system_info["server_id"] == uuid.getnode()

async def test_track_data(self, mocker: MockerFixture):
from argilla_server._version import __version__ as version

mock = mocker.patch("argilla_server.telemetry.send_telemetry")

telemetry = TelemetryClient()

await telemetry.track_data("test_topic", {"test": "test"})

mock.assert_called_once_with(
topic="argilla/server/test_topic",
library_name="argilla-server",
library_version=version,
user_agent={"test": "test", **telemetry._system_info},
)

async def test_track_api_request(self, test_telemetry: TelemetryClient, mocker: MockerFixture):
mocker.patch("argilla_server.telemetry.resolve_endpoint_path_for_request", return_value="/api/test/endpoint")

Expand Down
14 changes: 14 additions & 0 deletions argilla-server/tests/unit/test_api_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import MagicMock, ANY

import pytest
from pytest_mock import MockerFixture
from starlette.testclient import TestClient

from argilla_server._app import create_server_app
Expand Down Expand Up @@ -44,6 +45,19 @@ def test_track_api_request_call_on_error(self, test_telemetry: TelemetryClient):

test_telemetry.track_api_request.assert_called_once()

def test_track_api_request_with_unexpected_telemetry_error(
self, test_telemetry: TelemetryClient, mocker: "MockerFixture"
):
with mocker.patch.object(test_telemetry, "track_api_request", side_effect=Exception("mocked error")):
settings.enable_telemetry = True

client = TestClient(create_server_app())

response = client.get("/api/v1/version")

test_telemetry.track_api_request.assert_called_once()
assert response.status_code == 200

def test_not_track_api_request_call_when_disabled_telemetry(self, test_telemetry: TelemetryClient):
settings.enable_telemetry = False

Expand Down