Skip to content

Commit

Permalink
add avatar (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma authored Jan 19, 2024
1 parent 1b1bef9 commit 043e1a8
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 22 deletions.
13 changes: 7 additions & 6 deletions examples/game/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
CheckpointArgs,
enable_web_ui,
send_chat_msg,
send_player_msg,
send_player_input,
get_chat_msg,
get_suggests,
ResetException,
)

import gradio as gr
from gradio_groupchat import GroupChat
import modelscope_gradio_components as mgr

enable_web_ui()

Expand Down Expand Up @@ -105,7 +106,7 @@ def init_game():
is_init = True

def check_for_new_session(uid):
print(uid)
# print(uid)
if uid not in glb_signed_user:
glb_signed_user.append(uid)
game_thread = threading.Thread(target=start_game, args=(uid,))
Expand Down Expand Up @@ -136,7 +137,7 @@ def start_game(uid):
}

user_chat_bot_cover = gr.HTML(format_cover_html(welcome))
chatbot = GroupChat(label="Dialog", show_label=False, height=600, visible=False)
chatbot = mgr.Chatbot(label="Dialog", show_label=False, height=600, visible=False)

with gr.Row():
with gr.Column():
Expand Down Expand Up @@ -185,7 +186,7 @@ def start_game(uid):

def send_message(msg, uid):
send_player_input(msg, uid=uid)
send_chat_msg(msg, "你", uid=uid)
send_player_msg(msg, "你", uid=uid)
return ""

