From 8a90dccd12f83a64d546d761098ee499d85021a5 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 2 Oct 2024 16:44:27 +0200 Subject: [PATCH] tests: Fixing tests after merge --- .../test_create_dataset_records_bulk.py | 205 +++++++++++++++--- .../test_create_dataset_records_bulk.py | 145 ------------- 2 files changed, 173 insertions(+), 177 deletions(-) delete mode 100644 argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py index 155b04641a..64d99656e7 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py @@ -12,20 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - from uuid import UUID -from httpx import AsyncClient + +import pytest from fastapi.encoders import jsonable_encoder +from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server.enums import ( + DatasetStatus, + QuestionType, + ResponseStatus, + SuggestionType, + RecordStatus, + DatasetDistributionStrategy, +) from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.models.database import Record, Response, Suggestion, User from argilla_server.webhooks.v1.enums import RecordEvent from argilla_server.webhooks.v1.records import build_record_event -from argilla_server.models.database import Record, Response, Suggestion, User -from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType - from tests.factories import ( DatasetFactory, LabelSelectionQuestionFactory, @@ -39,6 +45,7 @@ ChatFieldFactory, CustomFieldFactory, WebhookFactory, + AnnotatorFactory, ) @@ -590,15 +597,6 @@ async def test_create_dataset_records_bulk_with_chat_field_without_content_key( } assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0 - async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events( - self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict - ): - dataset = await DatasetFactory.create(status=DatasetStatus.ready) - await TextFieldFactory.create(name="prompt", dataset=dataset) - await TextQuestionFactory.create(name="text-question", dataset=dataset) - - webhook = await WebhookFactory.create(events=[RecordEvent.created]) - async def test_create_dataset_records_bulk_with_custom_field_values( self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict ): @@ -631,23 +629,6 @@ async def test_create_dataset_records_bulk_with_custom_field_values( }, ) - assert response.status_code == 201 - - records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all() - - event_a = await build_record_event(db, RecordEvent.created, records[0]) - event_b = await build_record_event(db, RecordEvent.created, records[1]) - - assert HIGH_QUEUE.count == 2 - - assert HIGH_QUEUE.jobs[0].args[0] == webhook.id - assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created - assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) - - assert HIGH_QUEUE.jobs[1].args[0] == webhook.id - assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created - assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) - assert response.status_code == 201, response.json() records = (await db.execute(select(Record))).scalars().all() assert len(records) == 3 @@ -679,3 +660,163 @@ async def test_create_dataset_records_bulk_with_wrong_custom_field_value( assert response.status_code == 422 assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0 + + async def test_create_dataset_records_bulk_updates_records_status( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextFieldFactory.create(name="response", dataset=dataset) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + response_items = response.json()["items"] + assert response_items[0]["status"] == RecordStatus.completed + assert response_items[1]["status"] == RecordStatus.pending + assert response_items[2]["status"] == RecordStatus.pending + assert response_items[3]["status"] == RecordStatus.pending + + assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed + assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending + + async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.created]) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "You should exercise more.", + }, + }, + { + "fields": { + "prompt": "Do you like to exercise?", + }, + }, + ], + }, + ) + + assert response.status_code == 201, response.json() + + records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all() + + event_a = await build_record_event(db, RecordEvent.created, records[0]) + event_b = await build_record_event(db, RecordEvent.created, records[1]) + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) + + assert HIGH_QUEUE.jobs[1].args[0] == webhook.id + assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py deleted file mode 100644 index 1aae133535..0000000000 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from uuid import UUID -from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession - -from argilla_server.models import User, Record -from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus, DatasetStatus - -from tests.factories import AnnotatorFactory, DatasetFactory, TextFieldFactory, TextQuestionFactory - - -@pytest.mark.asyncio -class TestCreateDatasetRecordsBulk: - def url(self, dataset_id: UUID) -> str: - return f"/api/v1/datasets/{dataset_id}/records/bulk" - - async def test_create_dataset_records_bulk_updates_records_status( - self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict - ): - dataset = await DatasetFactory.create( - status=DatasetStatus.ready, - distribution={ - "strategy": DatasetDistributionStrategy.overlap, - "min_submitted": 2, - }, - ) - - user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) - - await TextFieldFactory.create(name="prompt", dataset=dataset) - await TextFieldFactory.create(name="response", dataset=dataset) - - await TextQuestionFactory.create(name="text-question", dataset=dataset) - - response = await async_client.post( - self.url(dataset.id), - headers=owner_auth_header, - json={ - "items": [ - { - "fields": { - "prompt": "Does exercise help reduce stress?", - "response": "Exercise can definitely help reduce stress.", - }, - "responses": [ - { - "user_id": str(owner.id), - "status": ResponseStatus.submitted, - "values": { - "text-question": { - "value": "text question response", - }, - }, - }, - { - "user_id": str(user.id), - "status": ResponseStatus.submitted, - "values": { - "text-question": { - "value": "text question response", - }, - }, - }, - ], - }, - { - "fields": { - "prompt": "Does exercise help reduce stress?", - "response": "Exercise can definitely help reduce stress.", - }, - "responses": [ - { - "user_id": str(owner.id), - "status": ResponseStatus.submitted, - "values": { - "text-question": { - "value": "text question response", - }, - }, - }, - ], - }, - { - "fields": { - "prompt": "Does exercise help reduce stress?", - "response": "Exercise can definitely help reduce stress.", - }, - "responses": [ - { - "user_id": str(owner.id), - "status": ResponseStatus.draft, - "values": { - "text-question": { - "value": "text question response", - }, - }, - }, - { - "user_id": str(user.id), - "status": ResponseStatus.draft, - "values": { - "text-question": { - "value": "text question response", - }, - }, - }, - ], - }, - { - "fields": { - "prompt": "Does exercise help reduce stress?", - "response": "Exercise can definitely help reduce stress.", - }, - }, - ], - }, - ) - - assert response.status_code == 201 - - response_items = response.json()["items"] - assert response_items[0]["status"] == RecordStatus.completed - assert response_items[1]["status"] == RecordStatus.pending - assert response_items[2]["status"] == RecordStatus.pending - assert response_items[3]["status"] == RecordStatus.pending - - assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed - assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending - assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending - assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending