From 390920c93b1dc77e2d73e02c2111fbb342c22089 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Mon, 5 Feb 2024 17:47:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90WEBUI=E3=80=91Add=20speak=20function?= =?UTF-8?q?=20within=20AgentBase.=20(#27)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agentscope/agents/agent.py | 7 ++ src/agentscope/agents/dialog_agent.py | 7 +- src/agentscope/agents/dict_dialog_agent.py | 4 +- src/agentscope/utils/logging_utils.py | 108 +++++++++------------ tests/logger_test.py | 71 ++++++++++++++ 5 files changed, 130 insertions(+), 67 deletions(-) create mode 100644 tests/logger_test.py diff --git a/src/agentscope/agents/agent.py b/src/agentscope/agents/agent.py index 18cca85b0..436f056f6 100644 --- a/src/agentscope/agents/agent.py +++ b/src/agentscope/agents/agent.py @@ -132,6 +132,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> dict: return res + def speak( + self, + content: Union[str, dict], + ) -> None: + """Speak out the content generated by the agent.""" + logger.chat(content) + def observe(self, x: Union[dict, Sequence[dict]]) -> None: """Observe the input, store it in memory without response to it. diff --git a/src/agentscope/agents/dialog_agent.py b/src/agentscope/agents/dialog_agent.py index 1f749c289..b57ca4ce2 100644 --- a/src/agentscope/agents/dialog_agent.py +++ b/src/agentscope/agents/dialog_agent.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """A general dialog agent.""" from typing import Any, Optional, Union, Callable -from loguru import logger from ..message import Msg from .agent import AgentBase @@ -89,8 +88,10 @@ def reply(self, x: dict = None) -> dict: response = self.model(prompt) msg = Msg(self.name, response) - # logging and record the message in memory - logger.chat(msg) + # Print/speak the message in this agent's voice + self.speak(msg) + + # Record the message in memory self.memory.add(msg) return msg diff --git a/src/agentscope/agents/dict_dialog_agent.py b/src/agentscope/agents/dict_dialog_agent.py index a6e66d9fe..d302e7823 100644 --- a/src/agentscope/agents/dict_dialog_agent.py +++ b/src/agentscope/agents/dict_dialog_agent.py @@ -148,8 +148,8 @@ def reply(self, x: dict = None) -> dict: else: msg = Msg(self.name, response) - # logging the message - logger.chat(msg) + # Print/speak the message in this agent's voice + self.speak(msg) # record to memory self.memory.add(msg) diff --git a/src/agentscope/utils/logging_utils.py b/src/agentscope/utils/logging_utils.py index 47813b898..295ce1c0b 100644 --- a/src/agentscope/utils/logging_utils.py +++ b/src/agentscope/utils/logging_utils.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- """Logging utilities.""" +import json import os import sys from typing import Optional, Literal, Union, Any from loguru import logger -from agentscope.constants import MSG_TOKEN - LOG_LEVEL = Literal[ "TRACE", "DEBUG", @@ -18,6 +17,9 @@ "CRITICAL", ] +LEVEL_CHAT_LOG = "CHAT_LOG" +LEVEL_CHAT_SAVE = "CHAT_SAVE" + class _Stream: """Redirect stderr to logging""" @@ -85,72 +87,44 @@ def _chat(message: Union[str, dict], *args: Any, **kwargs: Any) -> None: "content" keys, and the message will be logged as ": ". """ + # Save message into file + logger.log(LEVEL_CHAT_SAVE, json.dumps(message), *args, **kwargs) + + # Print message in terminal with specific format if isinstance(message, dict): contain_name_or_role = "name" in message or "role" in message contain_content = "content" in message contain_url = "url" in message # print content if contain name or role and contain content - if contain_name_or_role and contain_content: + if contain_name_or_role: speaker = message.get("name", None) or message.get("role", None) - content = message["content"] (m1, m2) = _get_speaker_color(speaker) - logger.log( - "CHAT", - f"{m1}{speaker}{m2}: {content}".replace( - "{", - "{{", - ).replace("}", "}}"), - *args, - **kwargs, - ) + + print_str = [] + if contain_content: + print_str.append( + f"{m1}{speaker}{m2}: {message['content']}", + ) if contain_url: - # print url if contain name or role and contain url - url = message["url"] - (m1, m2) = _get_speaker_color(speaker) - # print url one by one if url is a list - if isinstance(url, list): - for each_url in url: - logger.log( - "CHAT", - f"{m1}{speaker}{m2}: {each_url}", - *args, - **kwargs, - ) - else: - logger.log( - "CHAT", - f"{m1}{speaker}{m2}: {url}", - *args, - **kwargs, - ) - - # print raw message if not contain name - if not contain_name_or_role or not contain_content: - logger.log("CHAT", str(message), *args, **kwargs) - else: - # print other types of message directly - logger.log("CHAT", message, *args, **kwargs) + print_str.append(f"{m1}{speaker}{m2}: {message['url']}") + if len(print_str) > 0: + print_str = ( + "\n".join(print_str).replace("{", "{{").replace("}", "}}") + ) + logger.log(LEVEL_CHAT_LOG, print_str, *args, **kwargs) + return -def _level_format(record: dict) -> str: - """Format the log record.""" - if record["level"].name == "CHAT": - return record["message"] + "\n" - else: - return ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | {" - "level: <8} | {name}:{" - "function}:{line} - {" - "message}\n" - ) + message = str(message).replace("{", "{{").replace("}", "}}") + logger.log(LEVEL_CHAT_LOG, message, *args, **kwargs) -def _level_format_with_special_tokens(record: dict) -> str: +def _level_format(record: dict) -> str: """Format the log record.""" - if record["level"].name == "CHAT": - return MSG_TOKEN + record["message"] + MSG_TOKEN + "\n" + if record["level"].name == LEVEL_CHAT_LOG: + return record["message"] + "\n" else: return ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | {" @@ -182,33 +156,43 @@ def setup_logger( sys.stderr = _Stream() # add chat function for logger - logger.level("CHAT", no=21, color="") + logger.level(LEVEL_CHAT_LOG, no=21) + logger.level(LEVEL_CHAT_SAVE, no=0) logger.chat = _chat # set logging level logger.remove() # standard output for all logging except chat - logger.add(sys.stdout, format=_level_format, enqueue=True, level=level) + logger.add( + sys.stdout, + filter=lambda record: record["level"].name != LEVEL_CHAT_SAVE, + format=_level_format, + enqueue=True, + level=level, + ) if path_log is not None: if not os.path.exists(path_log): os.makedirs(path_log) - path_log_file = os.path.join(path_log, "all.log") - path_log_file_only_chat = os.path.join( + path_log_file = os.path.join(path_log, "logging.log") + path_chat_file = os.path.join( path_log, - "chat.log", + "logging.chat", ) # save all logging into file logger.add( path_log_file, - format=_level_format_with_special_tokens, + filter=lambda record: record["level"].name != LEVEL_CHAT_SAVE, + format=_level_format, enqueue=True, level=level, ) + logger.add( - path_log_file_only_chat, - format=_level_format, + path_chat_file, + filter=lambda record: record["level"].name == LEVEL_CHAT_SAVE, + format="{message}", enqueue=True, - level="CHAT", + level=LEVEL_CHAT_SAVE, ) diff --git a/tests/logger_test.py b/tests/logger_test.py new file mode 100644 index 000000000..0db963753 --- /dev/null +++ b/tests/logger_test.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" Unit test for logger chat""" +import os +import shutil +import time +import unittest + +from loguru import logger + +from agentscope.utils import setup_logger + + +class LoggerTest(unittest.TestCase): + """ + Unit test for logger. + """ + + def setUp(self) -> None: + """Setup for unit test.""" + self.run_dir = "./logger_runs/" + + def test_logger_chat(self) -> None: + """Logger chat.""" + + setup_logger(self.run_dir, level="INFO") + + # str with "\n" + logger.chat("Test\nChat\n\nMessage\n\n") + + # dict with "\n" + logger.chat( + { + "name": "Alice", + "content": "Hi!\n", + "url": "https://xxx.png", + }, + ) + + # dict without content + logger.chat({"name": "Alice", "url": "https://xxx.png"}) + + # dict + logger.chat({"abc": 1}) + + # To avoid that logging is not finished before the file is read + time.sleep(3) + + with open( + os.path.join(self.run_dir, "logging.chat"), + "r", + encoding="utf-8", + ) as file: + lines = file.readlines() + + ground_truth = [ + '"Test\\nChat\\n\\nMessage\\n\\n"\n', + '{"name": "Alice", "content": "Hi!\\n", "url": "https://xxx.png' + '"}\n', + '{"name": "Alice", "url": "https://xxx.png"}\n', + '{"abc": 1}\n', + ] + + self.assertListEqual(lines, ground_truth) + + def tearDown(self) -> None: + """Tear down for LoggerTest.""" + shutil.rmtree(self.run_dir) + + +if __name__ == "__main__": + unittest.main()