Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IMPROVEMENT] Add Webhooks delete events with expanded schemas #5519

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,18 @@
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.webhooks.v1.enums import DatasetEvent, ResponseEvent, RecordEvent
from argilla_server.webhooks.v1.records import notify_record_event as notify_record_event_v1
from argilla_server.webhooks.v1.responses import notify_response_event as notify_response_event_v1
from argilla_server.webhooks.v1.datasets import notify_dataset_event as notify_dataset_event_v1
from argilla_server.webhooks.v1.records import (
build_record_event as build_record_event_v1,
notify_record_event as notify_record_event_v1,
)
from argilla_server.webhooks.v1.responses import (
build_response_event as build_response_event_v1,
notify_response_event as notify_response_event_v1,
)
from argilla_server.webhooks.v1.datasets import (
build_dataset_event as build_dataset_event_v1,
notify_dataset_event as notify_dataset_event_v1,
)
from argilla_server.contexts import accounts, distribution
from argilla_server.database import get_async_db
from argilla_server.enums import DatasetStatus, UserRole, RecordStatus
Expand Down Expand Up @@ -204,13 +213,14 @@ async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict


async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
deleted_dataset_event_v1 = await build_dataset_event_v1(db, DatasetEvent.deleted, dataset)

async with db.begin_nested():
dataset = await dataset.delete(db, autocommit=False)
await search_engine.delete_index(dataset)

await db.commit()

await notify_dataset_event_v1(db, DatasetEvent.deleted, dataset)
await deleted_dataset_event_v1.notify(db)

return dataset

