Skip to content

Commit

Permalink
[ENHANCEMENT] argilla-server: List records endpoint using db (#5170)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR rewrites the current list dataset records endpoint to use the DB
instead of the search engine since no filtering is applied to the
endpoint.

~~The PR introduces a new abstraction layer to manage DB internals:
repository. With this layer, we have all db methods related to a
resource in a single place, which helps to maintainability and
reusability.~~

DB Query details have been moved to the Record model class, simplifying
the context flows.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: José Francisco Calvo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Damián Pumar <[email protected]>
Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
5 people authored Sep 23, 2024
1 parent 496a8c3 commit aae998a
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from fastapi import APIRouter, Depends, Query, Security, status
Expand Down Expand Up @@ -43,17 +43,16 @@
SearchSuggestionsOptions,
SuggestionFilterScope,
)
from argilla_server.contexts import datasets, search
from argilla_server.contexts import datasets, search, records
from argilla_server.database import get_async_db
from argilla_server.enums import RecordSortField, ResponseStatusFilter
from argilla_server.enums import RecordSortField
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
SearchEngine,
SearchResponses,
UserResponseStatusFilter,
get_search_engine,
)
from argilla_server.security import auth
Expand All @@ -72,42 +71,13 @@
router = APIRouter()


async def _filter_records_using_search_engine(
db: "AsyncSession",
search_engine: "SearchEngine",
dataset: Dataset,
limit: int,
offset: int,
user: Optional[User] = None,
include: Optional[RecordIncludeParam] = None,
) -> Tuple[List[Record], int]:
search_responses = await _get_search_responses(
db=db,
search_engine=search_engine,
dataset=dataset,
limit=limit,
offset=offset,
user=user,
)

record_ids = [response.record_id for response in search_responses.items]
user_id = user.id if user else None

return (
await datasets.get_records_by_ids(
db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include
),
search_responses.total,
)


def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope:
if isinstance(scope, RecordFilterScope):
return search_engine.RecordFilterScope(property=scope.property)
elif isinstance(scope, MetadataFilterScope):
return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property)
elif isinstance(scope, SuggestionFilterScope):
return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property)
return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property))
elif isinstance(scope, ResponseFilterScope):
return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user)
else:
Expand Down Expand Up @@ -223,18 +193,6 @@ async def _get_search_responses(
return await search_engine.search(**search_params)


async def _build_response_status_filter_for_search(
response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None
) -> Optional[UserResponseStatusFilter]:
user_response_status_filter = None

if response_statuses:
# TODO(@frascuchon): user response and status responses should be split into different filter types
user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses)

return user_response_status_filter


async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset):
try:
await search.validate_search_records_query(db, query, dataset)
Expand All @@ -246,27 +204,34 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord
async def list_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset))

records, total = await _filter_records_using_search_engine(
db,
search_engine,
dataset=dataset,
limit=limit,
include_args = (
dict(
with_responses=include.with_responses,
with_suggestions=include.with_suggestions,
with_vectors=include.with_all_vectors or include.vectors,
)
if include
else {}
)

dataset_records, total = await records.list_dataset_records(
db=db,
dataset_id=dataset.id,
offset=offset,
include=include,
limit=limit,
**include_args,
)

return Records(items=records, total=total)
return Records(items=dataset_records, total=total)


@router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
77 changes: 69 additions & 8 deletions argilla-server/src/argilla_server/contexts/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Sequence
from typing import Dict, Sequence, Union, List, Tuple, Optional
from uuid import UUID

from sqlalchemy import select
from sqlalchemy import select, and_, func, Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import selectinload, contains_eager

from argilla_server.models import Dataset, Record
from argilla_server.database import get_async_db
from argilla_server.models import Dataset, Record, VectorSettings, Vector


async def list_dataset_records(
db: AsyncSession,
dataset_id: UUID,
offset: int,
limit: int,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Tuple[Sequence[Record], int]:
query = _record_by_dataset_id_query(
dataset_id=dataset_id,
offset=offset,
limit=limit,
with_responses=with_responses,
with_suggestions=with_suggestions,
with_vectors=with_vectors,
)

records = (await db.scalars(query)).unique().all()
total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id))

return records, total


async def list_dataset_records_by_ids(
db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID]
) -> Sequence[Record]:
query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id)
return (await db.execute(query)).unique().scalars().all()
query = select(Record).where(and_(Record.id.in_(record_ids), Record.dataset_id == dataset_id))
return (await db.scalars(query)).unique().all()


async def list_dataset_records_by_external_ids(
db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str]
) -> Sequence[Record]:
query = (
select(Record)
.filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)
.where(and_(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id))
.options(selectinload(Record.dataset))
)
return (await db.execute(query)).unique().scalars().all()

return (await db.scalars(query)).unique().all()


async def fetch_records_by_ids_as_dict(
Expand All @@ -52,3 +78,38 @@ async def fetch_records_by_external_ids_as_dict(
) -> Dict[str, Record]:
records_by_external_ids = await list_dataset_records_by_external_ids(db, dataset.id, external_ids)
return {record.external_id: record for record in records_by_external_ids}


def _record_by_dataset_id_query(
dataset_id,
offset: Optional[int] = None,
limit: Optional[int] = None,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Select:
query = select(Record).filter_by(dataset_id=dataset_id)

if with_responses:
query = query.options(selectinload(Record.responses))

if with_suggestions:
query = query.options(selectinload(Record.suggestions))

if with_vectors is True:
query = query.options(selectinload(Record.vectors))
elif isinstance(with_vectors, list):
subquery = select(VectorSettings.id).filter(
and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors))
)
query = query.outerjoin(
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery))
).options(contains_eager(Record.vectors))

if offset is not None:
query = query.offset(offset)

if limit is not None:
query = query.limit(limit)

return query.order_by(Record.inserted_at)
13 changes: 10 additions & 3 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@
from typing import Any, List, Optional, Union
from uuid import UUID

from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text
from sqlalchemy import Enum as SAEnum
from sqlalchemy import (
JSON,
ForeignKey,
String,
Text,
UniqueConstraint,
and_,
sql,
)
from sqlalchemy.engine.default import DefaultExecutionContext
from sqlalchemy.ext.asyncio import async_object_session
from sqlalchemy.ext.mutable import MutableDict, MutableList
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property
from sqlalchemy.orm import Mapped, mapped_column, relationship

from argilla_server.api.schemas.v1.questions import QuestionSettings
from argilla_server.enums import (
DatasetStatus,
FieldType,
MetadataPropertyType,
QuestionType,
RecordStatus,
ResponseStatus,
SuggestionType,
UserRole,
Expand Down
Loading

0 comments on commit aae998a

Please sign in to comment.