Skip to content

Commit

Permalink
fix export & add emojis (#1004)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf authored Jan 29, 2023
1 parent 29b540a commit eda275b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
21 changes: 15 additions & 6 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,34 +318,43 @@ def main():

parser.add_argument(
"--print-openapi-schema",
default=False,
help="Dumps the openapi schema to stdout",
action=argparse.BooleanOptionalAction,
action="store_true",
)
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
parser.add_argument("--port", help="The port to run the server", default=8080)
parser.add_argument(
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
"--export",
default=False,
help="Export all trees which are ready for exporting.",
action="store_true",
)
parser.add_argument(
"--export-file",
type=str,
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
)
parser.add_argument(
"--retry-scoring",
default=False,
help="Retry scoring failed message trees",
action=argparse.BooleanOptionalAction,
action="store_true",
)

args = parser.parse_args()

if args.print_openapi_schema:
print(get_openapi_schema())
elif args.export:

if args.export:
use_compression: bool = ".gz" in args.export_file
export_ready_trees(file=args.export_file, use_compression=use_compression)
elif args.retry_scoring:

if args.retry_scoring:
retry_scoring_failed_message_trees()
else:

if not (args.export or args.print_openapi_schema or args.retry_scoring):
uvicorn.run(app, host=args.host, port=args.port)


Expand Down
33 changes: 16 additions & 17 deletions backend/oasst_backend/utils/tree_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class ExportMessageNode(BaseModel):
rank: int | None
synthetic: bool | None
model_name: str | None
emojis: dict[str, int] | None
replies: list[ExportMessageNode] | None

@classmethod
def prep_message_export(cls, message: Message) -> ExportMessageNode:
return cls(
@staticmethod
def prep_message_export(message: Message) -> ExportMessageNode:
return ExportMessageNode(
message_id=str(message.id),
parent_id=str(message.parent_id) if message.parent_id else None,
text=str(message.payload.payload.text),
Expand All @@ -33,6 +34,7 @@ def prep_message_export(cls, message: Message) -> ExportMessageNode:
review_count=message.review_count,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,
rank=message.rank,
)

Expand All @@ -43,23 +45,20 @@ class ExportMessageTree(BaseModel):


def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]

message_parents = defaultdict(list)
for message in export_tree_data:
message_parents[message.parent_id].append(message)
messages_by_parent = defaultdict(list)
for message in export_messages:
messages_by_parent[message.parent_id].append(message)

def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
children = message_parents[parent]
tree.replies = children
def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
node.replies = messages_by_parent[node.message_id]
for child in node.replies:
assign_replies(child)
return node

for idx, child in enumerate(tree.replies):
build_tree(tree.replies[idx], child.message_id, messages)

build_tree(export_tree, None, export_tree_data)

return export_tree
prompt = assign_replies(messages_by_parent[None][0])
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)


def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
Expand Down

0 comments on commit eda275b

Please sign in to comment.