diff --git a/examples/game/game_app.py b/examples/game/game_app.py index fc0fe2ebc..65a9d4974 100644 --- a/examples/game/game_app.py +++ b/examples/game/game_app.py @@ -11,6 +11,7 @@ from multiprocessing import Event import traceback from urllib import parse +from tempfile import TemporaryDirectory import agentscope import shutil from config_utils import load_configs @@ -31,7 +32,7 @@ send_riddle_input, get_quest_msg, ) -from oss_utils import upload_config_to_oss +from oss_utils import upload_config_to_oss, replace_model_in_yaml from create_config_tab import ( create_config_tab, create_config_accord, @@ -397,13 +398,30 @@ def build_game_zip(uid): uid = check_uuid(uid) directory_path = f'/tmp/as_game/config/{uid}' - file_path = f"/tmp/as_game/config/{uid}.zip" - if not os.path.exists(directory_path): - os.makedirs(directory_path) - - shutil.make_archive(file_path[:-4], 'zip', directory_path) - gr.Info("🎉打包成功!") + # 创建临时目录 + with TemporaryDirectory() as temp_directory: + # 遍历目录中的所有.yaml文件 + for root, dirs, files in os.walk(directory_path): + for file_name in files: + if file_name.endswith('.yaml'): + original_file_path = os.path.join(root, file_name) + # 在临时目录中创建修改后的文件 + replace_model_in_yaml(original_file_path, temp_directory) + + # 拷贝非YAML文件到临时目录 + for root, dirs, files in os.walk(directory_path): + for file_name in files: + if not file_name.endswith('.yaml'): + original_file_path = os.path.join(root, file_name) + temp_file_path = os.path.join(temp_directory, file_name) + shutil.copy2(original_file_path, temp_file_path) + + # 创建zip文件 + shutil.make_archive(f'/tmp/as_game/config/{uid}', 'zip', + temp_directory) + + print("🎉打包成功!") def update_publish_button(uid): diff --git a/examples/game/oss_utils.py b/examples/game/oss_utils.py index 8cc043039..e3535d6b8 100644 --- a/examples/game/oss_utils.py +++ b/examples/game/oss_utils.py @@ -1,5 +1,6 @@ import os import oss2 +import yaml def upload_to_oss(bucket, local_file_path, oss_file_path): @@ -35,5 +36,27 @@ def get_oss_config(): return access_key_id, access_key_secret, endpoint, bucket_name +def replace_model_in_yaml(original_file_path, temp_directory, + old_str="post_api", new_str="tongyi_model"): + with open(original_file_path, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + + # 修改数据 + if isinstance(data, list): + for item in data: + if 'model' in item and item['model'] == old_str: + item['model'] = new_str + elif isinstance(data, dict): + if 'model' in data and data['model'] == old_str: + data['model'] = new_str + + # 写入临时文件 + temp_file_path = os.path.join(temp_directory, + os.path.basename(original_file_path)) + with open(temp_file_path, 'w', encoding='utf-8') as temp_file: + yaml.safe_dump(data, temp_file, default_flow_style=False, + allow_unicode=True) + + if __name__ == '__main__': upload_config_to_oss("local_user")