Skip to content

Commit

Permalink
[FEATURE] from_hub raise on existing dataset name (#5358)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR changes the behaviour of the `from_disk` method when a dataset
of the same name already exists. Currently, a new dataset is create with
the name + uuid. This change will:

- check id the dataset name exists, and if so
- warn that the dataset exists and that the `name` parameter could be
used to create a new one
- try to push records to the existing dataset with a try except to add
more context

Closes #5346 

**Type of change**
- Improvement (change adding some improvement to an existing
functionality)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
burtenshaw and pre-commit-ci[bot] authored Aug 1, 2024
1 parent 5c12e32 commit 08d6ed5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
47 changes: 29 additions & 18 deletions argilla/src/argilla/datasets/_export/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from abc import ABC
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union
from uuid import uuid4

from argilla._exceptions import RecordsIngestionError, ArgillaError
from argilla._models import DatasetModel
from argilla.client import Argilla
from argilla.settings import Settings
Expand Down Expand Up @@ -90,28 +90,39 @@ def from_disk(

# Get the relevant workspace_id of the incoming dataset
if isinstance(workspace, str):
workspace_id = client.workspaces(workspace).id
elif isinstance(workspace, Workspace):
workspace_id = workspace.id
workspace = client.workspaces(workspace)
if not workspace:
raise ArgillaError(f"Workspace {workspace} not found on the server.")
else:
warnings.warn("Workspace not provided. Using default workspace.")
workspace_id = client.workspaces.default.id
dataset_model.workspace_id = workspace_id
workspace = client.workspaces.default
dataset_model.workspace_id = workspace.id

# Get a relevant and unique name for the incoming dataset.
if name:
logging.warning(f"Changing dataset name from {dataset_model.name} to {name}")
if name and (name != dataset_model.name):
logging.info(f"Changing dataset name from {dataset_model.name} to {name}")
dataset_model.name = name
elif client.api.datasets.name_exists(name=dataset_model.name, workspace_id=workspace_id):
logging.warning(f"Loaded dataset name {dataset_model.name} already exists. Changing to unique UUID.")
dataset_model.name = f"{dataset_model.name}_{uuid4()}"

# Create the dataset and load the settings and records
dataset = cls.from_model(model=dataset_model, client=client)
dataset.settings = Settings.from_json(path=settings_path)
dataset.create()

if client.api.datasets.name_exists(name=dataset_model.name, workspace_id=workspace.id):
warnings.warn(
f"Loaded dataset name {dataset_model.name} already exists in the workspace {workspace.name} so using it. To create a new dataset, provide a unique name to the `name` parameter."
)
dataset_model = client.api.datasets.get_by_name_and_workspace_id(
name=dataset_model.name, workspace_id=workspace.id
)
dataset = cls.from_model(model=dataset_model, client=client)
else:
# Create a new dataset and load the settings and records
dataset = cls.from_model(model=dataset_model, client=client)
dataset.settings = Settings.from_json(path=settings_path)
dataset.create()

if os.path.exists(records_path) and with_records:
dataset.records.from_json(path=records_path)
try:
dataset.records.from_json(path=records_path)
except RecordsIngestionError as e:
raise RecordsIngestionError(
message="Error importing dataset records from disk. Records and datasets settings are not compatible."
) from e
return dataset

############################
Expand Down
17 changes: 14 additions & 3 deletions argilla/tests/integration/test_export_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import argilla as rg
import pytest
from argilla._exceptions import ConflictError
from huggingface_hub.utils._errors import BadRequestError, FileMetadataError, HfHubHTTPError

_RETRIES = 5
Expand Down Expand Up @@ -125,7 +126,9 @@ def test_import_dataset_from_disk(

with TemporaryDirectory() as temp_dir:
output_dir = dataset.to_disk(path=temp_dir, with_records=with_records_export)
new_dataset = rg.Dataset.from_disk(output_dir, client=client, with_records=with_records_import)
new_dataset = rg.Dataset.from_disk(
output_dir, client=client, with_records=with_records_import, name=f"test_{uuid.uuid4()}"
)

if with_records_export and with_records_import:
for i, record in enumerate(new_dataset.records(with_suggestions=True)):
Expand Down Expand Up @@ -175,11 +178,19 @@ def test_import_dataset_from_hub(
match="Trying to load a dataset `with_records=True` but dataset does not contain any records.",
):
new_dataset = rg.Dataset.from_hub(
repo_id=repo_id, client=client, with_records=with_records_import, token=token
repo_id=repo_id,
client=client,
with_records=with_records_import,
token=token,
name=f"test_{uuid.uuid4()}",
)
else:
new_dataset = rg.Dataset.from_hub(
repo_id=repo_id, client=client, with_records=with_records_import, token=token
repo_id=repo_id,
client=client,
with_records=with_records_import,
token=token,
name=f"test_{uuid.uuid4()}",
)

if with_records_import and with_records_export:
Expand Down

0 comments on commit 08d6ed5

Please sign in to comment.