Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add server dataset hub import #5591

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b2594b9
feat: first iteration of background job to import datasets from hub
jfcalvo Oct 10, 2024
6108ffe
feat: improve import_dataset_from_hub_job to get dataset before insta…
jfcalvo Oct 11, 2024
3a875ee
feat: improve HubDataset batch processing
jfcalvo Oct 11, 2024
b10a92e
feat: use UpsertRecordsBulk of CreateRecordsBulk for importing datase…
jfcalvo Oct 11, 2024
2b47522
feat: transform dataset importing value columns with PIL images to da…
jfcalvo Oct 11, 2024
7f93e0b
feat: add support to map suggestions importing datasets from hub
jfcalvo Oct 14, 2024
07aeec6
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 14, 2024
83237a7
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 14, 2024
a38fdda
feat: add support for hub dataset mapping
jfcalvo Oct 14, 2024
15d694f
feat: set metadata and suggestions as optional for HubDatasetMapping
jfcalvo Oct 14, 2024
c3752e8
feat: when no external_id is mapped row_idx is used
jfcalvo Oct 14, 2024
469ebc4
feat: use streaming when loading the dataset
jfcalvo Oct 15, 2024
b665019
feat: refactor UpsertRecordsBulk to validate records individually
jfcalvo Oct 15, 2024
a2fbc10
feat: ignore invalid records when importing datasets from hub
jfcalvo Oct 15, 2024
f3e33bd
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
jfcalvo Oct 15, 2024
fde2603
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 16, 2024
3d1f04f
feat: add a fixed number of rows to take importing dataset from Hub (…
jfcalvo Oct 17, 2024
e84fb45
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
jfcalvo Oct 17, 2024
08ebe28
feat: add support for class labels and casting rows (#5601)
jfcalvo Oct 18, 2024
312551a
feat: improve `HubDataset` image processing support (#5606)
jfcalvo Oct 18, 2024
f064267
feat: add support to `-1` no label values for `ClassLabel` dataset fe…
jfcalvo Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 210 additions & 176 deletions argilla-server/pdm.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion argilla-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ dependencies = [
"typer >= 0.6.0, < 0.10.0", # spaCy only supports typer<0.10.0
"packaging>=23.2",
"psycopg2-binary>=2.9.9",
# For HF dataset import
"datasets>=3.0.1",
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
"pillow>=10.4.0",
# For Telemetry
"huggingface_hub>=0.13,<1",
]
Expand Down Expand Up @@ -100,7 +103,6 @@ test = [
"factory-boy~=3.2.1",
"httpx>=0.26.0",
# Required by tests/unit/utils/test_dependency.py but we should take a look a probably removed them
"datasets > 1.17.0,!= 2.3.2",
"spacy>=3.5.0,<3.7.0",
"pytest-randomly>=3.15.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DatasetProgress,
Datasets,
DatasetUpdate,
HubDataset,
UsersProgress,
)
from argilla_server.api.schemas.v1.fields import Field, FieldCreate, Fields
Expand All @@ -38,9 +39,11 @@
MetadataPropertyCreate,
)
from argilla_server.api.schemas.v1.vector_settings import VectorSettings, VectorSettingsCreate, VectorsSettings
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.contexts import datasets
from argilla_server.database import get_async_db
from argilla_server.enums import DatasetStatus
from argilla_server.jobs import hub_jobs
from argilla_server.models import Dataset, User
from argilla_server.search_engine import (
SearchEngine,
Expand Down Expand Up @@ -301,3 +304,26 @@
await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))


# TODO: Maybe change /import to /import-from-hub?
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
@router.post("/datasets/{dataset_id}/import", status_code=status.HTTP_202_ACCEPTED, response_model=JobSchema)
async def import_dataset_from_hub(
*,
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
hub_dataset: HubDataset,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

Check warning on line 318 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L318

Added line #L318 was not covered by tests

await authorize(current_user, DatasetPolicy.import_from_hub(dataset))

Check warning on line 320 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L320

Added line #L320 was not covered by tests

job = hub_jobs.import_dataset_from_hub_job.delay(

Check warning on line 322 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L322

Added line #L322 was not covered by tests
name=hub_dataset.name,
subset=hub_dataset.subset,
split=hub_dataset.split,
dataset_id=dataset.id,
)

return JobSchema(id=job.id, status=job.get_status())

Check warning on line 329 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L329

Added line #L329 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def create_dataset_records_bulk(
async def upsert_dataset_records_bulk(
*,
dataset_id: UUID,
records_bulk_create: RecordsBulkUpsert,
records_bulk_upsert: RecordsBulkUpsert,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
Expand All @@ -86,7 +86,7 @@ async def upsert_dataset_records_bulk(

await authorize(current_user, DatasetPolicy.upsert_records(dataset))

records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create)
records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_upsert)

updated = len(records_bulk.updated_item_ids)
created = len(records_bulk.items) - updated
Expand Down
52 changes: 52 additions & 0 deletions argilla-server/src/argilla_server/api/handlers/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import APIRouter, Depends, HTTPException, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

from rq.job import Job
from rq.exceptions import NoSuchJobError

from argilla_server.database import get_async_db
from argilla_server.jobs.queues import REDIS_CONNECTION
from argilla_server.models import User
from argilla_server.api.policies.v1 import JobPolicy, authorize
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.security import auth

router = APIRouter(tags=["jobs"])


def _get_job(job_id: str) -> Job:
try:
return Job.fetch(job_id, connection=REDIS_CONNECTION)
except NoSuchJobError:
raise HTTPException(

Check warning on line 35 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L32-L35

Added lines #L32 - L35 were not covered by tests
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job with id `{job_id}` not found",
)


@router.get("/jobs/{job_id}", response_model=JobSchema)
async def get_job(
*,
db: AsyncSession = Depends(get_async_db),
job_id: str,
current_user: User = Security(auth.get_current_user),
):
job = _get_job(job_id)

Check warning on line 48 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L48

Added line #L48 was not covered by tests

await authorize(current_user, JobPolicy.get)

Check warning on line 50 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L50

Added line #L50 was not covered by tests

return JobSchema(id=job.id, status=job.get_status(refresh=True))

Check warning on line 52 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L52

Added line #L52 was not covered by tests
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from argilla_server.api.policies.v1.vector_settings_policy import VectorSettingsPolicy
from argilla_server.api.policies.v1.workspace_policy import WorkspacePolicy
from argilla_server.api.policies.v1.workspace_user_policy import WorkspaceUserPolicy
from argilla_server.api.policies.v1.job_policy import JobPolicy

__all__ = [
"DatasetPolicy",
Expand All @@ -37,6 +38,7 @@
"VectorSettingsPolicy",
"WorkspacePolicy",
"WorkspaceUserPolicy",
"JobPolicy",
"authorize",
"is_authorized",
]
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

return is_allowed

@classmethod
def import_from_hub(cls, dataset: Dataset) -> PolicyAction:
async def is_allowed(actor: User) -> bool:
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

Check warning on line 148 in argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py#L147-L148

Added lines #L147 - L148 were not covered by tests

return is_allowed

Check warning on line 150 in argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py#L150

Added line #L150 was not covered by tests
21 changes: 21 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/job_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.models import User


class JobPolicy:
@classmethod
async def get(cls, actor: User) -> bool:
return actor.is_owner or actor.is_admin

Check warning on line 21 in argilla-server/src/argilla_server/api/policies/v1/job_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/job_policy.py#L21

Added line #L21 was not covered by tests
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from argilla_server.api.handlers.v1 import (
workspaces as workspaces_v1,
)
from argilla_server.api.handlers.v1 import jobs as jobs_v1
from argilla_server.errors.base_errors import __ALL__
from argilla_server.errors.error_handler import APIErrorHandler

Expand Down Expand Up @@ -92,6 +93,7 @@ def create_api_v1():
users_v1.router,
vectors_settings_v1.router,
workspaces_v1.router,
jobs_v1.router,
oauth2_v1.router,
settings_v1.router,
]:
Expand Down
6 changes: 6 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,9 @@ class DatasetUpdate(UpdateSchema):
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}


