Skip to content

Commit

Permalink
[FEATURE] add different error handling strategies to the log method (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
burtenshaw authored Sep 18, 2024
1 parent 5c1a5fc commit 7982a92
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
24 changes: 22 additions & 2 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
from uuid import UUID
from enum import Enum

from tqdm import tqdm

Expand All @@ -32,6 +33,12 @@
from argilla.datasets import Dataset


class RecordErrorHandling(Enum):
RAISE = "raise"
WARN = "warn"
IGNORE = "ignore"


class DatasetRecordsIterator:
"""This class is used to iterate over records in a dataset"""

Expand Down Expand Up @@ -210,6 +217,7 @@ def log(
mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
user_id: Optional[UUID] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
on_error: RecordErrorHandling = RecordErrorHandling.RAISE,
) -> "DatasetRecords":
"""Add or update records in a dataset on the server using the provided records.
If the record includes a known `id` field, the record will be updated.
Expand All @@ -228,7 +236,9 @@ def log(
Returns:
A list of Record objects representing the updated records.
"""
record_models = self._ingest_records(records=records, mapping=mapping, user_id=user_id or self.__client.me.id)
record_models = self._ingest_records(
records=records, mapping=mapping, user_id=user_id or self.__client.me.id, on_error=on_error
)
batch_size = self._normalize_batch_size(
batch_size=batch_size,
records_length=len(record_models),
Expand Down Expand Up @@ -380,6 +390,7 @@ def _ingest_records(
records: Union[List[Dict[str, Any]], List[Record], HFDataset],
mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
user_id: Optional[UUID] = None,
on_error: RecordErrorHandling = RecordErrorHandling.RAISE,
) -> List[RecordModel]:
"""Ingests records from a list of dictionaries, a Hugging Face Dataset, or a list of Record objects."""

Expand All @@ -405,7 +416,16 @@ def _ingest_records(
f"Found a record of type {type(record)}: {record}."
)
except Exception as e:
raise RecordsIngestionError(f"Failed to ingest record from dict {record}: {e}")
if on_error == RecordErrorHandling.IGNORE:
self._log_message(
message=f"Failed to ingest record from dict {record}: {e}",
level="info",
)
continue
elif on_error == RecordErrorHandling.WARN:
warnings.warn(f"Failed to ingest record from dict {record}: {e}")
continue
raise RecordsIngestionError(f"Failed to ingest record from dict {record}") from e
ingested_records.append(record.api_model())
return ingested_records

Expand Down
55 changes: 54 additions & 1 deletion argilla/tests/unit/test_record_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from uuid import uuid4

import pytest

import argilla as rg
from argilla._exceptions import RecordsIngestionError
from argilla.records._dataset_records import RecordErrorHandling


@pytest.fixture
Expand Down Expand Up @@ -207,3 +208,55 @@ def test_ingest_record_from_dict_with_mapping_multiple():
assert record.fields["prompt_field"] == "What is the capital of France?"
assert "positive" in suggestions
assert "What is the capital of France?" in suggestions


def test_ingest_records_on_error_raise(dataset):
with pytest.raises(RecordsIngestionError):
dataset.records._ingest_records(
records=[
{"prompt": "Valid record"},
{"invalid_field": "This should raise an error"},
],
on_error=RecordErrorHandling.RAISE,
)


def test_ingest_records_on_error_warn(dataset):
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered
warnings.simplefilter("always")

records = dataset.records._ingest_records(
records=[
{"prompt": "Valid record"},
{"invalid_field": "This should warn"},
],
on_error=RecordErrorHandling.WARN,
)

# Check that we got one warning
assert len(caught_warnings) == 2
# Check that the message matches
caught_warning_messages = [str(w.message) for w in caught_warnings]
assert any("Failed to ingest record" in message for message in caught_warning_messages)
assert any("invalid_field" in message for message in caught_warning_messages)

# Check that only the valid record was ingested
assert len(records) == 1
assert records[0].fields["prompt"] == "Valid record"


def test_ingest_records_on_error_ignore(dataset, caplog):
records = dataset.records._ingest_records(
records=[
{"prompt": "Valid record 1"},
{"invalid_field": "This should be ignored"},
{"prompt": "Valid record 2"},
],
on_error=RecordErrorHandling.IGNORE,
)

assert len(records) == 2
assert records[0].fields["prompt"] == "Valid record 1"
assert records[1].fields["prompt"] == "Valid record 2"
assert "Failed to ingest record" not in caplog.text

0 comments on commit 7982a92

Please sign in to comment.