diff --git a/argilla-sdk/pyproject.toml b/argilla-sdk/pyproject.toml index 45bc717838..0380497cbf 100644 --- a/argilla-sdk/pyproject.toml +++ b/argilla-sdk/pyproject.toml @@ -13,6 +13,8 @@ dynamic = ["version"] dependencies = [ "httpx>=0.26.0", "pydantic>=2.6.0, <3.0.0", + "tqdm>=4.60.0", + "rich>=10.0.0", ] [project.optional-dependencies] diff --git a/argilla-sdk/src/argilla_sdk/_helpers/_log.py b/argilla-sdk/src/argilla_sdk/_helpers/_log.py index 166da0f429..87f3886a8a 100644 --- a/argilla-sdk/src/argilla_sdk/_helpers/_log.py +++ b/argilla-sdk/src/argilla_sdk/_helpers/_log.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rich import print import logging @@ -36,10 +37,26 @@ def log_message(message: str, level: str = "info") -> None: logger.log(level=level_int, msg=message) +def log_interactive(message: str) -> None: + """Log a message to the console in an interactive environment. + Args: + message (str): The message to log. + """ + print(message) + + class LoggingMixin: """A utility mixin for logging from a `Resource` class.""" def _log_message(self, message: str, level: str = "info") -> None: class_name = self.__class__.__name__ message = f"{class_name}: {message}" - log_message(level=level, message=message) + if self._is_interactive() and level != "info": + log_interactive(message=message) + else: + log_message(level=level, message=message) + + def _is_interactive(self) -> bool: + import __main__ as main + + return not hasattr(main, "__file__") diff --git a/argilla-sdk/src/argilla_sdk/_helpers/_resource_repr.py b/argilla-sdk/src/argilla_sdk/_helpers/_resource_repr.py index f1f3e89baf..c23178362c 100644 --- a/argilla-sdk/src/argilla_sdk/_helpers/_resource_repr.py +++ b/argilla-sdk/src/argilla_sdk/_helpers/_resource_repr.py @@ -14,9 +14,6 @@ from typing import Any, Dict -from IPython.display import HTML - - RESOURCE_REPR_CONFIG = { "Dataset": { "columns": ["name", "id", "workspace_id", "updated_at"], @@ -53,7 +50,7 @@ def _resource_to_table_name(self, resource) -> str: resource_name = resource.__class__.__name__ return RESOURCE_REPR_CONFIG[resource_name]["table_name"] - def _represent_as_html(self, resources) -> HTML: + def _represent_as_html(self, resources) -> str: table_name = self._resource_to_table_name(resources[0]) table_rows = [self._resource_to_table_row(resource) for resource in resources] @@ -69,4 +66,4 @@ def _represent_as_html(self, resources) -> HTML: html_table += "" html_table += "" - return HTML(html_table)._repr_html_() + return html_table diff --git a/argilla-sdk/src/argilla_sdk/client.py b/argilla-sdk/src/argilla_sdk/client.py index 14b0e4c4f3..c243159715 100644 --- a/argilla-sdk/src/argilla_sdk/client.py +++ b/argilla-sdk/src/argilla_sdk/client.py @@ -28,7 +28,6 @@ from argilla_sdk import Dataset from argilla_sdk import User - from IPython.display import HTML __all__ = ["Argilla"] @@ -170,7 +169,7 @@ def list(self, workspace: Optional["Workspace"] = None) -> List["User"]: # Private methods ############################ - def _repr_html_(self) -> "HTML": + def _repr_html_(self) -> str: return self._represent_as_html(resources=self.list()) def _from_model(self, model: UserModel) -> "User": @@ -249,7 +248,7 @@ def default(self) -> "Workspace": # Private methods ############################ - def _repr_html_(self) -> "HTML": + def _repr_html_(self) -> str: return self._represent_as_html(resources=self.list()) def _from_model(self, model: WorkspaceModel) -> "Workspace": @@ -324,7 +323,7 @@ def list(self) -> List["Dataset"]: # Private methods ############################ - def _repr_html_(self) -> "HTML": + def _repr_html_(self) -> str: return self._represent_as_html(resources=self.list()) def _from_model(self, model: DatasetModel) -> "Dataset": diff --git a/argilla-sdk/src/argilla_sdk/records/_dataset_records.py b/argilla-sdk/src/argilla_sdk/records/_dataset_records.py index 11ecd88f10..dfb94f6d3c 100644 --- a/argilla-sdk/src/argilla_sdk/records/_dataset_records.py +++ b/argilla-sdk/src/argilla_sdk/records/_dataset_records.py @@ -17,6 +17,8 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union from uuid import UUID +from tqdm import tqdm + from argilla_sdk._api import RecordsAPI from argilla_sdk._helpers import LoggingMixin from argilla_sdk._models import RecordModel, MetadataValue @@ -224,7 +226,9 @@ def log( created_or_updated = [] records_updated = 0 - for batch in range(0, len(records), batch_size): + for batch in tqdm( + iterable=range(0, len(records), batch_size), desc="Adding and updating records", unit="batch" + ): self._log_message(message=f"Sending records from {batch} to {batch + batch_size}.") batch_records = record_models[batch : batch + batch_size] models, updated = self._api.bulk_upsert(dataset_id=self.__dataset.id, records=batch_records)