Skip to content

Commit

Permalink
[ENHANCEMENT/REFACTOR] argilla: lazy resolution for dataset workspa…
Browse files Browse the repository at this point in the history
…ces (#5152)

<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR removes the workspace resolution on `Dataset.__init__` and
performs it lazily when creating datasets or accessing the
`dataset.workspace` property

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)
- Documentation update

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- follows the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 4, 2024
1 parent 3d25920 commit 54d2e60
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 52 deletions.
2 changes: 1 addition & 1 deletion argilla/src/argilla/_helpers/_resource_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class ResourceHTMLReprMixin:
def _resource_to_table_row(self, resource) -> Dict[str, Any]:
row = {}
dumped_resource_model = resource._model.model_dump()
dumped_resource_model = resource.api_model().model_dump()
resource_name = resource.__class__.__name__
config = RESOURCE_REPR_CONFIG[resource_name].copy()
len_column = config.pop("len_column", None)
Expand Down
12 changes: 6 additions & 6 deletions argilla/src/argilla/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,28 @@ def api_model(self):
############################

def create(self) -> "Resource":
response_model = self._api.create(self._model)
response_model = self._api.create(self.api_model())
self._model = response_model
self._update_last_api_call()
self._log_message(f"Resource created: {self}")
return self

def get(self) -> "Resource":
response_model = self._api.get(self._model.id)
response_model = self._api.get(self.api_model().id)
self._model = response_model
self._update_last_api_call()
self._log_message(f"Resource fetched: {self}")
return self

def update(self) -> "Resource":
response_model = self._api.update(self._model)
response_model = self._api.update(self.api_model())
self._model = response_model
self._update_last_api_call()
self._log_message(f"Resource updated: {self}")
return self

def delete(self) -> None:
self._api.delete(self._model.id)
self._api.delete(self.api_model().id)
self._update_last_api_call()
self._log_message(f"Resource deleted: {self}")

Expand All @@ -109,13 +109,13 @@ def delete(self) -> None:

def serialize(self) -> dict[str, Any]:
try:
return self._model.model_dump()
return self.api_model().model_dump()
except Exception as e:
raise ArgillaSerializeError(f"Failed to serialize the resource. {e.__class__.__name__}") from e

def serialize_json(self) -> str:
try:
return self._model.model_dump_json()
return self.api_model().model_dump_json()
except Exception as e:
raise ArgillaSerializeError(f"Failed to serialize the resource. {e.__class__.__name__}") from e

Expand Down
6 changes: 3 additions & 3 deletions argilla/src/argilla/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __call__(self, name: str, **kwargs) -> "Workspace":

for model in workspace_models:
if model.name == name:
return Workspace(_model=model, client=self._client)
return self._from_model(model)
warnings.warn(
f"Workspace {name} not found. Creating a new workspace. Do `workspace.create()` to create the workspace."
)
Expand Down Expand Up @@ -252,7 +252,7 @@ def _repr_html_(self) -> str:
def _from_model(self, model: WorkspaceModel) -> "Workspace":
from argilla.workspaces import Workspace

return Workspace(client=self._client, _model=model)
return Workspace.from_model(client=self._client, model=model)


class Datasets(Sequence["Dataset"], ResourceHTMLReprMixin):
Expand Down Expand Up @@ -325,4 +325,4 @@ def _repr_html_(self) -> str:
def _from_model(self, model: DatasetModel) -> "Dataset":
from argilla.datasets import Dataset