class HubDataset(BaseModel):
name: str
subset: str
split: str
21 changes: 21 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rq.job import JobStatus
from pydantic import BaseModel


class Job(BaseModel):
id: str
status: JobStatus
142 changes: 142 additions & 0 deletions argilla-server/src/argilla_server/contexts/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import base64

from typing import Union
from typing_extensions import Self

from PIL import Image
from datasets import load_dataset
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models.database import Dataset
from argilla_server.search_engine import SearchEngine
from argilla_server.bulk.records_bulk import UpsertRecordsBulk
from argilla_server.api.schemas.v1.records import RecordUpsert as RecordUpsertSchema
from argilla_server.api.schemas.v1.records_bulk import RecordsBulkUpsert as RecordsBulkUpsertSchema
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate

BATCH_SIZE = 100


class HubDataset:
def __init__(self, name: str, subset: str, split: str):
self.dataset = load_dataset(path=name, name=subset, split=split)
self.iterable_dataset = self.dataset.to_iterable_dataset()

@property
def num_rows(self) -> int:
return self.dataset.num_rows

def take(self, n: int) -> Self:
self.iterable_dataset = self.iterable_dataset.take(n)

return self

async def import_to(self, db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> None:
if not dataset.is_ready:
raise Exception("it's not possible to import records to a non published dataset")

Check warning on line 51 in argilla-server/src/argilla_server/contexts/hub.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/hub.py#L51

Added line #L51 was not covered by tests

batched_dataset = self.iterable_dataset.batch(batch_size=BATCH_SIZE)
for batch in batched_dataset:
await self._import_batch_to(db, search_engine, batch, dataset)

async def _import_batch_to(
self, db: AsyncSession, search_engine: SearchEngine, batch: dict, dataset: Dataset
) -> None:
batch_size = len(next(iter(batch.values())))

items = []
for i in range(batch_size):
items.append(self._batch_row_to_record_schema(batch, i, dataset))

await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, RecordsBulkUpsertSchema(items=items))

