Skip to content

Commit

Permalink
improve format according to pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DavdGao committed Feb 5, 2024
1 parent 7a0642f commit 7824172
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
5 changes: 3 additions & 2 deletions src/agentscope/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> dict:
return res

def speak(
self,
content: Union[str, dict]) -> None:
self,
content: Union[str, dict],
) -> None:
"""Speak out the content generated by the agent."""
logger.chat(content)

Expand Down
1 change: 0 additions & 1 deletion src/agentscope/agents/dialog_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
"""A general dialog agent."""
from typing import Any, Optional, Union, Callable
from loguru import logger

from ..message import Msg
from .agent import AgentBase
Expand Down
13 changes: 9 additions & 4 deletions src/agentscope/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,17 @@ def _chat(message: Union[str, dict], *args: Any, **kwargs: Any) -> None:

print_str = []
if contain_content:
print_str.append(f"{m1}<b>{speaker}</b>{m2}: {message['content']}")
print_str.append(
f"{m1}<b>{speaker}</b>{m2}: {message['content']}",
)

if contain_url:
print_str.append(f"{m1}<b>{speaker}</b>{m2}: {message['url']}")

if len(print_str) > 0:
print_str = "\n".join(print_str).replace("{", "{{").replace("}", "}}")
print_str = (
"\n".join(print_str).replace("{", "{{").replace("}", "}}")
)
logger.log(LEVEL_CHAT_LOG, print_str, *args, **kwargs)
return

Expand Down Expand Up @@ -164,7 +168,8 @@ def setup_logger(
filter=lambda record: record["level"].name != LEVEL_CHAT_SAVE,
format=_level_format,
enqueue=True,
level=level)
level=level,
)

if path_log is not None:
if not os.path.exists(path_log):
Expand All @@ -189,5 +194,5 @@ def setup_logger(
filter=lambda record: record["level"].name == LEVEL_CHAT_SAVE,
format="{message}",
enqueue=True,
level=LEVEL_CHAT_SAVE
level=LEVEL_CHAT_SAVE,
)
16 changes: 11 additions & 5 deletions tests/logger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ def test_logger_chat(self) -> None:
logger.chat("Test\nChat\n\nMessage\n\n")

# dict with "\n"
logger.chat({"name": "Alice", "content": "Hi!\n", "url":
"https://xxx.png"})
logger.chat(
{
"name": "Alice",
"content": "Hi!\n",
"url": "https://xxx.png",
},
)

# dict without content
logger.chat({"name": "Alice", "url": "https://xxx.png"})
Expand All @@ -40,14 +45,15 @@ def test_logger_chat(self) -> None:

ground_truth = [
'"Test\\nChat\\n\\nMessage\\n\\n"\n',
'{"name": "Alice", "content": "Hi!\\n", "url": "https://xxx.png"}\n',
'{"name": "Alice", "content": "Hi!\\n", "url": "https://xxx.png'
'"}\n',
'{"name": "Alice", "url": "https://xxx.png"}\n',
'{"abc": 1}\n'
'{"abc": 1}\n',
]

self.assertListEqual(lines, ground_truth)

def tearDown(self):
def tearDown(self) -> None:
"""Tear down for LoggerTest."""
shutil.rmtree("./runs/")

Expand Down

0 comments on commit 7824172

Please sign in to comment.