return Dataset(client=self._client, _model=model)
return Dataset.from_model(model=model, client=self._client)
60 changes: 33 additions & 27 deletions argilla/src/argilla/datasets/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from argilla._api import DatasetsAPI
from argilla._exceptions import NotFoundError, SettingsError
from argilla._helpers import UUIDUtilities
from argilla._models import DatasetModel
from argilla._resource import Resource
from argilla.client import Argilla
Expand Down Expand Up @@ -52,10 +51,9 @@ class Dataset(Resource, DiskImportExportMixin):
def __init__(
self,
name: Optional[str] = None,
workspace: Optional[Union["Workspace", str]] = None,
workspace: Optional[Union["Workspace", str, UUID]] = None,
settings: Optional[Settings] = None,
client: Optional["Argilla"] = None,
_model: Optional[DatasetModel] = None,
) -> None:
"""Initializes a new Argilla Dataset object with the given parameters.
Expand All @@ -64,21 +62,15 @@ def __init__(
workspace (UUID): Workspace of the dataset. Default is the first workspace found in the server.
settings (Settings): Settings class to be used to configure the dataset.
client (Argilla): Instance of Argilla to connect with the server. Default is the default client.
_model (DatasetModel): Model of the dataset. Used to create the dataset from an existing model.
"""
client = client or Argilla._get_default()
super().__init__(client=client, api=client.api.datasets)
if name is None:
name = f"dataset_{uuid4()}"
self._log_message(f"Settings dataset name to unique UUID: {name}")

self.workspace_id = (
_model.workspace_id if _model and _model.workspace_id else self._workspace_id_from_name(workspace=workspace)
)
self._model = _model or DatasetModel(
name=name,
workspace_id=UUIDUtilities.convert_optional_uuid(uuid=self.workspace_id),
)
self._workspace = workspace
self._model = DatasetModel(name=name)
self._settings = settings or Settings(_dataset=self)
self._settings.dataset = self
self.__records = DatasetRecords(client=self._client, dataset=self)
Expand Down Expand Up @@ -136,6 +128,11 @@ def allow_extra_metadata(self, value: bool) -> None:
def schema(self) -> dict:
return self.settings.schema

@property
def workspace(self) -> Workspace:
self._workspace = self._resolve_workspace()
return self._workspace

#####################
# Core methods #
#####################
Expand Down Expand Up @@ -178,36 +175,45 @@ def update(self) -> "Dataset":

@classmethod
def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
return cls(client=client, _model=model)
instance = cls(client=client, workspace=model.workspace_id, name=model.name)
instance._model = model

return instance

#####################
# Utility methods #
#####################

def api_model(self) -> DatasetModel:
self._model.workspace_id = self.workspace.id
return self._model

def _publish(self) -> "Dataset":
self._settings.create()
self._api.publish(dataset_id=self._model.id)

return self.get()

def _workspace_id_from_name(self, workspace: Optional[Union["Workspace", str]]) -> UUID:
def _resolve_workspace(self) -> Workspace:
workspace = self._workspace

