diff --git a/backend/main.py b/backend/main.py index 8e30b78e14..dacd8097c1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -18,7 +18,8 @@ from oasst_backend.config import settings from oasst_backend.database import engine from oasst_backend.models import message_tree_state -from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository +from oasst_backend.prompt_repository import PromptRepository, UserRepository +from oasst_backend.task_repository import TaskRepository, delete_expired_tasks from oasst_backend.tree_manager import TreeManager from oasst_backend.user_repository import User from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame @@ -318,6 +319,13 @@ def update_user_streak(session: Session) -> None: return +@app.on_event("startup") +@repeat_every(seconds=60 * 60) # 1 hour +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def cronjob_delete_expired_tasks(session: Session) -> None: + delete_expired_tasks(session) + + app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index f1685221ca..1d2607642f 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -25,12 +25,29 @@ class TreeManagerConfiguration(BaseModel): goal_tree_size: int = 12 """Total number of messages to gather per tree.""" + random_goal_tree_size: bool = False + """If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size].""" + + min_goal_tree_size: int = 5 + """Minimum tree size for random goal sizes.""" + num_reviews_initial_prompt: int = 3 """Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state.""" num_reviews_reply: int = 3 """Number of peer review checks to collect per reply (other than initial_prompt).""" + auto_mod_enabled: bool = True + """Flag to enable/disable auto moderation.""" + + auto_mod_max_skip_reply: int = 25 + """Automatically set tree state to `halted_by_moderator` when more than the specified number + of users skip replying to a message. (auto moderation)""" + + auto_mod_red_flags: int = 3 + """Delete messages that receive more than this number of red flags if it is a reply or + set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)""" + p_full_labeling_review_prompt: float = 1.0 """Probability of full text-labeling (instead of mandatory only) for initial prompts.""" @@ -222,6 +239,8 @@ def validate_user_stats_intervals(cls, v: int): RATE_LIMIT_TASK_API_TIMES: int = 10_000 RATE_LIMIT_TASK_API_MINUTES: int = 1 + TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index dacd5f9b0a..cb1dd2e291 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -155,8 +155,6 @@ def insert_message( review_result=review_result, ) self.db.add(message) - - # self.db.refresh(message) return message def _validate_task( @@ -288,6 +286,10 @@ def store_text_reply( task.done = True self.db.add(task) self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text)) + logger.debug( + f"Inserted message id={user_message.id}, tree={user_message.message_tree_id}, user_id={user_message.user_id}, " + f"text[:100]='{user_message.text[:100]}', role='{user_message.role}', lang='{user_message.lang}'" + ) return user_message @managed_tx_method(CommitMode.FLUSH) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index 5fe84b24d5..7748840a43 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -1,16 +1,18 @@ -from datetime import timedelta +from datetime import datetime, timedelta from typing import Optional from uuid import UUID import oasst_backend.models.db_payload as db_payload from loguru import logger +from oasst_backend.config import settings from oasst_backend.models import ApiClient, Task from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.user_repository import UserRepository from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session, func, or_ +from oasst_shared.utils import utcnow +from sqlmodel import Session, delete, func, or_ from starlette.status import HTTP_404_NOT_FOUND @@ -24,6 +26,13 @@ def validate_frontend_message_id(message_id: str) -> None: raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) +def delete_expired_tasks(session: Session) -> int: + stm = delete(Task).where(Task.expiry_date < utcnow()) + result = session.exec(stm) + logger.info(f"Deleted {result.rowcount} expired tasks.") + return result.rowcount + + class TaskRepository: def __init__( self, @@ -118,12 +127,18 @@ def store_task( case _: raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) + if not collective and settings.TASK_VALIDITY_MINUTES > 0: + expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES) + else: + expiry_date = None + task_model = self.insert_task( payload=payload, id=task.id, message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective, + expiry_date=expiry_date, ) assert task_model.id == task.id return task_model @@ -175,6 +190,7 @@ def insert_task( message_tree_id: UUID = None, parent_message_id: UUID = None, collective: bool = False, + expiry_date: datetime = None, ) -> Task: c = PayloadContainer(payload=payload) task = Task( @@ -186,6 +202,7 @@ def insert_task( message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective, + expiry_date=expiry_date, ) logger.debug(f"inserting {task=}") self.db.add(task) @@ -218,3 +235,6 @@ def fetch_recent_reply_tasks( if limit: qry = qry.limit(limit) return qry.all() + + def delete_expired_tasks(self) -> int: + return delete_expired_tasks(self.db) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index f24292a498..7d5a37d9ea 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -9,6 +9,7 @@ import numpy as np import pydantic +import sqlalchemy as sa from fastapi.encoders import jsonable_encoder from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list @@ -31,6 +32,7 @@ from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.utils import utcnow +from sqlalchemy.sql.functions import coalesce from sqlmodel import Session, and_, func, not_, or_, text, update @@ -269,6 +271,31 @@ def _prompt_lottery(self, lang: str) -> int: self._enter_state(mts, message_tree_state.State.GROWING) self.db.flush() + def _auto_moderation(self, lang: str) -> None: + if not self.cfg.auto_mod_enabled: + return + + bad_messages = self.query_moderation_bad_messages(lang=lang) + for m in bad_messages: + num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag) + + if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags: + if m.parent_id is None: + logger.warning( + f"[AUTO MOD] Halting tree {m.message_tree_id}, inital prompt got too many red flags ({m.emojis})." + ) + self.enter_low_grade_state(m.message_tree_id) + else: + logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).") + self.pr.mark_messages_deleted(m.id, recursive=True) + + num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply) + if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply: + logger.warning( + f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})." + ) + self.halt_tree(m.id, halt=True) + def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]: self.pr.ensure_user_is_enabled() @@ -276,6 +303,7 @@ def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskReq lang = "en" logger.warning("Task availability request without lang tag received, assuming lang='en'.") + self._auto_moderation(lang=lang) num_missing_prompts = self._prompt_lottery(lang=lang) extendible_parents, _ = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) @@ -313,6 +341,7 @@ def next_task( lang = "en" logger.warning("Task request without lang tag received, assuming 'en'.") + self._auto_moderation(lang=lang) num_missing_prompts = self._prompt_lottery(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) @@ -1254,6 +1283,37 @@ def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: ) return qry.all() + def query_moderation_bad_messages(self, lang: str) -> list[Message]: + qry = ( + self.db.query(Message) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + or_( + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, + MessageTreeState.state == message_tree_state.State.GROWING, + ), + or_( + Message.parent_id.is_(None), + Message.review_result, + and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply), + ), + not_(Message.deleted), + or_( + coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0) + >= self.cfg.auto_mod_red_flags, + coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0) + >= self.cfg.auto_mod_max_skip_reply, + ), + ) + ) + + if lang is not None: + qry = qry.filter(Message.lang == lang) + + return qry.all() + @managed_tx_method(CommitMode.FLUSH) def _insert_tree_state( self, @@ -1281,10 +1341,17 @@ def _insert_default_state( self, root_message_id: UUID, state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW, + *, + goal_tree_size: int = None, ) -> MessageTreeState: + if goal_tree_size is None: + if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size: + goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size) + else: + goal_tree_size = self.cfg.goal_tree_size return self._insert_tree_state( root_message_id=root_message_id, - goal_tree_size=self.cfg.goal_tree_size, + goal_tree_size=goal_tree_size, max_depth=self.cfg.max_tree_depth, max_children_count=self.cfg.max_children_count, state=state, @@ -1379,9 +1446,32 @@ def _purge_message_internal(self, message_id: UUID) -> None: DELETE FROM task t WHERE t.parent_message_id = :message_id; DELETE FROM message WHERE id = :message_id; """ + parent_id = self.pr.fetch_message(message_id=message_id).parent_id r = self.db.execute(text(sql_purge_message), {"message_id": message_id}) logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.") + sql_update_ranking_counts = """ +WITH r AS ( + -- find ranking results and count per child + SELECT c.id, + count(*) FILTER ( + WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar) + ) AS ranking_count + FROM message c + LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload' + AND mr.message_id = c.parent_id + WHERE c.parent_id = :parent_id + GROUP BY c.id +) +UPDATE message m SET ranking_count = r.ranking_count +FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count; +""" + + if parent_id is not None: + # update ranking counts of remaining children + r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id}) + logger.debug(f"ranking_count updated for {r.rowcount} rows.") + def purge_message_tree(self, message_tree_id: UUID) -> None: sql_purge_message_tree = """ DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;