From 22c088dc7d9acb1e6e1daecc8e98d979424aed88 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 11:50:15 +0200 Subject: [PATCH 01/10] chore: Remove unused methods --- .../argilla_server/errors/error_handler.py | 4 - .../tests/unit/errors/test_api_errors.py | 87 ------------------- 2 files changed, 91 deletions(-) delete mode 100644 argilla-server/tests/unit/errors/test_api_errors.py diff --git a/argilla-server/src/argilla_server/errors/error_handler.py b/argilla-server/src/argilla_server/errors/error_handler.py index 0bc76f7425..ebbff42c8e 100644 --- a/argilla-server/src/argilla_server/errors/error_handler.py +++ b/argilla-server/src/argilla_server/errors/error_handler.py @@ -52,10 +52,6 @@ def __init__(self, error: ServerError): class APIErrorHandler: - @classmethod - async def track_error(cls, error: ServerError, request: Request): - await get_telemetry_client().track_error(error=error, request=request) - @classmethod async def common_exception_handler(cls, request: Request, error: Exception): """Wraps errors as custom generic error""" diff --git a/argilla-server/tests/unit/errors/test_api_errors.py b/argilla-server/tests/unit/errors/test_api_errors.py deleted file mode 100644 index b85d7791f7..0000000000 --- a/argilla-server/tests/unit/errors/test_api_errors.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import MagicMock - -import pytest -from argilla_server.api.schemas.v1.datasets import Dataset -from argilla_server.errors.base_errors import ( - EntityAlreadyExistsError, - EntityNotFoundError, - GenericServerError, - ServerError, -) -from argilla_server.errors.error_handler import APIErrorHandler -from fastapi import Request - -mock_request = Request(scope={"type": "http", "headers": {}}) - - -@pytest.mark.asyncio -class TestAPIErrorHandler: - @pytest.mark.skip - @pytest.mark.asyncio - @pytest.mark.parametrize( - ["error", "expected_event"], - [ - ( - EntityNotFoundError(name="mock-name", type="MockType"), - { - "accept-language": None, - "code": "argilla.api.errors::EntityNotFoundError", - "type": "MockType", - "user-agent": None, - }, - ), - ( - EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"), - { - "accept-language": None, - "code": "argilla.api.errors::EntityAlreadyExistsError", - "type": "Dataset", - "user-agent": None, - }, - ), - ( - GenericServerError(RuntimeError("This is a mock error")), - { - "accept-language": None, - "code": "argilla.api.errors::GenericServerError", - "type": "builtins.RuntimeError", - "user-agent": None, - }, - ), - ( - ServerError(), - { - "accept-language": None, - "code": "argilla.api.errors::ServerError", - "user-agent": None, - }, - ), - ], - ) - async def test_track_error(self, test_telemetry: MagicMock, error, expected_event): - await APIErrorHandler.track_error(error, request=mock_request) - - user_agent = { - "code": error.code, - "user-agent": mock_request.headers.get("user-agent"), - "accept-language": mock_request.headers.get("accept-language"), - "type": error.__class__.__name__, - "count": 1, - } - user_agent.update(test_telemetry._system_info) - - test_telemetry.track_data.assert_called_once_with(topic="error/server", user_agent=user_agent) From c8ac5a9b78f40af011fc76645e4e834ce2325915 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 11:50:46 +0200 Subject: [PATCH 02/10] chore: Remove specific user me telemetry tracking --- argilla-server/src/argilla_server/api/handlers/v1/users.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/users.py b/argilla-server/src/argilla_server/api/handlers/v1/users.py index ab819e7599..557ad944b1 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/users.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/users.py @@ -32,12 +32,8 @@ @router.get("/me", response_model=UserSchema) async def get_current_user( - request: Request, current_user: User = Security(auth.get_current_user), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), ): - await telemetry_client.track_user_login(request=request, user=current_user) - return current_user From 66db650d9ac27174c2fdbe44d12f36522353366d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 11:52:07 +0200 Subject: [PATCH 03/10] refactor: Add authenticated user to the request --- .../security/authentication/provider.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/argilla-server/src/argilla_server/security/authentication/provider.py b/argilla-server/src/argilla_server/security/authentication/provider.py index 002faa73cc..de33a81a18 100644 --- a/argilla-server/src/argilla_server/security/authentication/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/provider.py @@ -28,6 +28,32 @@ from argilla_server.security.authentication.userinfo import UserInfo +def set_request_user(request: Request, user: User): + """ + Set the request user in the request state. + + Parameters: + request: The request object. + user: The user. + + """ + + request.state.user = user + + +def get_request_user(request: Request) -> Optional[User]: + """ + Get the current user from the request. + + Parameters: + request (Request): The request object. + + Returns: + The user if available, None otherwise. + """ + return getattr(request.state, "user", None) + + class AuthenticationProvider: """Authentication provider for the API requests.""" @@ -58,6 +84,7 @@ async def get_current_user( if not user: raise UnauthorizedError() + set_request_user(request, user) return user async def _authenticate_request_user(self, db: AsyncSession, request: Request) -> Optional[UserInfo]: From 857916f94e7b94ec0a3efacf20108044eb9d5197 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 11:54:05 +0200 Subject: [PATCH 04/10] feat: Track request with user id and role and add server id to the system context Also, removed unnecessary methods and attributes. --- .../src/argilla_server/telemetry.py | 41 ++++++------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/argilla-server/src/argilla_server/telemetry.py b/argilla-server/src/argilla_server/telemetry.py index 75d6027c59..d3c431f692 100644 --- a/argilla-server/src/argilla_server/telemetry.py +++ b/argilla-server/src/argilla_server/telemetry.py @@ -16,6 +16,7 @@ import json import logging import platform +import uuid from typing import Union from fastapi import Request, Response @@ -24,9 +25,6 @@ from argilla_server._version import __version__ from argilla_server.api.errors.v1.exception_handlers import get_request_error from argilla_server.constants import DEFAULT_USERNAME -from argilla_server.errors.base_errors import ( - ServerError, -) from argilla_server.models import ( Dataset, Field, @@ -41,6 +39,7 @@ VectorSettings, Workspace, ) +from argilla_server.security.authentication.provider import get_request_user from argilla_server.settings import settings from argilla_server.utils._fastapi import resolve_endpoint_path_for_request from argilla_server.utils._telemetry import ( @@ -53,10 +52,11 @@ @dataclasses.dataclass class TelemetryClient: - enable_telemetry: dataclasses.InitVar[bool] = settings.enable_telemetry + _server_id: str = str(uuid.UUID(int=uuid.getnode())) - def __post_init__(self, enable_telemetry: bool): + def __post_init__(self): self._system_info = { + "server_id": self._server_id, "system": platform.system(), "machine": platform.machine(), "platform": platform.platform(), @@ -67,7 +67,6 @@ def __post_init__(self, enable_telemetry: bool): _LOGGER.info("System Info:") _LOGGER.info(f"Context: {json.dumps(self._system_info, indent=2)}") - self.enable_telemetry = enable_telemetry @staticmethod def _process_request_info(request: Request): @@ -191,25 +190,20 @@ async def track_api_request(self, request: Request, response: Response) -> None: "response.status": str(response.status_code), } - if "Server-Timing" in response.headers: - duration_in_ms = response.headers["Server-Timing"] - duration_in_ms = duration_in_ms.removeprefix("total;dur=") - + if server_timing := response.headers.get("Server-Timing"): + duration_in_ms = server_timing.removeprefix("total;dur=") data["duration_in_milliseconds"] = duration_in_ms + if user := get_request_user(request=request): + data["user.id"] = str(user.id) + data["user.role"] = user.role + if response.status_code >= 400: - argilla_error: Exception = get_request_error(request=request) - if argilla_error: + if argilla_error := get_request_error(request=request): data["response.error_code"] = argilla_error.code # noqa await self.track_data(topic=topic, data=data) - async def track_user_login(self, request: Request, user: User): - topic = "user/login" - user_agent = self._process_user_model(user=user) - user_agent.update(**self._process_request_info(request)) - await self.track_data(topic=topic, data=user_agent) - async def track_crud_user( self, action: str, @@ -308,17 +302,6 @@ async def track_crud_records_suggestions( user_agent["record_id"] = record_id await self.track_data(topic=topic, data=user_agent, count=count) - async def track_error(self, error: ServerError, request: Request): - topic = "error/server" - user_agent = { - "code": error.code, - "user-agent": request.headers.get("user-agent"), - "accept-language": request.headers.get("accept-language"), - "type": error.__class__.__name__, - } - - await self.track_data(topic=topic, data=user_agent) - _TELEMETRY_CLIENT = TelemetryClient() From 5ae206966f8b868e0fbfcff48fc52ae99044ba68 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 12:10:26 +0200 Subject: [PATCH 05/10] chore: Add missing test --- argilla-server/tests/unit/test_api_telemetry.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/argilla-server/tests/unit/test_api_telemetry.py b/argilla-server/tests/unit/test_api_telemetry.py index 9f10aa14ae..7f0984cfcd 100644 --- a/argilla-server/tests/unit/test_api_telemetry.py +++ b/argilla-server/tests/unit/test_api_telemetry.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, ANY import pytest +from pytest_mock import MockerFixture from starlette.testclient import TestClient from argilla_server._app import create_server_app @@ -44,6 +45,19 @@ def test_track_api_request_call_on_error(self, test_telemetry: TelemetryClient): test_telemetry.track_api_request.assert_called_once() + def test_track_api_request_with_unexpected_telemetry_error( + self, test_telemetry: TelemetryClient, mocker: "MockerFixture" + ): + with mocker.patch.object(test_telemetry, "track_api_request", side_effect=Exception("mocked error")): + settings.enable_telemetry = True + + client = TestClient(create_server_app()) + + response = client.get("/api/v1/version") + + test_telemetry.track_api_request.assert_called_once() + assert response.status_code == 200 + def test_not_track_api_request_call_when_disabled_telemetry(self, test_telemetry: TelemetryClient): settings.enable_telemetry = False From 56ebfbcea14c44e3b091523b2a072c0cd5da91e1 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 14:44:06 +0200 Subject: [PATCH 06/10] refactor: using server-id int value Also, review and simplify the track_data method --- argilla-server/src/argilla_server/telemetry.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/argilla-server/src/argilla_server/telemetry.py b/argilla-server/src/argilla_server/telemetry.py index 6bcc0ea586..6662a169fc 100644 --- a/argilla-server/src/argilla_server/telemetry.py +++ b/argilla-server/src/argilla_server/telemetry.py @@ -36,7 +36,7 @@ @dataclasses.dataclass class TelemetryClient: - _server_id: str = str(uuid.UUID(int=uuid.getnode())) + _server_id: int = uuid.getnode() def __post_init__(self): self._system_info = { @@ -52,16 +52,11 @@ def __post_init__(self): _LOGGER.info("System Info:") _LOGGER.info(f"Context: {json.dumps(self._system_info, indent=2)}") - async def track_data(self, topic: str, data: dict, include_system_info: bool = True, count: int = 1): - library_name = "argilla/server" - topic = f"{library_name}/{topic}" - - user_agent = {**data} - if include_system_info: - user_agent.update(self._system_info) - if count is not None: - user_agent["count"] = count + async def track_data(self, topic: str, data: dict): + library_name = "argilla-server" + topic = f"argilla/server/{topic}" + user_agent = {**data, **self._system_info} send_telemetry(topic=topic, library_name=library_name, library_version=__version__, user_agent=user_agent) async def track_api_request(self, request: Request, response: Response) -> None: From 8b936e596012f6c07e46aef05f812dafef302848 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 14:47:17 +0200 Subject: [PATCH 07/10] tests: Add more tests --- .../tests/unit/commons/test_telemetry.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/argilla-server/tests/unit/commons/test_telemetry.py b/argilla-server/tests/unit/commons/test_telemetry.py index ca864b850b..d3532ee14e 100644 --- a/argilla-server/tests/unit/commons/test_telemetry.py +++ b/argilla-server/tests/unit/commons/test_telemetry.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import uuid from unittest.mock import MagicMock import pytest -from fastapi import Request, APIRouter -from fastapi.routing import APIRoute -from pytest_mock import mocker, MockerFixture +from fastapi import Request +from pytest_mock import MockerFixture from starlette.responses import JSONResponse from argilla_server.api.errors.v1.exception_handlers import set_request_error @@ -30,6 +28,28 @@ @pytest.mark.asyncio class TestSuiteTelemetry: + async def test_create_client_with_server_id(self): + test_telemetry = TelemetryClient() + + assert "server_id" in test_telemetry._system_info + assert test_telemetry._system_info["server_id"] == uuid.getnode() + + async def test_track_data(self, mocker: MockerFixture): + from argilla_server._version import __version__ as version + + mock = mocker.patch("argilla_server.telemetry.send_telemetry") + + telemetry = TelemetryClient() + + await telemetry.track_data("test_topic", {"test": "test"}) + + mock.assert_called_once_with( + topic="argilla/server/test_topic", + library_name="argilla-server", + library_version=version, + user_agent={"test": "test", **telemetry._system_info}, + ) + async def test_track_api_request(self, test_telemetry: TelemetryClient, mocker: MockerFixture): mocker.patch("argilla_server.telemetry.resolve_endpoint_path_for_request", return_value="/api/test/endpoint") From 692d24426f94d3a4025441bef5195c6c0fb190be Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 17:33:15 +0200 Subject: [PATCH 08/10] refactor: Store server id in file --- .../src/argilla_server/telemetry.py | 9 +++--- .../src/argilla_server/utils/_telemetry.py | 32 +++++++++++++++++++ .../tests/unit/commons/test_telemetry.py | 7 ++-- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/telemetry.py b/argilla-server/src/argilla_server/telemetry.py index 6662a169fc..91b8ec5cbb 100644 --- a/argilla-server/src/argilla_server/telemetry.py +++ b/argilla-server/src/argilla_server/telemetry.py @@ -17,30 +17,31 @@ import logging import platform import uuid -from typing import Union from fastapi import Request, Response from huggingface_hub.utils import send_telemetry from argilla_server._version import __version__ from argilla_server.api.errors.v1.exception_handlers import get_request_error +from argilla_server.security.authentication.provider import get_request_user from argilla_server.utils._fastapi import resolve_endpoint_path_for_request from argilla_server.utils._telemetry import ( is_running_on_docker_container, server_deployment_type, + get_server_id, ) -from argilla_server.security.authentication.provider import get_request_user _LOGGER = logging.getLogger(__name__) @dataclasses.dataclass class TelemetryClient: - _server_id: int = uuid.getnode() + _server_id: uuid.UUID = dataclasses.field(init=False) def __post_init__(self): + self._server_id = get_server_id() self._system_info = { - "server_id": self._server_id, + "server_id": self._server_id.urn, "system": platform.system(), "machine": platform.machine(), "platform": platform.platform(), diff --git a/argilla-server/src/argilla_server/utils/_telemetry.py b/argilla-server/src/argilla_server/utils/_telemetry.py index 8015ba5319..95a4986bf5 100644 --- a/argilla-server/src/argilla_server/utils/_telemetry.py +++ b/argilla-server/src/argilla_server/utils/_telemetry.py @@ -13,11 +13,43 @@ # limitations under the License. import logging import os +import uuid +from uuid import UUID from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS +from argilla_server.settings import settings _LOGGER = logging.getLogger(__name__) +_SERVER_ID_DAT_FILE = "server_id.dat" + + +def get_server_id() -> UUID: + """ + Returns the server ID. If it is not set, it generates a new one and stores it + in $ARGILLA_HOME/server_id.dat + + Returns: + UUID: The server ID + + """ + + server_id_file = os.path.join(settings.home_path, _SERVER_ID_DAT_FILE) + + if os.path.exists(server_id_file): + with open(server_id_file, "r") as f: + server_id = f.read().strip() + try: + return UUID(server_id) + except ValueError: + _LOGGER.warning(f"Invalid server ID in {server_id_file}. Generating a new one.") + + server_id = uuid.uuid4() + with open(server_id_file, "w") as f: + f.write(str(server_id)) + + return server_id + def server_deployment_type() -> str: """Returns the type of deployment of the server.""" diff --git a/argilla-server/tests/unit/commons/test_telemetry.py b/argilla-server/tests/unit/commons/test_telemetry.py index d3532ee14e..3fc6f81349 100644 --- a/argilla-server/tests/unit/commons/test_telemetry.py +++ b/argilla-server/tests/unit/commons/test_telemetry.py @@ -28,11 +28,14 @@ @pytest.mark.asyncio class TestSuiteTelemetry: - async def test_create_client_with_server_id(self): + async def test_create_client_with_server_id(self, mocker: MockerFixture): + mock_server_id = uuid.uuid4() + mocker.patch("argilla_server.telemetry.get_server_id", return_value=mock_server_id) + test_telemetry = TelemetryClient() assert "server_id" in test_telemetry._system_info - assert test_telemetry._system_info["server_id"] == uuid.getnode() + assert test_telemetry._system_info["server_id"] == mock_server_id.urn async def test_track_data(self, mocker: MockerFixture): from argilla_server._version import __version__ as version From bc43b437682c67c858c4f441b3553f1dd2963016 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 17:42:40 +0200 Subject: [PATCH 09/10] chore: Move telemetry python module --- .../src/argilla_server/telemetry/__init__.py | 16 ++++++++++++++++ .../{telemetry.py => telemetry/_client.py} | 2 +- .../_telemetry.py => telemetry/_helpers.py} | 0 .../tests/unit/commons/test_telemetry.py | 16 +++++++++++----- argilla-server/tests/unit/conftest.py | 2 +- 5 files changed, 29 insertions(+), 7 deletions(-) create mode 100644 argilla-server/src/argilla_server/telemetry/__init__.py rename argilla-server/src/argilla_server/{telemetry.py => telemetry/_client.py} (98%) rename argilla-server/src/argilla_server/{utils/_telemetry.py => telemetry/_helpers.py} (100%) diff --git a/argilla-server/src/argilla_server/telemetry/__init__.py b/argilla-server/src/argilla_server/telemetry/__init__.py new file mode 100644 index 0000000000..64f5f3cc31 --- /dev/null +++ b/argilla-server/src/argilla_server/telemetry/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._client import TelemetryClient, get_telemetry_client # noqa +from ._helpers import * # noqa diff --git a/argilla-server/src/argilla_server/telemetry.py b/argilla-server/src/argilla_server/telemetry/_client.py similarity index 98% rename from argilla-server/src/argilla_server/telemetry.py rename to argilla-server/src/argilla_server/telemetry/_client.py index 91b8ec5cbb..de8835701a 100644 --- a/argilla-server/src/argilla_server/telemetry.py +++ b/argilla-server/src/argilla_server/telemetry/_client.py @@ -25,7 +25,7 @@ from argilla_server.api.errors.v1.exception_handlers import get_request_error from argilla_server.security.authentication.provider import get_request_user from argilla_server.utils._fastapi import resolve_endpoint_path_for_request -from argilla_server.utils._telemetry import ( +from argilla_server.telemetry._helpers import ( is_running_on_docker_container, server_deployment_type, get_server_id, diff --git a/argilla-server/src/argilla_server/utils/_telemetry.py b/argilla-server/src/argilla_server/telemetry/_helpers.py similarity index 100% rename from argilla-server/src/argilla_server/utils/_telemetry.py rename to argilla-server/src/argilla_server/telemetry/_helpers.py diff --git a/argilla-server/tests/unit/commons/test_telemetry.py b/argilla-server/tests/unit/commons/test_telemetry.py index 3fc6f81349..ca89a57109 100644 --- a/argilla-server/tests/unit/commons/test_telemetry.py +++ b/argilla-server/tests/unit/commons/test_telemetry.py @@ -30,7 +30,7 @@ class TestSuiteTelemetry: async def test_create_client_with_server_id(self, mocker: MockerFixture): mock_server_id = uuid.uuid4() - mocker.patch("argilla_server.telemetry.get_server_id", return_value=mock_server_id) + mocker.patch("argilla_server.telemetry._client.get_server_id", return_value=mock_server_id) test_telemetry = TelemetryClient() @@ -40,7 +40,7 @@ async def test_create_client_with_server_id(self, mocker: MockerFixture): async def test_track_data(self, mocker: MockerFixture): from argilla_server._version import __version__ as version - mock = mocker.patch("argilla_server.telemetry.send_telemetry") + mock = mocker.patch("argilla_server.telemetry._client.send_telemetry") telemetry = TelemetryClient() @@ -54,7 +54,9 @@ async def test_track_data(self, mocker: MockerFixture): ) async def test_track_api_request(self, test_telemetry: TelemetryClient, mocker: MockerFixture): - mocker.patch("argilla_server.telemetry.resolve_endpoint_path_for_request", return_value="/api/test/endpoint") + mocker.patch( + "argilla_server.telemetry._client.resolve_endpoint_path_for_request", return_value="/api/test/endpoint" + ) request = Request( scope={ @@ -83,7 +85,9 @@ async def test_track_api_request(self, test_telemetry: TelemetryClient, mocker: ) async def test_track_api_request_call_with_error(self, test_telemetry: TelemetryClient, mocker: MockerFixture): - mocker.patch("argilla_server.telemetry.resolve_endpoint_path_for_request", return_value="/api/test/endpoint") + mocker.patch( + "argilla_server.telemetry._client.resolve_endpoint_path_for_request", return_value="/api/test/endpoint" + ) request = Request( scope={ @@ -110,7 +114,9 @@ async def test_track_api_request_call_with_error(self, test_telemetry: Telemetry async def test_track_api_request_call_with_error_and_exception( self, test_telemetry: TelemetryClient, mocker: MockerFixture ): - mocker.patch("argilla_server.telemetry.resolve_endpoint_path_for_request", return_value="/api/test/endpoint") + mocker.patch( + "argilla_server.telemetry._client.resolve_endpoint_path_for_request", return_value="/api/test/endpoint" + ) request = Request( scope={ diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index ea1713cf8d..a702be6c36 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -118,7 +118,7 @@ def test_telemetry(mocker: "MockerFixture") -> "TelemetryClient": setattr(real_telemetry, attr_name, wrapped) # Patch the _TELEMETRY_CLIENT to use the real_telemetry - mocker.patch("argilla_server.telemetry._TELEMETRY_CLIENT", new=real_telemetry) + mocker.patch("argilla_server.telemetry._client._TELEMETRY_CLIENT", new=real_telemetry) return real_telemetry From e047a04ede3a33dad1612ad27c3c1e0c554b4961 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 2 Sep 2024 17:55:56 +0200 Subject: [PATCH 10/10] tests: add more tests --- .../tests/unit/telemetry/__init__.py | 14 ++++++ .../unit/telemetry/test_telemetry_helpers.py | 43 +++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 argilla-server/tests/unit/telemetry/__init__.py create mode 100644 argilla-server/tests/unit/telemetry/test_telemetry_helpers.py diff --git a/argilla-server/tests/unit/telemetry/__init__.py b/argilla-server/tests/unit/telemetry/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/tests/unit/telemetry/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/tests/unit/telemetry/test_telemetry_helpers.py b/argilla-server/tests/unit/telemetry/test_telemetry_helpers.py new file mode 100644 index 0000000000..a6753a06be --- /dev/null +++ b/argilla-server/tests/unit/telemetry/test_telemetry_helpers.py @@ -0,0 +1,43 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch, mock_open +from uuid import UUID + +import pytest +from pytest_mock import MockerFixture + +from argilla_server.settings import settings +from argilla_server.telemetry import get_server_id + + +class TestTelemetryHelpers: + def test_get_server_id_without_existing_file(self, mocker: MockerFixture): + mocker.patch.object(os.path, "exists", return_value=False) + + with patch("builtins.open", mock_open()) as mock: + server_id = get_server_id() + another_server_id = get_server_id() + + assert server_id != another_server_id + assert mock.call_count == 2 + mock.assert_called_with(os.path.join(settings.home_path, "server_id.dat"), "w") + + def test_get_server_id_with_existing_file(self, mocker: MockerFixture): + mocker.patch.object(os.path, "exists", return_value=True) + + with patch("builtins.open", mock_open(read_data="00000000-0000-0000-0000-000000000000")) as mock: + server_id = get_server_id() + assert server_id == UUID(int=0)