Skip to content

Commit

Permalink
tests: Fixing tests after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Oct 2, 2024
1 parent dc52185 commit 8a90dcc
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from uuid import UUID
from httpx import AsyncClient

import pytest
from fastapi.encoders import jsonable_encoder
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.enums import (
DatasetStatus,
QuestionType,
ResponseStatus,
SuggestionType,
RecordStatus,
DatasetDistributionStrategy,
)
from argilla_server.jobs.queues import HIGH_QUEUE
from argilla_server.models.database import Record, Response, Suggestion, User
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.records import build_record_event
from argilla_server.models.database import Record, Response, Suggestion, User
from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType

from tests.factories import (
DatasetFactory,
LabelSelectionQuestionFactory,
Expand All @@ -39,6 +45,7 @@
ChatFieldFactory,
CustomFieldFactory,
WebhookFactory,
AnnotatorFactory,
)


Expand Down Expand Up @@ -590,15 +597,6 @@ async def test_create_dataset_records_bulk_with_chat_field_without_content_key(
}
assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0

async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextQuestionFactory.create(name="text-question", dataset=dataset)

webhook = await WebhookFactory.create(events=[RecordEvent.created])

async def test_create_dataset_records_bulk_with_custom_field_values(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
Expand Down Expand Up @@ -631,23 +629,6 @@ async def test_create_dataset_records_bulk_with_custom_field_values(
},
)

assert response.status_code == 201

records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all()

event_a = await build_record_event(db, RecordEvent.created, records[0])
event_b = await build_record_event(db, RecordEvent.created, records[1])

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)

assert response.status_code == 201, response.json()
records = (await db.execute(select(Record))).scalars().all()
assert len(records) == 3
Expand Down Expand Up @@ -679,3 +660,163 @@ async def test_create_dataset_records_bulk_with_wrong_custom_field_value(

assert response.status_code == 422
assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0

async def test_create_dataset_records_bulk_updates_records_status(
self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict
):
dataset = await DatasetFactory.create(
status=DatasetStatus.ready,
distribution={
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 2,
},
)

user = await AnnotatorFactory.create(workspaces=[dataset.workspace])

await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextFieldFactory.create(name="response", dataset=dataset)

await TextQuestionFactory.create(name="text-question", dataset=dataset)

response = await async_client.post(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{
"fields": {
"prompt": "Does exercise help reduce stress?",
"response": "Exercise can definitely help reduce stress.",
},
"responses": [
{
"user_id": str(owner.id),
"status": ResponseStatus.submitted,
"values": {
"text-question": {
"value": "text question response",
},
},
},
{
"user_id": str(user.id),
"status": ResponseStatus.submitted,
"values": {
"text-question": {
"value": "text question response",
},
},
},
],
},
{
"fields": {
"prompt": "Does exercise help reduce stress?",
"response": "Exercise can definitely help reduce stress.",
},
"responses": [
{
"user_id": str(owner.id),
"status": ResponseStatus.submitted,
"values": {
"text-question": {
"value": "text question response",
},
},
},
],
},
{
"fields": {
"prompt": "Does exercise help reduce stress?",
"response": "Exercise can definitely help reduce stress.",
},
"responses": [
{
"user_id": str(owner.id),
"status": ResponseStatus.draft,
"values": {
"text-question": {
"value": "text question response",
},
},
},
{
"user_id": str(user.id),
"status": ResponseStatus.draft,
"values": {
"text-question": {
"value": "text question response",
},
},
},
],
},
{
"fields": {
"prompt": "Does exercise help reduce stress?",
"response": "Exercise can definitely help reduce stress.",
},
},
],
},
)

assert response.status_code == 201

response_items = response.json()["items"]
assert response_items[0]["status"] == RecordStatus.completed
assert response_items[1]["status"] == RecordStatus.pending
assert response_items[2]["status"] == RecordStatus.pending
assert response_items[3]["status"] == RecordStatus.pending

assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed
assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending
assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending
assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending

async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextQuestionFactory.create(name="text-question", dataset=dataset)

webhook = await WebhookFactory.create(events=[RecordEvent.created])

response = await async_client.post(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{
"fields": {
"prompt": "You should exercise more.",
},
},
{
"fields": {
"prompt": "Do you like to exercise?",
},
},
],
},
)

assert response.status_code == 201, response.json()

records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all()

event_a = await build_record_event(db, RecordEvent.created, records[0])
event_b = await build_record_event(db, RecordEvent.created, records[1])

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)

This file was deleted.

0 comments on commit 8a90dcc

Please sign in to comment.