Skip to content

Commit

Permalink
[BUGFIX] argilla: update dataset for settings properties (#5305)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Fields, vectors, and metadata, do not update the inner dataset/client
when fetching data remotely. This can cause errors when working with
different clients.

This PR fixes this by updating the dataset when assigning setting
properties.

The internals of this should be improved by simplifying relationships
between resources, clients, and APIs.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 24, 2024
1 parent b4e8bad commit 122a17a
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 11 deletions.
11 changes: 10 additions & 1 deletion argilla/src/argilla/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from datetime import datetime
from typing import Any, TYPE_CHECKING, Optional
from uuid import UUID

try:
from typing import Self
except ImportError:
from typing_extensions import Self

from argilla._exceptions import ArgillaSerializeError
from argilla._helpers import LoggingMixin

Expand Down Expand Up @@ -125,3 +130,7 @@ def _update_last_api_call(self):
def _seconds_from_last_api_call(self) -> Optional[float]:
if self._last_api_call:
return (datetime.utcnow() - self._last_api_call).total_seconds()

@abstractmethod
def _with_client(self, client: "Argilla") -> "Self":
pass
17 changes: 16 additions & 1 deletion argilla/src/argilla/settings/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from argilla.settings._metadata import MetadataField, MetadataType
from argilla.settings._vector import VectorField

try:
from typing import Self
except ImportError:
from typing_extensions import Self

if TYPE_CHECKING:
from argilla.datasets import Dataset

Expand All @@ -33,7 +38,7 @@ class TextField(SettingsPropertyBase):
_model: FieldModel
_api: FieldsAPI

_dataset: "Dataset"
_dataset: Optional["Dataset"]

def __init__(
self,
Expand Down Expand Up @@ -64,6 +69,8 @@ def __init__(
settings=TextFieldSettings(use_markdown=use_markdown),
)

self._dataset = None

@classmethod
def from_model(cls, model: FieldModel) -> "TextField":
instance = cls(name=model.name)
Expand Down Expand Up @@ -92,6 +99,14 @@ def dataset(self) -> "Dataset":
def dataset(self, value: "Dataset") -> None:
self._dataset = value
self._model.dataset_id = self._dataset.id
self._with_client(self._dataset._client)

def _with_client(self, client: "Argilla") -> "Self":
# TODO: Review and simplify. Maybe only one of them is required
self._client = client
self._api = self._client.api.fields

return self


def field_from_dict(data: dict) -> Union[TextField, VectorField, MetadataType]:
Expand Down
17 changes: 16 additions & 1 deletion argilla/src/argilla/settings/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from argilla._resource import Resource
from argilla.client import Argilla

try:
from typing import Self
except ImportError:
from typing_extensions import Self

if TYPE_CHECKING:
from argilla import Dataset

Expand All @@ -41,12 +46,14 @@ class MetadataPropertyBase(Resource):
_model: MetadataFieldModel
_api: MetadataAPI

_dataset: "Dataset"
_dataset: Optional["Dataset"]

def __init__(self, client: Optional[Argilla] = None) -> None:
client = client or Argilla._get_default()
super().__init__(client=client, api=client.api.metadata)

self._dataset = None

@property
def name(self) -> str:
return self._model.name
Expand Down Expand Up @@ -79,12 +86,20 @@ def dataset(self) -> Optional["Dataset"]:
def dataset(self, value: "Dataset") -> None:
self._dataset = value
self._model.dataset_id = value.id
self._with_client(self._dataset._client)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(name={self.name}, title={self.title}, dimensions={self.visible_for_annotators})"
)

def _with_client(self, client: "Argilla") -> "Self":
# TODO: Review and simplify. Maybe only one of them is required
self._client = client
self._api = self._client.api.metadata

return self


class TermsMetadataProperty(MetadataPropertyBase):
def __init__(
Expand Down
12 changes: 12 additions & 0 deletions argilla/src/argilla/settings/_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List, Literal, Optional, Union, Dict

from argilla import Argilla
from argilla._models._settings._questions import (
LabelQuestionModel,
LabelQuestionSettings,
Expand All @@ -31,6 +32,11 @@
)
from argilla.settings._common import SettingsPropertyBase

try:
from typing import Self
except ImportError:
from typing_extensions import Self

__all__ = [
"LabelQuestion",
"MultiLabelQuestion",
Expand Down Expand Up @@ -71,6 +77,12 @@ def _render_options_as_labels(cls, options: List[Dict[str, str]]) -> List[str]:
"""Render values as labels for the question so that they can be returned as a list of strings"""
return list(cls._render_options_as_values(options=options).keys())

def _with_client(self, client: "Argilla") -> "Self":
self._client = client
self._api = client.api.questions

return self


class LabelQuestion(QuestionPropertyBase):
_model: LabelQuestionModel
Expand Down
14 changes: 7 additions & 7 deletions argilla/src/argilla/settings/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,16 @@ def __init__(
"""
super().__init__(client=_dataset._client if _dataset else None)

self._dataset = _dataset
self._distribution = distribution
self.__guidelines = self.__process_guidelines(guidelines)
self.__allow_extra_metadata = allow_extra_metadata

self.__questions = QuestionsProperties(self, questions)
self.__fields = SettingsProperties(self, fields)
self.__vectors = SettingsProperties(self, vectors)
self.__metadata = SettingsProperties(self, metadata)

self.__guidelines = self.__process_guidelines(guidelines)
self.__allow_extra_metadata = allow_extra_metadata

self._distribution = distribution

self._dataset = _dataset

#####################
# Properties #
#####################
Expand Down Expand Up @@ -392,6 +390,8 @@ def __init__(self, settings: "Settings", properties: List[Property]):
self._settings = settings

for property in properties or []:
if self._settings.dataset and hasattr(property, "dataset"):
property.dataset = self._settings.dataset
self.add(property)

def __getitem__(self, key: Union[UUID, str, int]) -> Optional[Property]:
Expand Down
10 changes: 9 additions & 1 deletion argilla/src/argilla/settings/_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class VectorField(Resource):

_model: VectorFieldModel
_api: VectorsAPI
_dataset: "Dataset"
_dataset: Optional["Dataset"]

def __init__(
self,
Expand Down Expand Up @@ -83,6 +83,7 @@ def dataset(self) -> "Dataset":
def dataset(self, value: "Dataset") -> None:
self._dataset = value
self._model.dataset_id = self._dataset.id
self._with_client(self._dataset._client)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name}, title={self.title}, dimensions={self.dimensions})"
Expand All @@ -98,3 +99,10 @@ def from_model(cls, model: VectorFieldModel) -> "VectorField":
def from_dict(cls, data: dict) -> "VectorField":
model = VectorFieldModel(**data)
return cls.from_model(model=model)

def _with_client(self, client: "Argilla") -> "VectorField":
# TODO: Review and simplify. Maybe only one of them is required
self._client = client
self._api = self._client.api.vectors

return self
5 changes: 5 additions & 0 deletions argilla/tests/integration/test_create_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def test_create_a_dataset_copy(self, client: Argilla, dataset_name: str):
settings=dataset.settings,
).create()

for properties in [new_dataset.settings.fields, new_dataset.settings.vectors, new_dataset.settings.metadata]:
for property in properties:
assert property.dataset == new_dataset
assert property._client == new_dataset._client

records = list(dataset.records(with_vectors=True))
new_dataset.records.log(records)

Expand Down

0 comments on commit 122a17a

Please sign in to comment.