Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] argilla-server: List records endpoint using db #5170

Merged
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
f62d58a
feat: add dataset support to be created using distribution settings (…
jfcalvo Jul 1, 2024
017001f
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 1, 2024
f084ab7
✨ Remove unused method
damianpumar Jul 4, 2024
c8ef4c6
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 4, 2024
6df5256
feat: improve Records `responses_submitted` relationship to be view o…
jfcalvo Jul 4, 2024
dbae135
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 4, 2024
cf3408c
feat: change metrics to support new distribution task logic (#5140)
jfcalvo Jul 4, 2024
267811c
[REFACTOR] `argilla-server`: Remove list current user records endpoin…
frascuchon Jul 4, 2024
89f9bde
[BREAKING- REFACTOR] `argilla-server`: remove metadata filter query p…
frascuchon Jul 4, 2024
0404465
[BREAKING - REFACTOR] `argilla-server`: remove user response status s…
frascuchon Jul 4, 2024
20d4ab8
refactor: Remove sort_by argument
frascuchon Jul 4, 2024
5f4e5b0
[breaking] refactor: Remove sort_by query param
frascuchon Jul 4, 2024
c885392
tests: Adapt tests
frascuchon Jul 4, 2024
28b2998
chore: Update changelog
frascuchon Jul 5, 2024
209d64d
feat: Define new repositories
frascuchon Jul 5, 2024
a350b0c
chore: Rewrite list endpoint using repositories
frascuchon Jul 5, 2024
3537941
tests: Enable skip tests for list dataset records
frascuchon Jul 5, 2024
8e8b116
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
frascuchon Jul 5, 2024
808c837
[ENHANCEMENT]: `argilla-server`: allow update distribution for non an…
frascuchon Jul 8, 2024
ba417dc
[BREAKING - REFACTOR] `argilla-server`: remove `sort_by` query param …
frascuchon Jul 8, 2024
f241e41
fix: wrong filter naming after merge from develop
frascuchon Jul 9, 2024
67d4ee3
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 9, 2024
3e06890
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 9, 2024
b15de8f
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
frascuchon Jul 11, 2024
f497140
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 11, 2024
bec0b0d
feat: add session helper with serializable isolation level (#5165)
jfcalvo Jul 12, 2024
8bf8abb
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
85e847f
[REFACTOR] `argilla-server`: remove deprecated records endpoint (#5206)
frascuchon Jul 12, 2024
22263d8
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
c219764
[ENHANCEMENT] `argilla`: add record `status` property (#5184)
frascuchon Jul 12, 2024
ced0220
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
aa9bf1f
Merge branch 'feat/add-dataset-automatic-task-distribution' into refa…
frascuchon Jul 12, 2024
0b73b3f
Merge branch 'refactor/cleaning-list-records-endpoints' into refactor…
frascuchon Jul 12, 2024
dcfbfaf
Merge branch 'feat/add-dataset-automatic-task-distribution' into refa…
frascuchon Jul 12, 2024
4d3f668
Merge branch 'refactor/cleaning-list-records-endpoints' into refactor…
frascuchon Jul 12, 2024
2941072
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 18, 2024
9c9aa26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
11ef168
Update argilla-frontend/components/features/datasets/dataset-progress…
frascuchon Jul 19, 2024
0e525b4
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 25, 2024
1526e33
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 29, 2024
bca45ff
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 31, 2024
7356451
chore: Remove repositories
frascuchon Jul 31, 2024
39e6bd7
refactor: Moving logic to contexts
frascuchon Jul 31, 2024
e4eb17f
refactor: using contexts
frascuchon Jul 31, 2024
6326a54
tests: Mock db for contexts
frascuchon Jul 31, 2024
3ce1f84
refactor: Reusing depends get_dataset
frascuchon Jul 31, 2024
82e306e
refactor: Moving query builder to models
frascuchon Jul 31, 2024
423466a
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
jfcalvo Jul 31, 2024
7205842
chore: Update CHANGELOG
frascuchon Aug 1, 2024
59d05c5
chore: Change order
frascuchon Aug 1, 2024
61bc08f
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 1, 2024
d656bbc
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 2, 2024
d8aa03e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
ee3fa63
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 26, 2024
6d9ecfb
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
jfcalvo Aug 30, 2024
d4f70de
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 9, 2024
4b25753
chore: Apply PR comments
frascuchon Sep 9, 2024
c8d3be8
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 9, 2024
31390ab
chore: Revert newline
frascuchon Sep 9, 2024
a7b6713
Merge branch 'refactor/argilla-server/list-records-endpoint-using-db'…
frascuchon Sep 9, 2024
a7dc205
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 17, 2024
b00f404
chore: Apply suggestions
frascuchon Sep 17, 2024
a9e6795
revert code changes
frascuchon Sep 17, 2024
acfe981
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 18, 2024
031a407
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
22539d6
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ These are the section headers that we use:
- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
- Change search index mapping for responses (reindex is required). ([#5228](https://github.com/argilla-io/argilla/pull/5228))

### Changed

- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126))
- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))

frascuchon marked this conversation as resolved.
Show resolved Hide resolved
### Fixed

- Fixed SQLite connection settings not working correctly due to an outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149))
Expand All @@ -48,6 +54,15 @@ These are the section headers that we use:

### Removed

- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))
- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))
- [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153))
- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156))
- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

