Skip to content

Commit

Permalink
[FEATURE] cast datasets pil objects to base64 in SDK (#5433)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->
This PR introduces a utility into the SDK so that it automatically casts
PIL objects to base64 based on the datasets `Features`

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)
- Documentation update

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
burtenshaw and pre-commit-ci[bot] authored Sep 2, 2024
1 parent 7cd1e1b commit de6c3fa
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 5 deletions.
1 change: 1 addition & 0 deletions argilla/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"tqdm>=4.60.0",
"rich>=10.0.0",
"datasets>=2.0.0",
"pillow>=9.5.0",
]

legacy = [
Expand Down
107 changes: 107 additions & 0 deletions argilla/src/argilla/_helpers/_media.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import io
import warnings
from pathlib import Path
from typing import Union

from PIL import Image


def pil_to_data_uri(image_object: "Image") -> str:
"""Convert a PIL image to a base64 data URI string.
Parameters:
image_object (Image): The PIL image to convert to a base64 data URI.
Returns:
str: The data URI string.
"""
if not isinstance(image_object, Image.Image):
raise ValueError("The image_object must be a PIL Image object.")

image_format = image_object.format
if image_format is None:
image_format = "PNG"
warnings.warn("The image format is not set. Defaulting to PNG.", UserWarning)

try:
buffered = io.BytesIO()
image_object.save(buffered, format=image_format)
except Exception as e:
raise ValueError("An error occurred while saving the image binary to buffer") from e

try:
img_str = base64.b64encode(buffered.getvalue()).decode()
mimetype = f"image/{image_format.lower()}"
data_uri = f"data:{mimetype};base64,{img_str}"
except Exception as e:
raise ValueError("An error occurred while converting the image binary to base64") from e

return data_uri


def filepath_to_data_uri(file_path: "Path") -> str:
"""Convert an image file to a base64 data URI string."""
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "rb") as image_file:
img_str = base64.b64encode(image_file.read()).decode()
mimetype = f"image/{file_path.suffix[1:]}"
data_uri = f"data:{mimetype};base64,{img_str}"
else:
raise FileNotFoundError(f"File not found at {file_path}")

return data_uri


def cast_image(image: Union["Image", str, Path]) -> str:
"""Convert a PIL image to a base64 data URI string.
Parameters:
image_object (Image): The PIL image to convert to a base64 data URI.
Returns:
str: The data URI string.
"""
if isinstance(image, str):
if image.startswith("data:") or image.startswith("http"):
return image
else:
return filepath_to_data_uri(image)
elif isinstance(image, Path):
return filepath_to_data_uri(image)
elif isinstance(image, Image.Image):
return pil_to_data_uri(image)
else:
raise ValueError("The image must be a data URI string, a file path, or a PIL Image object.")


def uncast_image(image: str) -> "Image":
"""Convert a base64 data URI string to a PIL image."""
if isinstance(image, Image.Image):
return image
elif not isinstance(image, str):
raise ValueError("The image must be a data URI string.")
elif image.startswith("data:image"):
try:
image_data = base64.b64decode(image.split(",")[1])
image = Image.open(io.BytesIO(image_data))
except Exception as e:
raise ValueError("An error occurred while converting the data URI to a PIL image.") from e
return image
elif image.startswith("http"):
return image
elif Path(image).exists():
return Image.open(image)
else:
raise ValueError("The image must be a data URI string.")
52 changes: 50 additions & 2 deletions argilla/src/argilla/records/_io/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Union
from uuid import uuid4

from datasets import Dataset as HFDataset
from datasets import IterableDataset
from datasets import IterableDataset, Image

from argilla.records._io._generic import GenericIO
from argilla._helpers._media import pil_to_data_uri

