From 7982a9287952e7a2e23a4899027eead7595e6106 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 18 Sep 2024 13:46:37 +0200 Subject: [PATCH] [FEATURE] add different error handling strategies to the log method (#5463) (#5510) https://github.com/argilla-io/argilla/pull/5463 --- .../src/argilla/records/_dataset_records.py | 24 +++++++- argilla/tests/unit/test_record_ingestion.py | 55 ++++++++++++++++++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 01d2154120..3e36688902 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union from uuid import UUID +from enum import Enum from tqdm import tqdm @@ -32,6 +33,12 @@ from argilla.datasets import Dataset +class RecordErrorHandling(Enum): + RAISE = "raise" + WARN = "warn" + IGNORE = "ignore" + + class DatasetRecordsIterator: """This class is used to iterate over records in a dataset""" @@ -210,6 +217,7 @@ def log( mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None, user_id: Optional[UUID] = None, batch_size: int = DEFAULT_BATCH_SIZE, + on_error: RecordErrorHandling = RecordErrorHandling.RAISE, ) -> "DatasetRecords": """Add or update records in a dataset on the server using the provided records. If the record includes a known `id` field, the record will be updated. @@ -228,7 +236,9 @@ def log( Returns: A list of Record objects representing the updated records. """ - record_models = self._ingest_records(records=records, mapping=mapping, user_id=user_id or self.__client.me.id) + record_models = self._ingest_records( + records=records, mapping=mapping, user_id=user_id or self.__client.me.id, on_error=on_error + ) batch_size = self._normalize_batch_size( batch_size=batch_size, records_length=len(record_models), @@ -380,6 +390,7 @@ def _ingest_records( records: Union[List[Dict[str, Any]], List[Record], HFDataset], mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None, user_id: Optional[UUID] = None, + on_error: RecordErrorHandling = RecordErrorHandling.RAISE, ) -> List[RecordModel]: """Ingests records from a list of dictionaries, a Hugging Face Dataset, or a list of Record objects.""" @@ -405,7 +416,16 @@ def _ingest_records( f"Found a record of type {type(record)}: {record}." ) except Exception as e: - raise RecordsIngestionError(f"Failed to ingest record from dict {record}: {e}") + if on_error == RecordErrorHandling.IGNORE: + self._log_message( + message=f"Failed to ingest record from dict {record}: {e}", + level="info", + ) + continue + elif on_error == RecordErrorHandling.WARN: + warnings.warn(f"Failed to ingest record from dict {record}: {e}") + continue + raise RecordsIngestionError(f"Failed to ingest record from dict {record}") from e ingested_records.append(record.api_model()) return ingested_records diff --git a/argilla/tests/unit/test_record_ingestion.py b/argilla/tests/unit/test_record_ingestion.py index 1549a40d67..6d83f80e64 100644 --- a/argilla/tests/unit/test_record_ingestion.py +++ b/argilla/tests/unit/test_record_ingestion.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import warnings from uuid import uuid4 import pytest import argilla as rg from argilla._exceptions import RecordsIngestionError +from argilla.records._dataset_records import RecordErrorHandling @pytest.fixture @@ -207,3 +208,55 @@ def test_ingest_record_from_dict_with_mapping_multiple(): assert record.fields["prompt_field"] == "What is the capital of France?" assert "positive" in suggestions assert "What is the capital of France?" in suggestions + + +def test_ingest_records_on_error_raise(dataset): + with pytest.raises(RecordsIngestionError): + dataset.records._ingest_records( + records=[ + {"prompt": "Valid record"}, + {"invalid_field": "This should raise an error"}, + ], + on_error=RecordErrorHandling.RAISE, + ) + + +def test_ingest_records_on_error_warn(dataset): + with warnings.catch_warnings(record=True) as caught_warnings: + # Cause all warnings to always be triggered + warnings.simplefilter("always") + + records = dataset.records._ingest_records( + records=[ + {"prompt": "Valid record"}, + {"invalid_field": "This should warn"}, + ], + on_error=RecordErrorHandling.WARN, + ) + + # Check that we got one warning + assert len(caught_warnings) == 2 + # Check that the message matches + caught_warning_messages = [str(w.message) for w in caught_warnings] + assert any("Failed to ingest record" in message for message in caught_warning_messages) + assert any("invalid_field" in message for message in caught_warning_messages) + + # Check that only the valid record was ingested + assert len(records) == 1 + assert records[0].fields["prompt"] == "Valid record" + + +def test_ingest_records_on_error_ignore(dataset, caplog): + records = dataset.records._ingest_records( + records=[ + {"prompt": "Valid record 1"}, + {"invalid_field": "This should be ignored"}, + {"prompt": "Valid record 2"}, + ], + on_error=RecordErrorHandling.IGNORE, + ) + + assert len(records) == 2 + assert records[0].fields["prompt"] == "Valid record 1" + assert records[1].fields["prompt"] == "Valid record 2" + assert "Failed to ingest record" not in caplog.text