- [breaking] Removed deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))
- [breaking] Removed deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))
- [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from fastapi import APIRouter, Depends, Query, Security, status
from fastapi import APIRouter, Depends, Query, Security, status, Path
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down Expand Up @@ -44,17 +43,16 @@
SearchSuggestionsOptions,
SuggestionFilterScope,
)
from argilla_server.contexts import datasets, search
from argilla_server.contexts import datasets, search, records
from argilla_server.database import get_async_db
from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder
from argilla_server.enums import RecordSortField
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
SearchEngine,
SearchResponses,
UserResponseStatusFilter,
get_search_engine,
)
from argilla_server.security import auth
Expand All @@ -73,42 +71,13 @@
router = APIRouter()


async def _filter_records_using_search_engine(
db: "AsyncSession",
search_engine: "SearchEngine",
dataset: Dataset,
limit: int,
offset: int,
user: Optional[User] = None,
include: Optional[RecordIncludeParam] = None,
) -> Tuple[List[Record], int]:
search_responses = await _get_search_responses(
db=db,
search_engine=search_engine,
dataset=dataset,
limit=limit,
offset=offset,
user=user,
)

record_ids = [response.record_id for response in search_responses.items]
user_id = user.id if user else None

return (
await datasets.get_records_by_ids(
db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include
),
search_responses.total,
)


def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope:
if isinstance(scope, RecordFilterScope):
return search_engine.RecordFilterScope(property=scope.property)
elif isinstance(scope, MetadataFilterScope):
return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property)
elif isinstance(scope, SuggestionFilterScope):
return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property)
return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property))
elif isinstance(scope, ResponseFilterScope):
return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user)
else:
Expand Down Expand Up @@ -224,63 +193,57 @@ async def _get_search_responses(
return await search_engine.search(**search_params)


async def _build_response_status_filter_for_search(
response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None
) -> Optional[UserResponseStatusFilter]:
user_response_status_filter = None

if response_statuses:
# TODO(@frascuchon): user response and status responses should be split into different filter types
user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses)

return user_response_status_filter


async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID):
try:
await search.validate_search_records_query(db, query, dataset_id)
except (ValueError, NotFoundError) as e:
raise UnprocessableEntityError(str(e))


async def get_dataset_or_raise(dataset_id: UUID = Path) -> Dataset:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
return await datasets.get_or_raise(dataset_id)


@router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True)
async def list_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
dataset: Dataset = Depends(get_dataset_or_raise),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset))