Expand Down Expand Up @@ -815,15 +825,24 @@ async def preload_records_relationships_before_validate(db: AsyncSession, record
async def delete_records(
db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_ids: List[UUID]
) -> None:
params = [Record.id.in_(records_ids), Record.dataset_id == dataset.id]

records = (await db.execute(select(Record).filter(*params))).scalars().all()

deleted_record_events_v1 = []
for record in records:
deleted_record_events_v1.append(
await build_record_event_v1(db, RecordEvent.deleted, record),
)

async with db.begin_nested():
params = [Record.id.in_(records_ids), Record.dataset_id == dataset.id]
records = await Record.delete_many(db=db, params=params, autocommit=False)
await search_engine.delete_records(dataset=dataset, records=records)

await db.commit()

for record in records:
await notify_record_event_v1(db, RecordEvent.deleted, record)
for deleted_record_event_v1 in deleted_record_events_v1:
await deleted_record_event_v1.notify(db)


async def update_record(
Expand Down Expand Up @@ -860,13 +879,14 @@ async def update_record(


async def delete_record(db: AsyncSession, search_engine: "SearchEngine", record: Record) -> Record:
deleted_record_event_v1 = await build_record_event_v1(db, RecordEvent.deleted, record)

async with db.begin_nested():
record = await record.delete(db=db, autocommit=False)
await search_engine.delete_records(dataset=record.dataset, records=[record])

await db.commit()

await notify_record_event_v1(db, RecordEvent.deleted, record)
await deleted_record_event_v1.notify(db)

return record

Expand Down Expand Up @@ -897,8 +917,8 @@ async def create_response(
await search_engine.update_record_response(response)

await db.commit()
await distribution.update_record_status(search_engine, record.id)
await notify_response_event_v1(db, ResponseEvent.created, response)
await distribution.update_record_status(search_engine, record.id)

return response

Expand All @@ -922,8 +942,8 @@ async def update_response(
await search_engine.update_record_response(response)

await db.commit()
await distribution.update_record_status(search_engine, response.record_id)
await notify_response_event_v1(db, ResponseEvent.updated, response)
await distribution.update_record_status(search_engine, response.record_id)

return response

Expand Down Expand Up @@ -951,17 +971,20 @@ async def upsert_response(
await search_engine.update_record_response(response)

await db.commit()
await distribution.update_record_status(search_engine, record.id)

if response.inserted_at == response.updated_at:
await notify_response_event_v1(db, ResponseEvent.created, response)
else:
await notify_response_event_v1(db, ResponseEvent.updated, response)

await distribution.update_record_status(search_engine, record.id)

return response


async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response:
deleted_response_event_v1 = await build_response_event_v1(db, ResponseEvent.deleted, response)

async with db.begin_nested():
response = await response.delete(db, autocommit=False)

Expand All @@ -970,8 +993,8 @@ async def delete_response(db: AsyncSession, search_engine: SearchEngine, respons
await search_engine.delete_record_response(response)

await db.commit()
await deleted_response_event_v1.notify(db)
await distribution.update_record_status(search_engine, response.record_id)
await notify_response_event_v1(db, ResponseEvent.deleted, response)

return response

Expand Down
21 changes: 7 additions & 14 deletions argilla-server/src/argilla_server/webhooks/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models import Dataset
from argilla_server.jobs.webhook_jobs import enqueue_notify_events
from argilla_server.webhooks.v1.event import Event
from argilla_server.webhooks.v1.schemas import DatasetEventSchema
from argilla_server.webhooks.v1.enums import DatasetEvent


async def notify_dataset_event(db: AsyncSession, dataset_event: DatasetEvent, dataset: Dataset) -> List[Job]:
if dataset_event == DatasetEvent.deleted:
return await _notify_dataset_deleted_event(db, dataset)
event = await build_dataset_event(db, dataset_event, dataset)

return await event.notify(db)


async def build_dataset_event(db: AsyncSession, dataset_event: DatasetEvent, dataset: Dataset) -> Event:
# NOTE: Force loading required association resources required by the event schema
(
await db.execute(
Expand All @@ -45,18 +48,8 @@ async def notify_dataset_event(db: AsyncSession, dataset_event: DatasetEvent, da
)
).scalar_one()

return await enqueue_notify_events(
db,
return Event(
event=dataset_event,
timestamp=datetime.utcnow(),
data=DatasetEventSchema.from_orm(dataset).dict(),
)


async def _notify_dataset_deleted_event(db: AsyncSession, dataset: Dataset) -> List[Job]:
return await enqueue_notify_events(
db,
event=DatasetEvent.deleted,
timestamp=datetime.utcnow(),
data={"id": dataset.id},
)
36 changes: 36 additions & 0 deletions argilla-server/src/argilla_server/webhooks/v1/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from datetime import datetime

from rq.job import Job
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.jobs.webhook_jobs import enqueue_notify_events


class Event:
def __init__(self, event: str, timestamp: datetime, data: dict):
self.event = event
self.timestamp = timestamp
self.data = data

async def notify(self, db: AsyncSession) -> List[Job]:
return await enqueue_notify_events(
db,
event=self.event,
timestamp=self.timestamp,
data=self.data,
)
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/webhooks/v1/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from datetime import datetime

from argilla_server.models import Webhook
from argilla_server.contexts import info
from argilla_server.models import Webhook
from argilla_server.webhooks.v1.commons import notify_event
from argilla_server.webhooks.v1.enums import WebhookEvent

Expand Down
26 changes: 10 additions & 16 deletions argilla-server/src/argilla_server/webhooks/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@

from rq.job import Job
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models import Record, Dataset
from argilla_server.webhooks.v1.event import Event
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.schemas import RecordEventSchema
from argilla_server.jobs.webhook_jobs import enqueue_notify_events
from argilla_server.models import Record, Dataset


async def notify_record_event(db: AsyncSession, record_event: RecordEvent, record: Record) -> List[Job]:
if record_event == RecordEvent.deleted:
return await _notify_record_deleted_event(db, record)
event = await build_record_event(db, record_event, record)

return await event.notify(db)


async def build_record_event(db: AsyncSession, record_event: RecordEvent, record: Record) -> Event:
# NOTE: Force loading required association resources required by the event schema
(
await db.execute(
select(Dataset)
Expand All @@ -44,18 +48,8 @@ async def notify_record_event(db: AsyncSession, record_event: RecordEvent, recor
)
).scalar_one()

return await enqueue_notify_events(
db,
return Event(
event=record_event,
timestamp=datetime.utcnow(),
data=RecordEventSchema.from_orm(record).dict(),
)


async def _notify_record_deleted_event(db: AsyncSession, record: Record) -> List[Job]:
return await enqueue_notify_events(
db,
event=RecordEvent.deleted,
timestamp=datetime.utcnow(),
data={"id": record.id},
)
26 changes: 9 additions & 17 deletions argilla-server/src/argilla_server/webhooks/v1/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@
from typing import List
from datetime import datetime

from rq.job import Job
from sqlalchemy import select
from sqlalchemy.orm import selectinload

from rq.job import Job
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models import Response, Record, Dataset
from argilla_server.jobs.webhook_jobs import enqueue_notify_events
from argilla_server.webhooks.v1.schemas import ResponseEventSchema
from argilla_server.webhooks.v1.event import Event
from argilla_server.webhooks.v1.enums import ResponseEvent
from argilla_server.webhooks.v1.schemas import ResponseEventSchema


async def notify_response_event(db: AsyncSession, response_event: ResponseEvent, response: Response) -> List[Job]:
if response_event == ResponseEvent.deleted:
return await _notify_response_deleted_event(db, response)
event = await build_response_event(db, response_event, response)

return await event.notify(db)


async def build_response_event(db: AsyncSession, response_event: ResponseEvent, response: Response) -> Event:
# NOTE: Force loading required association resources required by the event schema
(
await db.execute(
Expand All @@ -51,18 +53,8 @@ async def notify_response_event(db: AsyncSession, response_event: ResponseEvent,
)
).scalar_one()

return await enqueue_notify_events(
db,
return Event(
event=response_event,
timestamp=datetime.utcnow(),
data=ResponseEventSchema.from_orm(response).dict(),
)


async def _notify_response_deleted_event(db: AsyncSession, response: Response) -> List[Job]:
return await enqueue_notify_events(
db,
event=ResponseEvent.deleted,
timestamp=datetime.utcnow(),
data={"id": response.id},
)
Loading