def _batch_row_to_record_schema(self, batch: dict, index: int, dataset: Dataset) -> RecordUpsertSchema:
return RecordUpsertSchema(
id=None,
external_id=self._batch_row_external_id(batch, index),
fields=self._batch_row_fields(batch, index, dataset),
metadata=self._batch_row_metadata(batch, index, dataset),
suggestions=self._batch_row_suggestions(batch, index, dataset),
responses=None,
vectors=None,
)

# NOTE: if there is a value with key "id" in the batch, we will use it as external_id
def _batch_row_external_id(self, batch: dict, index: int) -> Union[str, None]:
if not "id" in batch:
return None

Check warning on line 82 in argilla-server/src/argilla_server/contexts/hub.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/hub.py#L82

Added line #L82 was not covered by tests

return batch["id"][index]

def _batch_row_fields(self, batch: dict, index: int, dataset: Dataset) -> dict:
fields = {}
for field in dataset.fields:
value = batch[field.name][index]

if field.is_text:
value = str(value)

if field.is_image and isinstance(value, Image.Image):
value = pil_image_to_data_url(value)

fields[field.name] = value

return fields

def _batch_row_metadata(self, batch: dict, index: int, dataset: Dataset) -> dict:
metadata = {}
for metadata_property in dataset.metadata_properties:
metadata[metadata_property.name] = batch[metadata_property.name][index]

return metadata

def _batch_row_suggestions(self, batch: dict, index: int, dataset: Dataset) -> list:
suggestions = []
for question in dataset.questions:
if not question.name in batch:
continue

Check warning on line 112 in argilla-server/src/argilla_server/contexts/hub.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/hub.py#L112

Added line #L112 was not covered by tests

value = batch[question.name][index]

if question.is_text or question.is_label_selection:
value = str(value)

Check warning on line 117 in argilla-server/src/argilla_server/contexts/hub.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/hub.py#L117

Added line #L117 was not covered by tests

if question.is_rating:
value = int(value)

suggestions.append(
SuggestionCreate(
question_id=question.id,
value=value,
type=None,
agent=None,
score=None,
),
)

return suggestions


def pil_image_to_data_url(image: Image.Image):
buffer = io.BytesIO()

image.save(buffer, format=image.format)

base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")

return f"data:{image.get_format_mimetype()};base64,{base64_image}"
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/jobs/dataset_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3))
async def update_dataset_records_status_job(dataset_id: UUID):
async def update_dataset_records_status_job(dataset_id: UUID) -> None:
"""This Job updates the status of all the records in the dataset when the distribution strategy changes."""

record_ids = []
Expand Down
Loading