diff --git a/argilla/pyproject.toml b/argilla/pyproject.toml index 305b387804..c0bda31039 100644 --- a/argilla/pyproject.toml +++ b/argilla/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "tqdm>=4.60.0", "rich>=10.0.0", "datasets>=2.0.0", + "pillow>=9.5.0", ] legacy = [ diff --git a/argilla/src/argilla/_helpers/_media.py b/argilla/src/argilla/_helpers/_media.py new file mode 100644 index 0000000000..4f2fe8faa6 --- /dev/null +++ b/argilla/src/argilla/_helpers/_media.py @@ -0,0 +1,107 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io +import warnings +from pathlib import Path +from typing import Union + +from PIL import Image + + +def pil_to_data_uri(image_object: "Image") -> str: + """Convert a PIL image to a base64 data URI string. + Parameters: + image_object (Image): The PIL image to convert to a base64 data URI. + Returns: + str: The data URI string. + """ + if not isinstance(image_object, Image.Image): + raise ValueError("The image_object must be a PIL Image object.") + + image_format = image_object.format + if image_format is None: + image_format = "PNG" + warnings.warn("The image format is not set. Defaulting to PNG.", UserWarning) + + try: + buffered = io.BytesIO() + image_object.save(buffered, format=image_format) + except Exception as e: + raise ValueError("An error occurred while saving the image binary to buffer") from e + + try: + img_str = base64.b64encode(buffered.getvalue()).decode() + mimetype = f"image/{image_format.lower()}" + data_uri = f"data:{mimetype};base64,{img_str}" + except Exception as e: + raise ValueError("An error occurred while converting the image binary to base64") from e + + return data_uri + + +def filepath_to_data_uri(file_path: "Path") -> str: + """Convert an image file to a base64 data URI string.""" + file_path = Path(file_path) + if file_path.exists(): + with open(file_path, "rb") as image_file: + img_str = base64.b64encode(image_file.read()).decode() + mimetype = f"image/{file_path.suffix[1:]}" + data_uri = f"data:{mimetype};base64,{img_str}" + else: + raise FileNotFoundError(f"File not found at {file_path}") + + return data_uri + + +def cast_image(image: Union["Image", str, Path]) -> str: + """Convert a PIL image to a base64 data URI string. + Parameters: + image_object (Image): The PIL image to convert to a base64 data URI. + Returns: + str: The data URI string. + """ + if isinstance(image, str): + if image.startswith("data:") or image.startswith("http"): + return image + else: + return filepath_to_data_uri(image) + elif isinstance(image, Path): + return filepath_to_data_uri(image) + elif isinstance(image, Image.Image): + return pil_to_data_uri(image) + else: + raise ValueError("The image must be a data URI string, a file path, or a PIL Image object.") + + +def uncast_image(image: str) -> "Image": + """Convert a base64 data URI string to a PIL image.""" + if isinstance(image, Image.Image): + return image + elif not isinstance(image, str): + raise ValueError("The image must be a data URI string.") + elif image.startswith("data:image"): + try: + image_data = base64.b64decode(image.split(",")[1]) + image = Image.open(io.BytesIO(image_data)) + except Exception as e: + raise ValueError("An error occurred while converting the data URI to a PIL image.") from e + return image + elif image.startswith("http"): + return image + elif Path(image).exists(): + return Image.open(image) + else: + raise ValueError("The image must be a data URI string.") diff --git a/argilla/src/argilla/records/_io/_datasets.py b/argilla/src/argilla/records/_io/_datasets.py index a93aa5724c..84a3239958 100644 --- a/argilla/src/argilla/records/_io/_datasets.py +++ b/argilla/src/argilla/records/_io/_datasets.py @@ -13,11 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING, Any, Dict, List, Union +from uuid import uuid4 from datasets import Dataset as HFDataset -from datasets import IterableDataset +from datasets import IterableDataset, Image from argilla.records._io._generic import GenericIO +from argilla._helpers._media import pil_to_data_uri if TYPE_CHECKING: from argilla.records import Record @@ -58,11 +60,57 @@ def _record_dicts_from_datasets(dataset: HFDataset) -> List[Dict[str, Union[str, Returns: Generator[Dict[str, Union[str, float, int, list]], None, None]: A generator of dictionaries to be passed to DatasetRecords.add or DatasetRecords.update. """ - record_dicts = [] + media_features = HFDatasetsIO._get_image_features(dataset) + if media_features: + dataset = HFDatasetsIO._cast_images_as_urls(hf_dataset=dataset, columns=media_features) try: dataset: IterableDataset = dataset.to_iterable_dataset() except AttributeError: pass + record_dicts = [] for example in dataset: record_dicts.append(example) return record_dicts + + @staticmethod + def _get_image_features(dataset: "HFDataset") -> List[str]: + """Check if the Hugging Face dataset contains image features. + + Parameters: + hf_dataset (HFDataset): The Hugging Face dataset to check. + + Returns: + bool: True if the Hugging Face dataset contains image features, False otherwise. + """ + media_features = [name for name, feature in dataset.features.items() if isinstance(feature, Image)] + return media_features + + @staticmethod + def _cast_images_as_urls(hf_dataset: "HFDataset", columns: List[str]) -> "HFDataset": + """Cast the image features in the Hugging Face dataset as URLs. + + Parameters: + hf_dataset (HFDataset): The Hugging Face dataset to cast. + repo_id (str): The ID of the Hugging Face Hub repo. + + Returns: + HFDataset: The Hugging Face dataset with image features cast as URLs. + """ + + unique_identifier = uuid4().hex + + def batch_fn(batch): + data_uris = [pil_to_data_uri(sample) for sample in batch] + return {unique_identifier: data_uris} + + for column in columns: + hf_dataset = hf_dataset.map( + function=batch_fn, + with_indices=False, + batched=True, + input_columns=[column], + remove_columns=[column], + ) + hf_dataset = hf_dataset.rename_column(original_column_name=unique_identifier, new_column_name=column) + + return hf_dataset diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 43b3d12140..543028eeb7 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -17,6 +17,7 @@ from uuid import UUID from argilla._exceptions import ArgillaError +from argilla._helpers._media import cast_image, uncast_image from argilla._models import ( FieldValue, MetadataModel, @@ -88,7 +89,7 @@ def __init__( self._dataset = _dataset self._model = RecordModel(external_id=id, id=_server_id) - self.__fields = RecordFields(fields=fields) + self.__fields = RecordFields(fields=fields, record=self) self.__vectors = RecordVectors(vectors=vectors) self.__metadata = RecordMetadata(metadata=metadata) self.__responses = RecordResponses(responses=responses, record=self) @@ -272,11 +273,21 @@ class RecordFields(dict): It allows for accessing fields by attribute and key name. """ - def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None: + def __init__(self, record: Record, fields: Optional[Dict[str, FieldValue]] = None) -> None: super().__init__(fields or {}) + self.record = record def to_dict(self) -> dict: - return dict(self.items()) + return {key: cast_image(value) if self._is_image(key) else value for key, value in self.items()} + + def __getitem__(self, key: str) -> FieldValue: + value = super().__getitem__(key) + return uncast_image(value) if self._is_image(key) else value + + def _is_image(self, key: str) -> bool: + if not self.record.dataset: + return False + return self.record.dataset.settings.schema[key].type == "image" class RecordMetadata(dict): diff --git a/argilla/tests/unit/test_media.py b/argilla/tests/unit/test_media.py new file mode 100644 index 0000000000..30181ccc6f --- /dev/null +++ b/argilla/tests/unit/test_media.py @@ -0,0 +1,86 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tempfile import NamedTemporaryFile + +import pytest +from PIL import Image +from argilla._helpers._media import cast_image, pil_to_data_uri, uncast_image + + +@pytest.fixture +def pil_image(): + image = Image.new("RGB", (100, 100), color="red") + return image + + +@pytest.fixture +def data_uri_image(pil_image): + data_uri = pil_to_data_uri(pil_image) + return data_uri + + +@pytest.fixture +def path_to_image(pil_image): + with NamedTemporaryFile(suffix=".jpg") as f: + pil_image.save(f.name) + yield f.name + + +def test_cast_image_with_pil_image(pil_image): + result = cast_image(pil_image) + uncasted = uncast_image(result) + + assert isinstance(result, str) + assert result.startswith("data:image") + assert "base64" in result + + assert isinstance(uncasted, Image.Image) + assert uncasted.size == pil_image.size + assert uncasted.mode == pil_image.mode + assert uncasted.getcolors() == pil_image.getcolors() + + +def test_cast_image_with_file_path(path_to_image): + result = cast_image(path_to_image) + uncasted = uncast_image(result) + pil_image = Image.open(path_to_image) + + assert isinstance(result, str) + assert result.startswith("data:image") + assert "base64" in result + + assert isinstance(uncasted, Image.Image) + assert uncasted.size == pil_image.size + assert uncasted.mode == pil_image.mode + assert uncasted.getcolors() == pil_image.getcolors() + + +def test_cast_image_with_data_uri(data_uri_image): + result = cast_image(data_uri_image) + uncasted = uncast_image(result) + + assert result == data_uri_image + assert isinstance(uncasted, Image.Image) + + +def test_cast_image_with_invalid_input(): + invalid_input = 123 + with pytest.raises(ValueError): + cast_image(invalid_input) + + +def test_uncast_image_with_url(): + image_url = "https://example.com/image.jpg" + result = uncast_image(image_url) + assert result == image_url diff --git a/argilla/tests/unit/test_record_fields.py b/argilla/tests/unit/test_record_fields.py new file mode 100644 index 0000000000..36553954ab --- /dev/null +++ b/argilla/tests/unit/test_record_fields.py @@ -0,0 +1,82 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import random +from tempfile import NamedTemporaryFile + +from PIL import Image + +from argilla import Record, Settings, ImageField, Dataset + + +@pytest.fixture +def pil_image(): + image = Image.new("RGB", (100, 100), color="red") + return image + + +@pytest.fixture +def path_to_image(pil_image): + with NamedTemporaryFile(suffix=".jpg") as f: + pil_image.save(f.name) + yield f.name + + +@pytest.fixture +def dataset(): + dataset = Dataset( + name=f"test_dataset_{random.randint(1, 1000)}", + settings=Settings( + fields=[ImageField(name="image")], + ), + ) + return dataset + + +class TestRecordFields: + def test_create_record_fields(self): + record = Record(fields={"name": "John Doe"}, metadata={"age": 30}) + + fields = record.fields + assert fields["name"] == "John Doe" + assert record.metadata["age"] == 30 + + def test_create_record_image_path(self): + record = Record(fields={"image": "path/to/image.jpg"}) + + fields = record.fields + assert fields["image"] == "path/to/image.jpg" + + def test_create_dataset_with_local_image(self, path_to_image, pil_image, dataset): + record = Record(fields={"image": path_to_image}, _dataset=dataset) + + assert isinstance(record.fields["image"], Image.Image) + assert record.fields["image"].size == pil_image.size + assert record.fields["image"].mode == pil_image.mode + + def test_create_record_image_pil(self, pil_image, dataset): + record = Record(fields={"image": pil_image}, _dataset=dataset) + + fields = record.fields + assert isinstance(fields["image"], Image.Image) + assert fields["image"].size == pil_image.size + assert fields["image"].mode == pil_image.mode + + def test_create_record_with_wrong_image_type(self, dataset): + record = Record(fields={"image": 123}, _dataset=dataset) + with pytest.raises(ValueError): + record.fields.to_dict() + with pytest.raises(ValueError): + record.fields["image"]