if TYPE_CHECKING:
from argilla.records import Record
Expand Down Expand Up @@ -58,11 +60,57 @@ def _record_dicts_from_datasets(dataset: HFDataset) -> List[Dict[str, Union[str,
Returns:
Generator[Dict[str, Union[str, float, int, list]], None, None]: A generator of dictionaries to be passed to DatasetRecords.add or DatasetRecords.update.
"""
record_dicts = []
media_features = HFDatasetsIO._get_image_features(dataset)
if media_features:
dataset = HFDatasetsIO._cast_images_as_urls(hf_dataset=dataset, columns=media_features)
try:
dataset: IterableDataset = dataset.to_iterable_dataset()
except AttributeError:
pass
record_dicts = []
for example in dataset:
record_dicts.append(example)
return record_dicts

@staticmethod
def _get_image_features(dataset: "HFDataset") -> List[str]:
"""Check if the Hugging Face dataset contains image features.
Parameters:
hf_dataset (HFDataset): The Hugging Face dataset to check.
Returns:
bool: True if the Hugging Face dataset contains image features, False otherwise.
"""
media_features = [name for name, feature in dataset.features.items() if isinstance(feature, Image)]
return media_features

@staticmethod
def _cast_images_as_urls(hf_dataset: "HFDataset", columns: List[str]) -> "HFDataset":
"""Cast the image features in the Hugging Face dataset as URLs.
Parameters:
hf_dataset (HFDataset): The Hugging Face dataset to cast.
repo_id (str): The ID of the Hugging Face Hub repo.
Returns:
HFDataset: The Hugging Face dataset with image features cast as URLs.
"""

unique_identifier = uuid4().hex

def batch_fn(batch):
data_uris = [pil_to_data_uri(sample) for sample in batch]
return {unique_identifier: data_uris}

for column in columns:
hf_dataset = hf_dataset.map(
function=batch_fn,
with_indices=False,
batched=True,
input_columns=[column],
remove_columns=[column],
)
hf_dataset = hf_dataset.rename_column(original_column_name=unique_identifier, new_column_name=column)

return hf_dataset
17 changes: 14 additions & 3 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from uuid import UUID

from argilla._exceptions import ArgillaError
from argilla._helpers._media import cast_image, uncast_image
from argilla._models import (
FieldValue,
MetadataModel,
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(

self._dataset = _dataset
self._model = RecordModel(external_id=id, id=_server_id)
self.__fields = RecordFields(fields=fields)
self.__fields = RecordFields(fields=fields, record=self)
self.__vectors = RecordVectors(vectors=vectors)
self.__metadata = RecordMetadata(metadata=metadata)
self.__responses = RecordResponses(responses=responses, record=self)
Expand Down Expand Up @@ -272,11 +273,21 @@ class RecordFields(dict):
It allows for accessing fields by attribute and key name.
"""

def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None:
def __init__(self, record: Record, fields: Optional[Dict[str, FieldValue]] = None) -> None:
super().__init__(fields or {})
self.record = record

def to_dict(self) -> dict:
return dict(self.items())
return {key: cast_image(value) if self._is_image(key) else value for key, value in self.items()}

def __getitem__(self, key: str) -> FieldValue:
value = super().__getitem__(key)
return uncast_image(value) if self._is_image(key) else value

def _is_image(self, key: str) -> bool:
if not self.record.dataset:
return False
return self.record.dataset.settings.schema[key].type == "image"


class RecordMetadata(dict):
Expand Down
86 changes: 86 additions & 0 deletions argilla/tests/unit/test_media.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tempfile import NamedTemporaryFile

import pytest
from PIL import Image
from argilla._helpers._media import cast_image, pil_to_data_uri, uncast_image


@pytest.fixture
def pil_image():
image = Image.new("RGB", (100, 100), color="red")
return image


@pytest.fixture
def data_uri_image(pil_image):
data_uri = pil_to_data_uri(pil_image)
return data_uri


@pytest.fixture
def path_to_image(pil_image):
with NamedTemporaryFile(suffix=".jpg") as f:
pil_image.save(f.name)
yield f.name


def test_cast_image_with_pil_image(pil_image):
result = cast_image(pil_image)
uncasted = uncast_image(result)

assert isinstance(result, str)
assert result.startswith("data:image")
assert "base64" in result

assert isinstance(uncasted, Image.Image)
assert uncasted.size == pil_image.size
assert uncasted.mode == pil_image.mode
assert uncasted.getcolors() == pil_image.getcolors()


def test_cast_image_with_file_path(path_to_image):
result = cast_image(path_to_image)
uncasted = uncast_image(result)
pil_image = Image.open(path_to_image)

assert isinstance(result, str)
assert result.startswith("data:image")
assert "base64" in result

assert isinstance(uncasted, Image.Image)
assert uncasted.size == pil_image.size
assert uncasted.mode == pil_image.mode
assert uncasted.getcolors() == pil_image.getcolors()


def test_cast_image_with_data_uri(data_uri_image):
result = cast_image(data_uri_image)
uncasted = uncast_image(result)

assert result == data_uri_image
assert isinstance(uncasted, Image.Image)


def test_cast_image_with_invalid_input():
invalid_input = 123
with pytest.raises(ValueError):
cast_image(invalid_input)


def test_uncast_image_with_url():
image_url = "https://example.com/image.jpg"
result = uncast_image(image_url)
assert result == image_url
Loading

0 comments on commit de6c3fa

Please sign in to comment.