diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 34672e2ed6..225b0146d1 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -9,7 +9,7 @@ from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list from oasst_backend.config import TreeManagerConfiguration, settings -from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state +from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI @@ -840,15 +840,13 @@ def query_num_active_trees(self) -> int: return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: - sql_qry = """ -SELECT tl.* -FROM task t - INNER JOIN text_labels tl ON tl.id = t.id -WHERE t.done = TRUE - AND tl.message_id = :message_id -""" - r = self.db.execute(text(sql_qry), {"message_id": message_id}) - return [TextLabels.from_orm(x) for x in r.all()] + qry = ( + self.db.query(TextLabels) + .select_from(Task) + .join(TextLabels, Task.id == TextLabels.id) + .filter(Task.done, TextLabels.message_id == message_id) + ) + return qry.all() @managed_tx_method(CommitMode.FLUSH) def _insert_tree_state( @@ -911,7 +909,12 @@ def _insert_default_state( # print("query_extendible_parents", tm.query_extendible_parents()) # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292"))) - print("next_task:", tm.next_task()) + print( + "query_reviews_for_message", + tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")), + ) + + # print("next_task:", tm.next_task()) # print( # "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))