Skip to content

Commit

Permalink
use sqlalchemy for query_reviews_for_message() in TreeManager
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Jan 16, 2023
1 parent 6ccbd38 commit 2d4e39c
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
Expand Down Expand Up @@ -840,15 +840,13 @@ def query_num_active_trees(self) -> int:
return query.scalar()

def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
sql_qry = """
SELECT tl.*
FROM task t
INNER JOIN text_labels tl ON tl.id = t.id
WHERE t.done = TRUE
AND tl.message_id = :message_id
"""
r = self.db.execute(text(sql_qry), {"message_id": message_id})
return [TextLabels.from_orm(x) for x in r.all()]
qry = (
self.db.query(TextLabels)
.select_from(Task)
.join(TextLabels, Task.id == TextLabels.id)
.filter(Task.done, TextLabels.message_id == message_id)
)
return qry.all()

@managed_tx_method(CommitMode.FLUSH)
def _insert_tree_state(
Expand Down Expand Up @@ -911,7 +909,12 @@ def _insert_default_state(
# print("query_extendible_parents", tm.query_extendible_parents())
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))

print("next_task:", tm.next_task())
print(
"query_reviews_for_message",
tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
)

# print("next_task:", tm.next_task())

# print(
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
Expand Down

0 comments on commit 2d4e39c

Please sign in to comment.