records, total = await _filter_records_using_search_engine(
db,
search_engine,
dataset=dataset,
limit=limit,
include_args = (
dict(
with_responses=include.with_responses,
with_suggestions=include.with_suggestions,
with_vectors=include.with_all_vectors or include.vectors,
)
if include
else {}
)

dataset_records, total = await records.list_records_by_dataset_id(
dataset_id=dataset.id,
offset=offset,
include=include,
limit=limit,
**include_args,
)

return Records(items=records, total=total)
return Records(items=dataset_records, total=total)


@router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
async def delete_dataset_records(
*,
dataset: Dataset = Depends(get_dataset_or_raise),
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
current_user: User = Security(auth.get_current_user),
ids: str = Query(..., description="A comma separated list with the IDs of the records to be removed"),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.delete_records(dataset))

record_ids = parse_uuids(ids)
Expand Down Expand Up @@ -427,12 +390,10 @@ async def search_dataset_records(
)
async def list_dataset_records_search_suggestions_options(
*,
dataset: Dataset = Depends(get_dataset_or_raise),
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)
frascuchon marked this conversation as resolved.
Show resolved Hide resolved

await authorize(current_user, DatasetPolicy.search_records(dataset))

suggestion_agents_by_question = await search.get_dataset_suggestion_agents_by_question(db, dataset.id)
Expand Down
7 changes: 7 additions & 0 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.contexts import accounts, distribution
from argilla_server.database import get_async_db
from argilla_server.enums import DatasetStatus, UserRole, RecordStatus
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Expand Down Expand Up @@ -114,6 +115,12 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->
return result.scalars().all()


async def get_or_raise(dataset_id: UUID) -> Dataset:
"""Get a dataset by ID or raise a NotFoundError"""
async for db in get_async_db():
return await Dataset.get_or_raise(db, id=dataset_id)


async def create_dataset(db: AsyncSession, dataset_attrs: dict):
dataset = Dataset(
name=dataset_attrs["name"],
Expand Down
44 changes: 35 additions & 9 deletions argilla-server/src/argilla_server/contexts/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Sequence
from typing import Dict, Sequence, Union, List, Tuple
from uuid import UUID

from sqlalchemy import select
from sqlalchemy import select, and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import selectinload, contains_eager

from argilla_server.models import Dataset, Record
from argilla_server.database import get_async_db
from argilla_server.models import Dataset, Record, VectorSettings, Vector


async def list_records_by_dataset_id(
dataset_id: UUID,
offset: int,
limit: int,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Tuple[Sequence[Record], int]:
query = Record.Select.by_dataset_id(
dataset_id=dataset_id,
offset=offset,
limit=limit,
with_responses=with_responses,
with_suggestions=with_suggestions,
with_vectors=with_vectors,
)

async for db in get_async_db():
records = (await db.scalars(query)).unique().all()
total = await db.scalar(Record.Select.count(dataset_id=dataset_id))

return records, total


async def list_dataset_records_by_ids(
db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID]
) -> Sequence[Record]:
query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id)
return (await db.execute(query)).unique().scalars().all()
query = Record.Select.by_dataset_id(dataset_id=dataset_id).where(Record.id.in_(record_ids))
return (await db.scalars(query)).unique().all()


async def list_dataset_records_by_external_ids(
db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str]
) -> Sequence[Record]:
query = (
select(Record)
.filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)
Record.Select.by_dataset_id(dataset_id=dataset_id)
.where(Record.external_id.in_(external_ids))
.options(selectinload(Record.dataset))
)
return (await db.execute(query)).unique().scalars().all()

return (await db.scalars(query)).unique().all()


async def fetch_records_by_ids_as_dict(
Expand Down
58 changes: 56 additions & 2 deletions argilla-server/src/argilla_server/models/database.py
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@
from typing import Any, List, Optional, Union
from uuid import UUID

from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text
from sqlalchemy import (
JSON,
ForeignKey,
String,
Text,
UniqueConstraint,
and_,
sql,
select,
func,
text,
Select,
ColumnExpressionArgument,
)
from sqlalchemy import Enum as SAEnum
from sqlalchemy.engine.default import DefaultExecutionContext
from sqlalchemy.ext.asyncio import async_object_session
from sqlalchemy.ext.mutable import MutableDict, MutableList
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property, selectinload, contains_eager

from argilla_server.api.schemas.v1.questions import QuestionSettings
from argilla_server.enums import (
Expand Down Expand Up @@ -236,6 +249,47 @@
f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})"
)

class Select:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def count(cls, **filters) -> Select:
return select(func.count(Record.id)).filter_by(**filters)

@classmethod
def by_dataset_id(
cls,
dataset_id: UUID,
offset: Optional[int] = None,
limit: Optional[int] = None,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Select:
query = select(Record).filter_by(dataset_id=dataset_id)

if with_responses:
query = query.options(selectinload(Record.responses))

if with_suggestions:
query = query.options(selectinload(Record.suggestions))

Check warning on line 273 in argilla-server/src/argilla_server/models/database.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/models/database.py#L273

Added line #L273 was not covered by tests

if with_vectors is True:
query = query.options(selectinload(Record.vectors))
elif isinstance(with_vectors, list):
subquery = select(VectorSettings.id).filter(
and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors))
)
query = query.outerjoin(
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery))
).options(contains_eager(Record.vectors))

if offset is not None:
query = query.offset(offset)

if limit is not None:
query = query.limit(limit)

return query.order_by(Record.inserted_at)


class Question(DatabaseModel):
__tablename__ = "questions"
Expand Down
Loading