Skip to content

Commit

Permalink
feat: add server metadata support for datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo committed Oct 8, 2024
1 parent 1427abb commit da43321
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add metadata column to datasets table
Revision ID: 660d6c6b3360
Revises: 237f7c674d74
Create Date: 2024-10-04 16:47:21.611404
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "660d6c6b3360"
down_revision = "237f7c674d74"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column("datasets", sa.Column("metadata", sa.JSON(), server_default="{}", nullable=False))


def downgrade() -> None:
op.drop_column("datasets", "metadata")
17 changes: 15 additions & 2 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

from datetime import datetime
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Union, Dict, Any
from uuid import UUID

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr
from argilla_server.pydantic_v1.utils import GetterDict

try:
from typing import Annotated
Expand Down Expand Up @@ -102,20 +103,30 @@ class UsersProgress(BaseModel):
users: List[UserProgress]


class DatasetGetterDict(GetterDict):
def get(self, key: str, default: Any) -> Any:
if key == "metadata":
return getattr(self._obj, "metadata_", None)

return super().get(key, default)


class Dataset(BaseModel):
id: UUID
name: str
guidelines: Optional[str]
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
metadata: Dict[str, Any]
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
getter_dict = DatasetGetterDict


class Datasets(BaseModel):
Expand All @@ -130,6 +141,7 @@ class DatasetCreate(BaseModel):
strategy=DatasetDistributionStrategy.overlap,
min_submitted=1,
)
metadata: Dict[str, Any] = {}
workspace_id: UUID


Expand All @@ -138,5 +150,6 @@ class DatasetUpdate(UpdateSchema):
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution", "metadata"}
3 changes: 2 additions & 1 deletion argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->
return result.scalars().all()


async def create_dataset(db: AsyncSession, dataset_attrs: dict):
async def create_dataset(db: AsyncSession, dataset_attrs: dict) -> Dataset:
dataset = Dataset(
name=dataset_attrs["name"],
guidelines=dataset_attrs["guidelines"],
allow_extra_metadata=dataset_attrs["allow_extra_metadata"],
distribution=dataset_attrs["distribution"],
metadata_=dataset_attrs["metadata"],
workspace_id=dataset_attrs["workspace_id"],
)

Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class Dataset(DatabaseModel):
allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true())
status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True)
distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON))
metadata_: Mapped[dict] = mapped_column("metadata", JSON, default={}, server_default="{}")
workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True)
inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import pytest
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.models import Dataset

from typing import Any
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.models import Dataset

from tests.factories import WorkspaceFactory


Expand Down Expand Up @@ -54,6 +57,7 @@ async def test_create_dataset_with_default_distribution(
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand All @@ -74,6 +78,7 @@ async def test_create_dataset_with_overlap_distribution(
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
"metadata": {},
"workspace_id": str(workspace.id),
},
)
Expand All @@ -91,6 +96,7 @@ async def test_create_dataset_with_overlap_distribution(
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
"metadata": {},
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand Down Expand Up @@ -137,3 +143,63 @@ async def test_create_dataset_with_invalid_distribution_strategy(

assert response.status_code == 422
assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0

async def test_create_dataset_with_default_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 201
assert response.json()["metadata"] == {}

dataset = (await db.execute(select(Dataset))).scalar_one()
assert dataset.metadata_ == {}

async def test_create_dataset_with_custom_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"metadata": {"key": "value"},
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 201
assert response.json()["metadata"] == {"key": "value"}

dataset = (await db.execute(select(Dataset))).scalar_one()
assert dataset.metadata_ == {"key": "value"}

@pytest.mark.parametrize("invalid_metadata", ["invalid_metadata", None, 123])
async def test_create_dataset_with_invalid_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_metadata: Any
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"metadata": invalid_metadata,
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 422
assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from uuid import UUID

import pytest
Expand Down Expand Up @@ -72,29 +73,6 @@ async def test_update_dataset_without_distribution(self, async_client: AsyncClie
"min_submitted": 1,
}

async def test_update_dataset_without_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

assert dataset.name == "Dataset updated name"
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_with_invalid_strategy(
self, async_client: AsyncClient, owner_auth_header: dict
):
Expand Down Expand Up @@ -152,3 +130,69 @@ async def test_update_dataset_distribution_as_none(self, async_client: AsyncClie
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key-a": "value-a", "key-b": "value-b"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"metadata": {
"key-a": "value-a-updated",
"key-c": "value-c",
},
},
)

assert response.status_code == 200
assert response.json()["metadata"] == {
"key-a": "value-a-updated",
"key-b": "value-b",
"key-c": "value-c",
}

assert dataset.metadata_ == {
"key-a": "value-a-updated",
"key-b": "value-b",
"key-c": "value-c",
}

async def test_update_dataset_without_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["metadata"] == {"key": "value"}

assert dataset.name == "Dataset updated name"
assert dataset.metadata_ == {"key": "value"}

async def test_update_dataset_with_invalid_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"metadata": "invalid_metadata"},
)

assert response.status_code == 422
assert dataset.metadata_ == {"key": "value"}

async def test_update_dataset_metadata_as_none(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"metadata": None},
)

assert response.status_code == 422
assert dataset.metadata_ == {"key": "value"}
11 changes: 10 additions & 1 deletion argilla-server/tests/unit/api/handlers/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(dataset_a.workspace_id),
"last_activity_at": dataset_a.last_activity_at.isoformat(),
"inserted_at": dataset_a.inserted_at.isoformat(),
Expand All @@ -138,6 +139,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(dataset_b.workspace_id),
"last_activity_at": dataset_b.last_activity_at.isoformat(),
"inserted_at": dataset_b.inserted_at.isoformat(),
Expand All @@ -153,6 +155,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(dataset_c.workspace_id),
"last_activity_at": dataset_c.last_activity_at.isoformat(),
"inserted_at": dataset_c.inserted_at.isoformat(),
Expand Down Expand Up @@ -684,6 +687,7 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header:
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(dataset.workspace_id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand Down Expand Up @@ -890,6 +894,7 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(workspace.id),
"last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(),
"inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(),
Expand Down Expand Up @@ -4760,6 +4765,7 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": {},
"workspace_id": str(dataset.workspace_id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand Down Expand Up @@ -4846,7 +4852,10 @@ async def test_update_dataset_as_annotator(self, async_client: "AsyncClient"):
response = await async_client.patch(
f"/api/v1/datasets/{dataset.id}",
headers={API_KEY_HEADER_NAME: user.api_key},
json={"name": "New Name", "guidelines": "New Guidelines"},
json={
"name": "New Name",
"guidelines": "New Guidelines",
},
)

assert response.status_code == 403
Expand Down

0 comments on commit da43321

Please sign in to comment.