From b5acc70f8ce1ba3ed8616d07e4a2d14ea020cd6c Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 6 Aug 2024 18:20:56 +0200 Subject: [PATCH 1/3] Add retries argument --- argilla/src/argilla/_api/_client.py | 4 +++- argilla/src/argilla/_api/_http/_client.py | 18 ++++++++++-------- argilla/src/argilla/client.py | 23 ++++++++++++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/argilla/src/argilla/_api/_client.py b/argilla/src/argilla/_api/_client.py index 7f496c5d11..ccbe9667a9 100644 --- a/argilla/src/argilla/_api/_client.py +++ b/argilla/src/argilla/_api/_client.py @@ -103,8 +103,9 @@ class APIClient: def __init__( self, api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url, - api_key: str = DEFAULT_HTTP_CONFIG.api_key, + api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key, timeout: int = DEFAULT_HTTP_CONFIG.timeout, + retries: int = DEFAULT_HTTP_CONFIG.retries, **http_client_args, ): if not api_url: @@ -118,6 +119,7 @@ def __init__( http_client_args = http_client_args or {} http_client_args["timeout"] = timeout + http_client_args["retries"] = retries self.http_client = create_http_client( api_url=self.api_url, # type: ignore diff --git a/argilla/src/argilla/_api/_http/_client.py b/argilla/src/argilla/_api/_http/_client.py index f30a06f9a5..b2767efb79 100644 --- a/argilla/src/argilla/_api/_http/_client.py +++ b/argilla/src/argilla/_api/_http/_client.py @@ -23,12 +23,8 @@ class HTTPClientConfig: api_url: str api_key: str - timeout: int = None - - def __post_init__(self): - self.api_url = self.api_url - self.api_key = self.api_key - self.timeout = self.timeout or 60 + timeout: int = 60 + retries: int = 5 def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Client: @@ -37,5 +33,11 @@ def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Clien headers = client_args.pop("headers", {}) headers["X-Argilla-Api-Key"] = api_key - - return httpx.Client(base_url=api_url, headers=headers, **client_args) + retries = client_args.pop("retries", 0) + + return httpx.Client( + base_url=api_url, + headers=headers, + transport=httpx.HTTPTransport(retries=retries), + **client_args, + ) diff --git a/argilla/src/argilla/client.py b/argilla/src/argilla/client.py index efff77107b..1f16bee08b 100644 --- a/argilla/src/argilla/client.py +++ b/argilla/src/argilla/client.py @@ -41,14 +41,8 @@ class Argilla(_api.APIClient): datasets: A collection of datasets. users: A collection of users. me: The current user. - """ - workspaces: "Workspaces" - datasets: "Datasets" - users: "Users" - me: "User" - # Default instance of Argilla _default_client: Optional["Argilla"] = None @@ -57,9 +51,24 @@ def __init__( api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url, api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key, timeout: int = DEFAULT_HTTP_CONFIG.timeout, + retries: int = DEFAULT_HTTP_CONFIG.retries, **http_client_args, ) -> None: - super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, **http_client_args) + """Inits the `Argilla` client. + + Args: + api_url: the URL of the Argilla API. If not provided, then the value will try + to be set from `ARGILLA_API_URL` environment variable. Defaults to + `"http://localhost:6900"`. + api_key: the key to be used to authenticate in the Argilla API. If not provided, + then the value will try to be set from `ARGILLA_API_KEY` environment variable. + Defaults to `None`. + timeout: the maximum time in seconds to wait for a request to the Argilla API + to be completed before raising an exception. Defaults to `60`. + retries: the number of times to retry a failed HTTP request to the Argilla API + before raising an exception. Defaults to `5`. + """ + super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, retries=retries, **http_client_args) self._set_default(self) From e1d4cb8269e0292b7267c5629cde8cf160ac0ac4 Mon Sep 17 00:00:00 2001 From: Ben Burtenshaw Date: Thu, 5 Sep 2024 10:30:11 +0200 Subject: [PATCH 2/3] test: add unit test for retries --- argilla/tests/unit/api/http/test_http_client.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/argilla/tests/unit/api/http/test_http_client.py b/argilla/tests/unit/api/http/test_http_client.py index 8be069c085..c5bd890cd4 100644 --- a/argilla/tests/unit/api/http/test_http_client.py +++ b/argilla/tests/unit/api/http/test_http_client.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from httpx import Timeout +from unittest.mock import MagicMock, patch +import pytest from argilla import Argilla +from httpx import Timeout class TestHTTPClient: @@ -62,3 +64,15 @@ def test_create_client_with_extra_cookies(self): assert http_client.base_url == "http://localhost:6900" assert http_client.headers["X-Argilla-Api-Key"] == "argilla.apikey" assert http_client.cookies["session"] == "session_id" + + @pytest.mark.parametrize("retries", [0, 1, 5, 10]) + def test_create_client_with_various_retries(self, retries): + with patch("argilla._api._client.create_http_client") as mock_create_http_client: + mock_http_client = MagicMock() + mock_create_http_client.return_value = mock_http_client + + Argilla(api_url="http://test.com", api_key="test_key", retries=retries) + + mock_create_http_client.assert_called_once_with( + api_url="http://test.com", api_key="test_key", timeout=60, retries=retries + ) From 75eec2001b5655a5a792ec872be9acb7fb2b702b Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 12 Sep 2024 14:52:51 +0200 Subject: [PATCH 3/3] Update argilla/src/argilla/client.py Co-authored-by: Paco Aranda --- argilla/src/argilla/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla/src/argilla/client.py b/argilla/src/argilla/client.py index 1f16bee08b..94a997b7ce 100644 --- a/argilla/src/argilla/client.py +++ b/argilla/src/argilla/client.py @@ -65,7 +65,7 @@ def __init__( Defaults to `None`. timeout: the maximum time in seconds to wait for a request to the Argilla API to be completed before raising an exception. Defaults to `60`. - retries: the number of times to retry a failed HTTP request to the Argilla API + retries: the number of times to retry the HTTP connection to the Argilla API before raising an exception. Defaults to `5`. """ super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, retries=retries, **http_client_args)