return_welcome_button = gr.Button(
Expand Down Expand Up @@ -219,7 +220,7 @@ def update_suggest(uid):
def game_ui():
visible = True
invisible = False
return {chatbot:GroupChat(visible=visible),
return {chatbot: mgr.Chatbot(visible=visible),
user_chat_input: gr.Text(visible=visible),
send_button: gr.Button(visible=visible),
new_button: gr.Button(visible=invisible),
Expand All @@ -235,7 +236,7 @@ def game_ui():
def welcome_ui():
visible = True
invisible = False
return {chatbot:GroupChat(visible=invisible),
return {chatbot: mgr.Chatbot(visible=invisible),
user_chat_input: gr.Text(visible=invisible),
send_button: gr.Button(visible=invisible),
new_button: gr.Button(visible=visible),
Expand Down
Binary file added examples/game/assets/avatar_abing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/game/assets/avatar_fan.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/game/assets/avatar_laoxu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/game/assets/avatar_wang.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/game/assets/bot.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/game/assets/user.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions examples/game/config/customer_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"name": "王老板"
"model": "tongyi_model"
"use_memory": true
"avatar": "assets/avatar_wang.png"
"character_setting":
"food_preference": >
1.王老板是本地人。
Expand Down Expand Up @@ -29,6 +30,7 @@
"name": "阿炳"
"model": "tongyi_model"
"use_memory": true
"avatar": "assets/avatar_abing.png"
"character_setting":
"food_preference": >
1.阿炳也是本地人
Expand Down Expand Up @@ -56,6 +58,7 @@
"name": "老许"
"model": "tongyi_model"
"use_memory": true
"avatar": "assets/avatar_laoxu.png"
"character_setting":
"food_preference": >
老许最喜欢的是羊肉,比如烤羊肉串和羊肉烧卖;
Expand All @@ -82,6 +85,7 @@
"name": "范老师"
"model": "tongyi_model"
"use_memory": true
"avatar": "assets/avatar_fan.png"
"character_setting":
"food_preference": >
1. 范老师是山东人
Expand Down
4 changes: 2 additions & 2 deletions examples/game/customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from loguru import logger

from enums import CustomerConv, CustomerPlot
from utils import send_chat_msg
from utils import send_chat_msg, get_a_random_avatar
from agentscope.agents import StateAgent, DialogAgent
from agentscope.message import Msg


HISTORY_WINDOW = 10
# TODO: for debug, set the score bars to be lower
MIN_BAR_RECEIVED_CONST = 4
Expand All @@ -23,6 +22,7 @@ def __init__(self, game_config: dict, **kwargs: Any):
self.game_config = game_config
self.max_itr_preorder = 5
self.preorder_itr_count = 0
self.avatar = self.config.get('avatar', get_a_random_avatar())
self.background = self.config["character_setting"]["background"]
self.friendship = int(self.config.get("friendship", 60))

Expand Down
8 changes: 5 additions & 3 deletions examples/game/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def invited_group_chat(
continue
for c in invited_customer:
msg = c(msg)
send_pretty_msg(msg, uid=uid)
send_pretty_msg(msg, uid=uid,avatar=c.avatar)
end_query_answer(uid=uid)

invited_names.sort()
Expand Down Expand Up @@ -146,7 +146,8 @@ def one_on_one_loop(customers, player, uid):
uid=uid,
)
break
send_pretty_msg(msg, uid=uid)

send_pretty_msg(msg, uid=uid, avatar= customer.avatar)
send_chat_msg(
"【系统】请输入“做菜”启动做菜程序,它会按所选定食材产生菜品。 \n"
"【系统】对话轮次过多会使得顾客综合满意度下降。 \n"
Expand Down Expand Up @@ -182,7 +183,8 @@ def one_on_one_loop(customers, player, uid):
while True:
msg = customer(msg)
# print(f"{customer_reply.name}(顾客):" + customer_reply.content)
send_pretty_msg(msg, uid=uid)

send_pretty_msg(msg, uid=uid,avatar=customer.avatar)
send_chat_msg("【系统】若不输入任何内容直接按回车键,顾客将离开餐馆。", uid=uid)
msg = player(msg)
if len(msg["content"]) == 0:
Expand Down
37 changes: 26 additions & 11 deletions examples/game/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from datetime import datetime
from colorist import BgBrightColor
import inquirer
import random
from multiprocessing import Queue
from collections import defaultdict
from dataclasses import dataclass
from agentscope.message import Msg
from enums import StagePerNight
from pathlib import Path
from queue import Empty

USE_WEB_UI = False
Expand Down Expand Up @@ -52,6 +54,12 @@ def load_game_checkpoint(checkpoint_path: str) -> GameCheckpoint:
def speak_print(m: Msg):
print(f"{BgBrightColor.BLUE}{m.name}{BgBrightColor.OFF}: {m.content}")

def get_avatar_files(assets_path='assets'):
files = Path(assets_path).glob('*avatar*')
return [str(file) for file in files]

def get_a_random_avatar():
return random.choices(get_avatar_files())

def check_active_plot(
plots: list[dict],
Expand Down Expand Up @@ -143,17 +151,29 @@ def init_uid_queues():
"glb_queue_chat_suggests": Queue(),
}


glb_uid_dict = defaultdict(init_uid_queues)


def send_chat_msg(msg, role="系统", uid=None):
def send_chat_msg(msg, role="系统", uid=None, flushing=False, avatar='./assets/bot.jpg'):
print(msg)
if get_use_web_ui():
global glb_uid_dict
glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"]
glb_queue_chat_msg.put([role, msg])
glb_queue_chat_msg.put([None,
{"text": msg,
"flushing": flushing,
"avatar": avatar
}])

def send_player_msg(msg, role="你", uid= None, flushing=False, avatar='./assets/user.jpg'):
print(msg)
if get_use_web_ui():
global glb_uid_dict
glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"]
glb_queue_chat_msg.put([
{"text": msg,
"flushing": flushing,
"avatar": avatar
},None])

def get_chat_msg(uid=None):
global glb_uid_dict
Expand All @@ -164,21 +184,17 @@ def get_chat_msg(uid=None):
return line
return None


def send_player_input(msg, role="餐厅老板", uid=None):
if get_use_web_ui():
global glb_uid_dict
glb_queue_chat_input = glb_uid_dict[uid]["glb_queue_chat_input"]
glb_queue_chat_input.put([role, msg])


def send_pretty_msg(msg, uid=None):
def send_pretty_msg(msg, uid=None,flushing=True, avatar='./assets/bot.jpg'):
speak_print(msg)
if get_use_web_ui():
global glb_uid_dict
glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"]
glb_queue_chat_msg.put([msg.name, msg.content])

send_chat_msg(msg.content, uid = uid, role=msg.name, flushing=flushing, avatar=avatar)

def get_player_input(name=None, uid=None):
global glb_uid_dict
Expand Down Expand Up @@ -250,7 +266,6 @@ def query_answer(questions: List, key="ans", uid=None):
suggests_msg = (
suggests.message + "\n" + format_choices(suggests.choices)
)
print("suggests=", suggests)
samples = [[choice] for choice in suggests.choices]
msg = suggests.message
send_chat_msg(suggests_msg, uid=uid)
Expand Down

0 comments on commit 043e1a8

Please sign in to comment.