if workspace is None:
available_workspaces = self._client.workspaces
ws = available_workspaces[0] # type: ignore
warnings.warn(f"Workspace not provided. Using default workspace: {ws.name} id: {ws.id}")
workspace = self._client.workspaces.default
warnings.warn(f"Workspace not provided. Using default workspace: {workspace.name} id: {workspace.id}")
elif isinstance(workspace, str):
available_workspace_names = [ws.name for ws in self._client.workspaces]
ws = self._client.workspaces(workspace)
if not ws.exists():
self._log_message(
message=f"Workspace with name {workspace} not found. \
Available workspaces: {available_workspace_names}",
level="error",
workspace = self._client.workspaces(workspace)
if not workspace.exists():
available_workspace_names = [ws.name for ws in self._client.workspaces]
raise NotFoundError(
f"Workspace with name { workspace} not found. Available workspaces: {available_workspace_names}"
)
raise NotFoundError()
else:
ws = workspace
return ws.id
elif isinstance(workspace, UUID):
ws_model = self._client.api.workspaces.get(workspace)
workspace = Workspace.from_model(ws_model, client=self._client)
elif not isinstance(workspace, Workspace):
raise ValueError(f"Wrong workspace value found {workspace}")

return workspace

def _rollback_dataset_creation(self):
if self.exists() and not self._is_published():
Expand Down
14 changes: 9 additions & 5 deletions argilla/src/argilla/workspaces/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,20 @@ def __init__(
name: Optional[str] = None,
id: Optional[UUID] = None,
client: Optional["Argilla"] = None,
_model: Optional[WorkspaceModel] = None,
) -> None:
"""Initializes a Workspace object with a client and a name or id
Parameters:
client (Argilla): The client used to interact with Argilla
name (str): The name of the workspace
id (UUID): The id of the workspace
_model (WorkspaceModel): The internal Pydantic model of the workspace from/to the server
Returns:
Workspace: The initialized workspace object
"""
client = client or Argilla._get_default()
super().__init__(client=client, api=client.api.workspaces)
if _model is None:
_model = WorkspaceModel(name=name, id=id)
self._model = _model

self._model = WorkspaceModel(name=name, id=id)

def exists(self) -> bool:
"""
Expand Down Expand Up @@ -103,6 +100,13 @@ def list_datasets(self) -> List["Dataset"]:
self._log_message(f"Got {len(datasets)} datasets for workspace {self.id}")
return [Dataset.from_model(model=dataset, client=self._client) for dataset in datasets]

@classmethod
def from_model(cls, model: WorkspaceModel, client: Argilla) -> "Workspace":
instance = cls(name=model.name, id=model.id, client=client)
instance._model = model

return instance

############################
# Properties
############################
Expand Down
8 changes: 4 additions & 4 deletions argilla/tests/integration/test_dataset_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_dataset_with_workspace(client: rg.Argilla):
assert isinstance(dataset, rg.Dataset)
assert dataset.id is not None
assert dataset.exists()
assert dataset.workspace_id == ws.id
assert dataset.workspace == ws


def test_dataset_with_workspace_name(client: rg.Argilla):
Expand All @@ -93,7 +93,7 @@ def test_dataset_with_workspace_name(client: rg.Argilla):
assert isinstance(dataset, rg.Dataset)
assert dataset.id is not None
assert dataset.exists()
assert dataset.workspace_id == ws.id
assert dataset.workspace == ws


def test_dataset_with_incorrect_workspace_name(client: rg.Argilla):
Expand All @@ -110,7 +110,7 @@ def test_dataset_with_incorrect_workspace_name(client: rg.Argilla):
),
workspace=f"non_existing_workspace_{random.randint(0, 1000)}",
client=client,
)
).create()


def test_dataset_with_default_workspace(client: rg.Argilla):
Expand All @@ -130,7 +130,7 @@ def test_dataset_with_default_workspace(client: rg.Argilla):
assert isinstance(dataset, rg.Dataset)
assert dataset.id is not None
assert dataset.exists()
assert dataset.workspace_id == client.workspaces[0].id
assert dataset.workspace == client.workspaces[0]


def test_retrieving_dataset(client: rg.Argilla, dataset: rg.Dataset):
Expand Down
11 changes: 5 additions & 6 deletions argilla/tests/unit/helpers/test_resource_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ def test_represent_workspaces_as_html(self):

workspace = rg.Workspace(name="workspace1", id=uuid.uuid4())
datasets = [
rg.Dataset.from_model(
DatasetModel(id=uuid.uuid4(), name="dataset1", workspace_id=workspace.id), client=client
),
rg.Dataset.from_model(
DatasetModel(id=uuid.uuid4(), name="dataset2", workspace_id=workspace.id), client=client
),
rg.Dataset(name="dataset1", workspace=workspace, client=client),
rg.Dataset(name="dataset2", workspace=workspace, client=client),
]

for dataset in datasets:
dataset.id = uuid.uuid4()

assert (
ResourceHTMLReprMixin()._represent_as_html(datasets) == "<h3>Datasets</h3>"
"<table>"
Expand Down

0 comments on commit 54d2e60

Please sign in to comment.