From 54d2e60305e5a175dd93a27c74a58c011ce2aec3 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 4 Jul 2024 12:41:30 +0200 Subject: [PATCH] [ENHANCEMENT/REFACTOR] `argilla`: lazy resolution for dataset workspaces (#5152) This PR removes the workspace resolution on `Dataset.__init__` and performs it lazily when creating datasets or accessing the `dataset.workspace` property **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla/_helpers/_resource_repr.py | 2 +- argilla/src/argilla/_resource.py | 12 ++-- argilla/src/argilla/client.py | 6 +- argilla/src/argilla/datasets/_resource.py | 60 ++++++++++--------- argilla/src/argilla/workspaces/_resource.py | 14 +++-- .../integration/test_dataset_workspace.py | 8 +-- .../tests/unit/helpers/test_resource_repr.py | 11 ++-- 7 files changed, 61 insertions(+), 52 deletions(-) diff --git a/argilla/src/argilla/_helpers/_resource_repr.py b/argilla/src/argilla/_helpers/_resource_repr.py index c23178362c..1da5f67f89 100644 --- a/argilla/src/argilla/_helpers/_resource_repr.py +++ b/argilla/src/argilla/_helpers/_resource_repr.py @@ -32,7 +32,7 @@ class ResourceHTMLReprMixin: def _resource_to_table_row(self, resource) -> Dict[str, Any]: row = {} - dumped_resource_model = resource._model.model_dump() + dumped_resource_model = resource.api_model().model_dump() resource_name = resource.__class__.__name__ config = RESOURCE_REPR_CONFIG[resource_name].copy() len_column = config.pop("len_column", None) diff --git a/argilla/src/argilla/_resource.py b/argilla/src/argilla/_resource.py index 60bf3f11c7..2dd88c1917 100644 --- a/argilla/src/argilla/_resource.py +++ b/argilla/src/argilla/_resource.py @@ -78,28 +78,28 @@ def api_model(self): ############################ def create(self) -> "Resource": - response_model = self._api.create(self._model) + response_model = self._api.create(self.api_model()) self._model = response_model self._update_last_api_call() self._log_message(f"Resource created: {self}") return self def get(self) -> "Resource": - response_model = self._api.get(self._model.id) + response_model = self._api.get(self.api_model().id) self._model = response_model self._update_last_api_call() self._log_message(f"Resource fetched: {self}") return self def update(self) -> "Resource": - response_model = self._api.update(self._model) + response_model = self._api.update(self.api_model()) self._model = response_model self._update_last_api_call() self._log_message(f"Resource updated: {self}") return self def delete(self) -> None: - self._api.delete(self._model.id) + self._api.delete(self.api_model().id) self._update_last_api_call() self._log_message(f"Resource deleted: {self}") @@ -109,13 +109,13 @@ def delete(self) -> None: def serialize(self) -> dict[str, Any]: try: - return self._model.model_dump() + return self.api_model().model_dump() except Exception as e: raise ArgillaSerializeError(f"Failed to serialize the resource. {e.__class__.__name__}") from e def serialize_json(self) -> str: try: - return self._model.model_dump_json() + return self.api_model().model_dump_json() except Exception as e: raise ArgillaSerializeError(f"Failed to serialize the resource. {e.__class__.__name__}") from e diff --git a/argilla/src/argilla/client.py b/argilla/src/argilla/client.py index 5e45dbb40a..b6513d07e4 100644 --- a/argilla/src/argilla/client.py +++ b/argilla/src/argilla/client.py @@ -195,7 +195,7 @@ def __call__(self, name: str, **kwargs) -> "Workspace": for model in workspace_models: if model.name == name: - return Workspace(_model=model, client=self._client) + return self._from_model(model) warnings.warn( f"Workspace {name} not found. Creating a new workspace. Do `workspace.create()` to create the workspace." ) @@ -252,7 +252,7 @@ def _repr_html_(self) -> str: def _from_model(self, model: WorkspaceModel) -> "Workspace": from argilla.workspaces import Workspace - return Workspace(client=self._client, _model=model) + return Workspace.from_model(client=self._client, model=model) class Datasets(Sequence["Dataset"], ResourceHTMLReprMixin): @@ -325,4 +325,4 @@ def _repr_html_(self) -> str: def _from_model(self, model: DatasetModel) -> "Dataset": from argilla.datasets import Dataset - return Dataset(client=self._client, _model=model) + return Dataset.from_model(model=model, client=self._client) diff --git a/argilla/src/argilla/datasets/_resource.py b/argilla/src/argilla/datasets/_resource.py index 4cc9562b36..5237010f34 100644 --- a/argilla/src/argilla/datasets/_resource.py +++ b/argilla/src/argilla/datasets/_resource.py @@ -18,7 +18,6 @@ from argilla._api import DatasetsAPI from argilla._exceptions import NotFoundError, SettingsError -from argilla._helpers import UUIDUtilities from argilla._models import DatasetModel from argilla._resource import Resource from argilla.client import Argilla @@ -52,10 +51,9 @@ class Dataset(Resource, DiskImportExportMixin): def __init__( self, name: Optional[str] = None, - workspace: Optional[Union["Workspace", str]] = None, + workspace: Optional[Union["Workspace", str, UUID]] = None, settings: Optional[Settings] = None, client: Optional["Argilla"] = None, - _model: Optional[DatasetModel] = None, ) -> None: """Initializes a new Argilla Dataset object with the given parameters. @@ -64,7 +62,6 @@ def __init__( workspace (UUID): Workspace of the dataset. Default is the first workspace found in the server. settings (Settings): Settings class to be used to configure the dataset. client (Argilla): Instance of Argilla to connect with the server. Default is the default client. - _model (DatasetModel): Model of the dataset. Used to create the dataset from an existing model. """ client = client or Argilla._get_default() super().__init__(client=client, api=client.api.datasets) @@ -72,13 +69,8 @@ def __init__( name = f"dataset_{uuid4()}" self._log_message(f"Settings dataset name to unique UUID: {name}") - self.workspace_id = ( - _model.workspace_id if _model and _model.workspace_id else self._workspace_id_from_name(workspace=workspace) - ) - self._model = _model or DatasetModel( - name=name, - workspace_id=UUIDUtilities.convert_optional_uuid(uuid=self.workspace_id), - ) + self._workspace = workspace + self._model = DatasetModel(name=name) self._settings = settings or Settings(_dataset=self) self._settings.dataset = self self.__records = DatasetRecords(client=self._client, dataset=self) @@ -136,6 +128,11 @@ def allow_extra_metadata(self, value: bool) -> None: def schema(self) -> dict: return self.settings.schema + @property + def workspace(self) -> Workspace: + self._workspace = self._resolve_workspace() + return self._workspace + ##################### # Core methods # ##################### @@ -178,36 +175,45 @@ def update(self) -> "Dataset": @classmethod def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset": - return cls(client=client, _model=model) + instance = cls(client=client, workspace=model.workspace_id, name=model.name) + instance._model = model + + return instance ##################### # Utility methods # ##################### + def api_model(self) -> DatasetModel: + self._model.workspace_id = self.workspace.id + return self._model + def _publish(self) -> "Dataset": self._settings.create() self._api.publish(dataset_id=self._model.id) return self.get() - def _workspace_id_from_name(self, workspace: Optional[Union["Workspace", str]]) -> UUID: + def _resolve_workspace(self) -> Workspace: + workspace = self._workspace + if workspace is None: - available_workspaces = self._client.workspaces - ws = available_workspaces[0] # type: ignore - warnings.warn(f"Workspace not provided. Using default workspace: {ws.name} id: {ws.id}") + workspace = self._client.workspaces.default + warnings.warn(f"Workspace not provided. Using default workspace: {workspace.name} id: {workspace.id}") elif isinstance(workspace, str): - available_workspace_names = [ws.name for ws in self._client.workspaces] - ws = self._client.workspaces(workspace) - if not ws.exists(): - self._log_message( - message=f"Workspace with name {workspace} not found. \ - Available workspaces: {available_workspace_names}", - level="error", + workspace = self._client.workspaces(workspace) + if not workspace.exists(): + available_workspace_names = [ws.name for ws in self._client.workspaces] + raise NotFoundError( + f"Workspace with name { workspace} not found. Available workspaces: {available_workspace_names}" ) - raise NotFoundError() - else: - ws = workspace - return ws.id + elif isinstance(workspace, UUID): + ws_model = self._client.api.workspaces.get(workspace) + workspace = Workspace.from_model(ws_model, client=self._client) + elif not isinstance(workspace, Workspace): + raise ValueError(f"Wrong workspace value found {workspace}") + + return workspace def _rollback_dataset_creation(self): if self.exists() and not self._is_published(): diff --git a/argilla/src/argilla/workspaces/_resource.py b/argilla/src/argilla/workspaces/_resource.py index 8a1b6b3a3e..630750f588 100644 --- a/argilla/src/argilla/workspaces/_resource.py +++ b/argilla/src/argilla/workspaces/_resource.py @@ -45,7 +45,6 @@ def __init__( name: Optional[str] = None, id: Optional[UUID] = None, client: Optional["Argilla"] = None, - _model: Optional[WorkspaceModel] = None, ) -> None: """Initializes a Workspace object with a client and a name or id @@ -53,15 +52,13 @@ def __init__( client (Argilla): The client used to interact with Argilla name (str): The name of the workspace id (UUID): The id of the workspace - _model (WorkspaceModel): The internal Pydantic model of the workspace from/to the server Returns: Workspace: The initialized workspace object """ client = client or Argilla._get_default() super().__init__(client=client, api=client.api.workspaces) - if _model is None: - _model = WorkspaceModel(name=name, id=id) - self._model = _model + + self._model = WorkspaceModel(name=name, id=id) def exists(self) -> bool: """ @@ -103,6 +100,13 @@ def list_datasets(self) -> List["Dataset"]: self._log_message(f"Got {len(datasets)} datasets for workspace {self.id}") return [Dataset.from_model(model=dataset, client=self._client) for dataset in datasets] + @classmethod + def from_model(cls, model: WorkspaceModel, client: Argilla) -> "Workspace": + instance = cls(name=model.name, id=model.id, client=client) + instance._model = model + + return instance + ############################ # Properties ############################ diff --git a/argilla/tests/integration/test_dataset_workspace.py b/argilla/tests/integration/test_dataset_workspace.py index 50eb5f7b7f..cca87f13d0 100644 --- a/argilla/tests/integration/test_dataset_workspace.py +++ b/argilla/tests/integration/test_dataset_workspace.py @@ -71,7 +71,7 @@ def test_dataset_with_workspace(client: rg.Argilla): assert isinstance(dataset, rg.Dataset) assert dataset.id is not None assert dataset.exists() - assert dataset.workspace_id == ws.id + assert dataset.workspace == ws def test_dataset_with_workspace_name(client: rg.Argilla): @@ -93,7 +93,7 @@ def test_dataset_with_workspace_name(client: rg.Argilla): assert isinstance(dataset, rg.Dataset) assert dataset.id is not None assert dataset.exists() - assert dataset.workspace_id == ws.id + assert dataset.workspace == ws def test_dataset_with_incorrect_workspace_name(client: rg.Argilla): @@ -110,7 +110,7 @@ def test_dataset_with_incorrect_workspace_name(client: rg.Argilla): ), workspace=f"non_existing_workspace_{random.randint(0, 1000)}", client=client, - ) + ).create() def test_dataset_with_default_workspace(client: rg.Argilla): @@ -130,7 +130,7 @@ def test_dataset_with_default_workspace(client: rg.Argilla): assert isinstance(dataset, rg.Dataset) assert dataset.id is not None assert dataset.exists() - assert dataset.workspace_id == client.workspaces[0].id + assert dataset.workspace == client.workspaces[0] def test_retrieving_dataset(client: rg.Argilla, dataset: rg.Dataset): diff --git a/argilla/tests/unit/helpers/test_resource_repr.py b/argilla/tests/unit/helpers/test_resource_repr.py index e6bdcbd5fe..ebf8b94767 100644 --- a/argilla/tests/unit/helpers/test_resource_repr.py +++ b/argilla/tests/unit/helpers/test_resource_repr.py @@ -39,14 +39,13 @@ def test_represent_workspaces_as_html(self): workspace = rg.Workspace(name="workspace1", id=uuid.uuid4()) datasets = [ - rg.Dataset.from_model( - DatasetModel(id=uuid.uuid4(), name="dataset1", workspace_id=workspace.id), client=client - ), - rg.Dataset.from_model( - DatasetModel(id=uuid.uuid4(), name="dataset2", workspace_id=workspace.id), client=client - ), + rg.Dataset(name="dataset1", workspace=workspace, client=client), + rg.Dataset(name="dataset2", workspace=workspace, client=client), ] + for dataset in datasets: + dataset.id = uuid.uuid4() + assert ( ResourceHTMLReprMixin()._represent_as_html(datasets) == "

Datasets

" ""