diff --git a/.github/workflows/sphinx_docs.yml b/.github/workflows/sphinx_docs.yml index 90a519c53..86750643b 100644 --- a/.github/workflows/sphinx_docs.yml +++ b/.github/workflows/sphinx_docs.yml @@ -11,33 +11,36 @@ on: jobs: pages: - runs-on: ubuntu-latest timeout-minutes: 20 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ['3.9'] env: OS: ${{ matrix.os }} PYTHON: '3.9' steps: - - name: Checkout repository - uses: actions/checkout@master + - uses: actions/checkout@master - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@master with: - python_version: ${{ matrix.python-version }} - - id: deployment + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + pip install -q -e .[full] + - id: build name: Build Documentation - uses: sphinx-notes/pages@v3 - with: - documentation_path: ./docs/sphinx_doc/source - python_version: ${{ matrix.python-version }} - publish: false - requirements_path: ./docs/sphinx_doc/requirements.txt + run: | + cd docs/sphinx_doc + make clean html - name: Upload Documentation uses: actions/upload-artifact@v3 with: name: SphinxDoc - path: ${{ steps.deployment.outputs.artifact }} + path: 'docs/sphinx_doc/build' - uses: peaceiris/actions-gh-pages@v3 - if: github.event_name == 'push' && github_ref == 'refs/heads/main' + if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} with: github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ${{ steps.deployment.outputs.artifact }} \ No newline at end of file + publish_dir: 'docs/sphinx_doc/build/html' \ No newline at end of file diff --git a/docs/sphinx_doc/source/agentscope.agents.rst b/docs/sphinx_doc/source/agentscope.agents.rst index d8c576d1a..0b6520701 100644 --- a/docs/sphinx_doc/source/agentscope.agents.rst +++ b/docs/sphinx_doc/source/agentscope.agents.rst @@ -1,6 +1,14 @@ Agents package ========================== +operator module +------------------------------- + +.. automodule:: agentscope.agents.operator + :members: + :undoc-members: + :show-inheritance: + agent module ------------------------------- @@ -13,6 +21,38 @@ rpc_agent module ------------------------------- .. automodule:: agentscope.agents.rpc_agent + :members: + :undoc-members: + :show-inheritance: + +user_agent module +------------------------------- + +.. automodule:: agentscope.agents.user_agent + :members: + :undoc-members: + :show-inheritance: + +dialog_agent module +------------------------------- + +.. automodule:: agentscope.agents.dialog_agent + :members: + :undoc-members: + :show-inheritance: + +dict_dialog_agent module +------------------------------- + +.. automodule:: agentscope.agents.dict_dialog_agent + :members: + :undoc-members: + :show-inheritance: + +rpc_dialog_agent module +------------------------------- + +.. automodule:: agentscope.agents.dict_dialog_agent :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/src/agentscope/_init.py b/src/agentscope/_init.py index e96f523a6..6f75004b5 100644 --- a/src/agentscope/_init.py +++ b/src/agentscope/_init.py @@ -10,10 +10,12 @@ from ._runtime import Runtime from .file_manager import file_manager from .utils.logging_utils import LOG_LEVEL, setup_logger +from .utils.monitor import MonitorFactory from .models import read_model_configs +from .constants import _DEFAULT_DIR +from .constants import _DEFAULT_LOG_LEVEL + -_DEFAULT_DIR = "./runs" -_DEFAULT_LOG_LEVEL = "INFO" _INIT_SETTINGS = {} @@ -85,6 +87,9 @@ def init( dir_log = str(file_manager.dir_log) if save_log else None setup_logger(dir_log, logger_level) + # Set monitor + _ = MonitorFactory.get_monitor(db_path=file_manager.path_db) + # Load config and init agent by configs if agent_configs is not None: if isinstance(agent_configs, str): diff --git a/src/agentscope/configs/model_config.py b/src/agentscope/configs/model_config.py index 774be726a..452a4d2c7 100644 --- a/src/agentscope/configs/model_config.py +++ b/src/agentscope/configs/model_config.py @@ -3,6 +3,7 @@ from typing import Any from ..constants import _DEFAULT_MAX_RETRIES from ..constants import _DEFAULT_MESSAGES_KEY +from ..constants import _DEFAULT_API_BUDGET class CfgBase(dict): @@ -57,6 +58,9 @@ class OpenAICfg(CfgBase): """The arguments used in openai api generation, e.g. `temperature`, `seed`.""" + budget: float = _DEFAULT_API_BUDGET + """The total budget using this model. Set to `None` means no limit.""" + class PostApiCfg(CfgBase): """The config for Post API. The final request post will be @@ -113,3 +117,6 @@ class PostApiCfg(CfgBase): """The key of the prompt messages in `requests.post()`, e.g. `request.post(json={${messages_key}: messages, **json_args})`. For huggingface and modelscope inference API, the key is `inputs`""" + + budget: float = _DEFAULT_API_BUDGET + """The total budget using this model. Set to `None` means no limit.""" diff --git a/src/agentscope/constants.py b/src/agentscope/constants.py index 997d2998f..5cc9fea61 100644 --- a/src/agentscope/constants.py +++ b/src/agentscope/constants.py @@ -16,13 +16,17 @@ _DEFAULT_SUBDIR_FILE = "file" _DEFAULT_SUBDIR_INVOKE = "invoke" _DEFAULT_IMAGE_NAME = "image_{}_{}.png" +_DEFAULT_SQLITE_DB_PATH = "agentscope.db" # for model wrapper _DEFAULT_MAX_RETRIES = 3 _DEFAULT_MESSAGES_KEY = "inputs" _DEFAULT_RETRY_INTERVAL = 1 +_DEFAULT_API_BUDGET = None # for execute python _DEFAULT_PYPI_MIRROR = "http://mirrors.aliyun.com/pypi/simple/" _DEFAULT_TRUSTED_HOST = "mirrors.aliyun.com" +# for monitor +_DEFAULT_MONITOR_TABLE_NAME = "monitor_metrics" # for summarization _DEFAULT_SUMMARIZATION_PROMPT = """ TEXT: {} diff --git a/src/agentscope/file_manager.py b/src/agentscope/file_manager.py index db73cda40..ad397b8ab 100644 --- a/src/agentscope/file_manager.py +++ b/src/agentscope/file_manager.py @@ -14,6 +14,7 @@ _DEFAULT_SUBDIR_CODE, _DEFAULT_SUBDIR_FILE, _DEFAULT_SUBDIR_INVOKE, + _DEFAULT_SQLITE_DB_PATH, _DEFAULT_IMAGE_NAME, ) @@ -48,6 +49,10 @@ def _get_and_create_subdir(self, subdir: str) -> str: os.makedirs(path) return path + def _get_file_path(self, file_name: str) -> str: + """Get the path of the file.""" + return os.path.join(self.dir, Runtime.runtime_id, file_name) + @property def dir_log(self) -> str: """The directory for saving logs.""" @@ -69,11 +74,17 @@ def dir_invoke(self) -> str: """The directory for saving api invocations.""" return self._get_and_create_subdir(_DEFAULT_SUBDIR_INVOKE) + @property + def path_db(self) -> str: + """The path to the sqlite db file.""" + return self._get_file_path(_DEFAULT_SQLITE_DB_PATH) + def init(self, save_dir: str, save_api_invoke: bool = False) -> None: """Set the directory for saving files.""" self.dir = save_dir - if not os.path.exists(save_dir): - os.makedirs(save_dir) + runtime_dir = os.path.join(save_dir, Runtime.runtime_id) + if not os.path.exists(runtime_dir): + os.makedirs(runtime_dir) self.save_api_invoke = save_api_invoke diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index abe1f6bd8..fc8eff640 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -12,7 +12,8 @@ except ImportError: openai = None -from ..utils import MonitorFactory +from ..utils.monitor import MonitorFactory +from ..utils.monitor import get_full_name from ..utils import QuotaExceededError from ..utils.token_utils import get_openai_max_length @@ -28,6 +29,7 @@ def __init__( organization: str = None, client_args: dict = None, generate_args: dict = None, + budget: float = None, ) -> None: """Initialize the openai client. @@ -49,6 +51,9 @@ def __init__( generate_args (`dict`, default `None`): The extra keyword arguments used in openai api generation, e.g. `temperature`, `seed`. + budget (`float`, default `None`): + The total budget using this model. Set to `None` means no + limit. """ super().__init__(name) @@ -77,8 +82,18 @@ def __init__( # Set monitor accordingly self.monitor = None + self.budget = budget + self._register_budget() self._register_default_metrics() + def _register_budget(self) -> None: + self.monitor = MonitorFactory.get_monitor() + self.monitor.register_budget( + model_name=self.model_name, + value=self.budget, + prefix=self.model_name, + ) + def _register_default_metrics(self) -> None: """Register metrics to the monitor.""" raise NotImplementedError( @@ -95,7 +110,7 @@ def _metric(self, metric_name: str) -> str: Returns: `str`: Metric name of this wrapper. """ - return f"{self.__class__.__name__}.{self.model_name}.{metric_name}" + return get_full_name(name=metric_name, prefix=self.model_name) class OpenAIChatWrapper(OpenAIWrapper): @@ -193,7 +208,10 @@ def __call__( # step5: update monitor accordingly try: - self.monitor.update(**response.usage.model_dump()) + self.monitor.update( + response.usage.model_dump(), + prefix=self.model_name, + ) except QuotaExceededError as e: # TODO: optimize quota exceeded error handling process logger.error(e.message) diff --git a/src/agentscope/utils/monitor.py b/src/agentscope/utils/monitor.py index 4acdf82c0..48900648f 100644 --- a/src/agentscope/utils/monitor.py +++ b/src/agentscope/utils/monitor.py @@ -2,12 +2,18 @@ """ Monitor for agentscope """ import re -import copy +import sqlite3 from abc import ABC from abc import abstractmethod -from typing import Optional, Any +from contextlib import contextmanager +from typing import Optional, Generator from loguru import logger +from agentscope.constants import ( + _DEFAULT_MONITOR_TABLE_NAME, + _DEFAULT_SQLITE_DB_PATH, +) + class MonitorBase(ABC): r"""Base interface of Monitor""" @@ -60,10 +66,10 @@ def add(self, metric_name: str, value: float) -> bool: `bool`: whether the operation success. """ - def update(self, **kwargs: Any) -> None: + def update(self, values: dict, prefix: Optional[str] = None) -> None: """Update multiple metrics at once.""" - for k, v in kwargs.items(): - self.add(k, v) + for k, v in values: + self.add(get_full_name(prefix=prefix, name=k), v) @abstractmethod def clear(self, metric_name: str) -> bool: @@ -182,64 +188,160 @@ def get_metrics(self, filter_regex: Optional[str] = None) -> dict: } """ + @abstractmethod + def register_budget( + self, + model_name: str, + value: float, + prefix: Optional[str] = "local", + ) -> bool: + """Register model call budget to the monitor, the monitor will raise + QuotaExceededError, when budget is exceeded. -class QuotaExceededError(Exception): - """An Exception used to indicate that a certain metric exceeds quota""" + Args: + model_name (`str`): model that requires budget. + value (`float`): the budget value. + prefix (`Optional[str]`, default `None`): used to distinguish + multiple budget registrations. For multiple registrations with + the same `prefix`, only the first time will take effect. + + Returns: + `bool`: whether the operation success. + """ - def __init__(self, metric_name: str, quota: float) -> None: - self.message = f"Metric [{metric_name}] exceed quota [{quota}]" - super().__init__(self.message) +def get_full_name(name: str, prefix: Optional[str] = None) -> str: + """Get the full name of a metric. -def return_false_if_not_exists( # type: ignore [no-untyped-def] - func, -): - """A decorator used to check whether the attribute exists. - It will return False directly without executing the function, - if the metric does not exist. + Args: + metric_name (`str`): name of a metric. + prefix (` Optional[str]`, default `None`): metric prefix. + + Returns: + `str`: the full name of the metric """ + if prefix is None: + return name + else: + return f"{prefix}.{name}" + + +class QuotaExceededError(Exception): + """An Exception used to indicate that a certain metric exceeds quota""" + + def __init__( + self, + name: str, + ) -> None: + """Init a QuotaExceedError instance. + + Args: + name (`str`): name of the metric which exceeds quota. + """ + self.message = f"Metric [{name}] exceeds quota." + self.name = name + super().__init__(self.message) - def inner( - monitor: MonitorBase, - metric_name: str, - *args: tuple, - **kwargs: dict, - ) -> bool: - if not monitor.exists(metric_name): - logger.warning(f"Metric [{metric_name}] not exists.") - return False - return func(monitor, metric_name, *args, **kwargs) - return inner +@contextmanager +def sqlite_transaction(db_path: str) -> Generator: + """Get a sqlite transaction cursor. + Args: + db_path (`str`): path to the sqlite db file -def return_none_if_not_exists( # type: ignore [no-untyped-def] - func, -): - """A decorator used to check whether the attribute exists. - It will return None directly without executing the function, - if the metric does not exist. + Yields: + `Generator`: a cursor with transaction """ + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + try: + conn.execute("BEGIN") + yield cursor + conn.commit() + except Exception as e: + conn.rollback() + raise e + finally: + cursor.close() + conn.close() + + +@contextmanager +def sqlite_cursor(db_path: str) -> Generator: + """Get a sqlite cursor. + + Args: + db_path (`str`): path to the sqlite db file + + Yields: + `Generator`: a cursor + """ + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + try: + yield cursor + finally: + cursor.close() + conn.close() - def inner( # type: ignore [no-untyped-def] - monitor: MonitorBase, - metric_name: str, - *args: tuple, - **kwargs: dict, - ): - if not monitor.exists(metric_name): - logger.warning(f"Metric [{metric_name}] not exists.") - return None - return func(monitor, metric_name, *args, **kwargs) - return inner +class SqliteMonitor(MonitorBase): + """A monitor based on sqlite""" + def __init__( + self, + db_path: str, + table_name: str = _DEFAULT_MONITOR_TABLE_NAME, + drop_exists: bool = False, + ) -> None: + """Initialize a SqliteMonitor. -class DictMonitor(MonitorBase): - """MonitorBase implementation based on dictionary.""" + Args: + db_path (`str`): path to the sqlite db file. + table_name (`str`, optional): the table name used by the monitor. + Defaults to _DEFAULT_MONITOR_TABLE_NAME. + drop_exists (bool, optional): whether to delete the original table + when the table already exists. Defaults to False. + """ + super().__init__() + self.db_path = db_path + self.table_name = table_name + self._create_monitor_table(drop_exists) + logger.info( + f"SqliteMonitor initialization completed at [{self.db_path}]", + ) - def __init__(self) -> None: - self.metrics = {} + def _create_monitor_table(self, drop_exists: bool = False) -> None: + """Internal method to create a table in sqlite3.""" + with sqlite_transaction(self.db_path) as cursor: + if drop_exists: + cursor.execute(f"DROP TABLE IF EXISTS {self.table_name};") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + value REAL NOT NULL, + quota REAL, + unit TEXT + );""", + ) + cursor.execute( + f""" + CREATE TRIGGER IF NOT EXISTS {self.table_name}_quota_exceeded + BEFORE UPDATE ON {self.table_name} + FOR EACH ROW + WHEN OLD.quota is not NULL AND NEW.value > OLD.quota + BEGIN + SELECT RAISE(FAIL, 'QuotaExceeded'); + END; + """, + ) + logger.info(f"Init [{self.table_name}] as the monitor table") + logger.info( + f"Init [{self.table_name}_quota_exceeded] as the monitor trigger", + ) def register( self, @@ -247,85 +349,270 @@ def register( metric_unit: Optional[str] = None, quota: Optional[float] = None, ) -> bool: - if metric_name in self.metrics: - logger.warning(f"Metric [{metric_name}] is already registered.") - return False - self.metrics[metric_name] = { - "value": 0.0, - "unit": metric_unit, - "quota": quota, - } - logger.info( - f"Register metric [{metric_name}] to Monitor with unit " - f"[{metric_unit}] and quota [{quota}]", - ) - return True + with sqlite_transaction(self.db_path) as cursor: + if self._exists(cursor, metric_name): + return False + cursor.execute( + f""" + INSERT INTO {self.table_name} (name, value, quota, unit) + VALUES (?, ?, ?, ?) + """, + (metric_name, 0.0, quota, metric_unit), + ) + logger.info( + f"Register metric [{metric_name}] to SqliteMonitor with unit " + f"[{metric_unit}] and quota [{quota}]", + ) + return True - @return_false_if_not_exists - def add(self, metric_name: str, value: float) -> bool: - self.metrics[metric_name]["value"] += value - if ( - self.metrics[metric_name]["quota"] is not None - and self.metrics[metric_name]["value"] - > self.metrics[metric_name]["quota"] - ): - logger.warning(f"Metric [{metric_name}] quota exceeded.") - raise QuotaExceededError( - metric_name=metric_name, - quota=self.metrics[metric_name]["quota"], + def _add( + self, + cursor: sqlite3.Cursor, + metric_name: str, + value: float, + ) -> None: + try: + cursor.execute( + f""" + UPDATE {self.table_name} + SET value = value + ? + WHERE name = ? + """, + (value, metric_name), ) - return True + except sqlite3.IntegrityError as e: + raise QuotaExceededError(metric_name) from e - def exists(self, metric_name: str) -> bool: - return metric_name in self.metrics + def add(self, metric_name: str, value: float) -> bool: + with sqlite_transaction(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return False + self._add(cursor, metric_name, value) + return True - @return_false_if_not_exists def clear(self, metric_name: str) -> bool: - self.metrics[metric_name]["value"] = 0.0 - return True + with sqlite_transaction(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return False + cursor.execute( + f""" + UPDATE {self.table_name} + SET value = value + ? + WHERE name = ? + """, + (0.0, metric_name), + ) + return True - @return_false_if_not_exists def remove(self, metric_name: str) -> bool: - self.metrics.pop(metric_name) - logger.info(f"Remove metric [{metric_name}] from monitor.") + with sqlite_transaction(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return False + cursor.execute( + f""" + DELETE FROM {self.table_name} + WHERE name = ?""", + (metric_name,), + ) return True - @return_none_if_not_exists - def get_value(self, metric_name: str) -> Optional[float]: - if metric_name not in self.metrics: - return None - return self.metrics[metric_name]["value"] + def _get_metric(self, cursor: sqlite3.Cursor, metric_name: str) -> dict: + cursor.execute( + f""" + SELECT value, quota, unit FROM {self.table_name} + WHERE name = ?""", + (metric_name,), + ) + row = cursor.fetchone() + if row: + value, quota, unit = row + return { + "value": value, + "quota": quota, + "unit": unit, + } + else: + raise RuntimeError(f"Fail to get metric {metric_name}") - @return_none_if_not_exists - def get_unit(self, metric_name: str) -> Optional[str]: - if metric_name not in self.metrics: - return None - return self.metrics[metric_name]["unit"] + def get_value(self, metric_name: str) -> Optional[float]: + with sqlite_cursor(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return None + metric = self._get_metric(cursor, metric_name) + return metric["value"] - @return_none_if_not_exists def get_quota(self, metric_name: str) -> Optional[float]: - return self.metrics[metric_name]["quota"] + with sqlite_cursor(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return None + metric = self._get_metric(cursor, metric_name) + return metric["quota"] - @return_false_if_not_exists def set_quota(self, metric_name: str, quota: float) -> bool: - self.metrics[metric_name]["quota"] = quota - return True + with sqlite_transaction(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return False + cursor.execute( + f""" + UPDATE {self.table_name} + SET quota = ? + WHERE name = ? + """, + (quota, metric_name), + ) + return True + + def get_unit(self, metric_name: str) -> Optional[str]: + with sqlite_cursor(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return None + metric = self._get_metric(cursor, metric_name) + return metric["unit"] - @return_none_if_not_exists def get_metric(self, metric_name: str) -> Optional[dict]: - return copy.deepcopy(self.metrics[metric_name]) + with sqlite_cursor(self.db_path) as cursor: + if not self._exists(cursor, metric_name): + return None + return self._get_metric(cursor, metric_name) def get_metrics(self, filter_regex: Optional[str] = None) -> dict: + with sqlite_cursor(self.db_path) as cursor: + cursor.execute(f"SELECT * FROM {self.table_name}") + rows = cursor.fetchall() + metrics = { + row[1]: { + "value": row[2], + "quota": row[3], + "unit": row[4], + } + for row in rows + } if filter_regex is None: - return copy.deepcopy(self.metrics) + return metrics else: pattern = re.compile(filter_regex) return { - key: copy.deepcopy(value) - for key, value in self.metrics.items() + key: value + for key, value in metrics.items() if pattern.search(key) } + def _exists(self, cursor: sqlite3.Cursor, name: str) -> bool: + cursor.execute( + f""" + SELECT 1 FROM {self.table_name} + WHERE name = ? LIMIT 1 + """, + (name,), + ) + return cursor.fetchone() is not None + + def exists(self, metric_name: str) -> bool: + with sqlite_cursor(self.db_path) as cursor: + return self._exists(cursor, metric_name) + + def update(self, values: dict, prefix: Optional[str] = None) -> None: + with sqlite_transaction(self.db_path) as cursor: + for metric_name, value in values.items(): + self._add( + cursor, + get_full_name( + name=metric_name, + prefix=prefix, + ), + value, + ) + + def _create_update_cost_trigger( + self, + token_metric: str, + cost_metric: str, + unit_price: float, + ) -> None: + with sqlite_transaction(self.db_path) as cursor: + cursor.execute( + f""" + CREATE TRIGGER IF NOT EXISTS + "{self.table_name}_{token_metric}_{cost_metric}_price" + AFTER UPDATE OF value ON "{self.table_name}" + FOR EACH ROW + WHEN NEW.name = "{token_metric}" + BEGIN + UPDATE {self.table_name} + SET value = value + (NEW.value - OLD.value) * {unit_price} + WHERE name = "{cost_metric}"; + END; + """, + ) + + def register_budget( + self, + model_name: str, + value: float, + prefix: Optional[str] = None, + ) -> bool: + logger.info(f"set budget {value} to {model_name}") + pricing = _get_pricing() + if model_name in pricing: + budget_metric_name = get_full_name( + name="cost", + prefix=prefix, + ) + ok = self.register( + metric_name=budget_metric_name, + metric_unit="dollor", + quota=value, + ) + if not ok: + return False + for metric_name, unit_price in pricing[model_name].items(): + token_metric_name = get_full_name( + name=metric_name, + prefix=prefix, + ) + self.register( + metric_name=token_metric_name, + metric_unit="token", + ) + self._create_update_cost_trigger( + token_metric_name, + budget_metric_name, + unit_price, + ) + return True + else: + logger.warning( + f"Calculate budgets for model [{model_name}] is not supported", + ) + return False + + +def _get_pricing() -> dict: + """Get pricing as a dict + + Returns: + `dict`: the dict with pricing information. + """ + # TODO: get pricing from files + return { + "gpt-4-turbo": { + "prompt_tokens": 0.00001, + "completion_tokens": 0.00003, + }, + "gpt-4": { + "prompt_tokens": 0.00003, + "completion_tokens": 0.00006, + }, + "gpt-4-32k": { + "prompt_tokens": 0.00006, + "completion_tokens": 0.00012, + }, + "gpt-3.5-turbo": { + "prompt_tokens": 0.000001, + "completion_tokens": 0.000002, + }, + } + class MonitorFactory: """Factory of Monitor. @@ -334,24 +621,31 @@ class MonitorFactory: from agentscope.utils import MonitorFactory monitor = MonitorFactory.get_monitor() - """ _instance = None @classmethod - def get_monitor(cls, impl_type: Optional[str] = None) -> MonitorBase: + def get_monitor( + cls, + impl_type: Optional[str] = None, + db_path: str = _DEFAULT_SQLITE_DB_PATH, + ) -> MonitorBase: """Get the monitor instance. + Args: + impl_type (`Optional[str]`, optional): the type of monitor, + currently supports `sqlite` only. + db_path (`Optional[str]`, optional): path to the sqlite db file. + Returns: `MonitorBase`: the monitor instance. """ if cls._instance is None: - # todo: init a specific monitor implementation by input args - if impl_type is None or impl_type.lower() == "dict": - cls._instance = DictMonitor() + if impl_type is None or impl_type.lower() == "sqlite": + cls._instance = SqliteMonitor(db_path=db_path) else: raise NotImplementedError( "Monitor with type [{type}] is not implemented.", ) - return cls._instance + return cls._instance # type: ignore [return-value] diff --git a/tests/monitor_test.py b/tests/monitor_test.py index ffbae1e81..26ac418c8 100644 --- a/tests/monitor_test.py +++ b/tests/monitor_test.py @@ -4,15 +4,20 @@ """ import unittest - +import uuid +import os from agentscope.utils import MonitorBase, QuotaExceededError, MonitorFactory -from agentscope.utils.monitor import DictMonitor +from agentscope.utils.monitor import SqliteMonitor class MonitorFactoryTest(unittest.TestCase): "Test class for MonitorFactory" + def setUp(self) -> None: + self.db_path = f"test-{uuid.uuid4()}.db" + _ = MonitorFactory.get_monitor(db_path=self.db_path) + def test_get_monitor(self) -> None: """Test get monitor method of MonitorFactory.""" monitor1 = MonitorFactory.get_monitor() @@ -25,6 +30,10 @@ def test_get_monitor(self) -> None: self.assertTrue(monitor2.remove("token_num")) self.assertFalse(monitor1.exists("token_num")) + def tearDown(self) -> None: + MonitorFactory._instance = None # pylint: disable=W0212 + os.remove(self.db_path) + class MonitorTestBase(unittest.TestCase): """An abstract test class for MonitorBase interface""" @@ -91,7 +100,7 @@ def test_add_clear_set_quota(self) -> None: self.assertTrue(self.monitor.set_quota("token_num", 200)) # add success and check new value self.assertTrue(self.monitor.add("token_num", 10)) - self.assertEqual(self.monitor.get_value("token_num"), 111) + self.assertEqual(self.monitor.get_value("token_num"), 20) # clear an existing metric self.assertTrue(self.monitor.clear("token_num")) # clear an not existing metric @@ -161,8 +170,64 @@ def test_get(self) -> None: ) -class DictMonitorTest(MonitorTestBase): - """Test class for DictMonitor""" +class SqliteMonitorTest(MonitorTestBase): + """Test class for SqliteMonitor""" def get_monitor_instance(self) -> MonitorBase: - return DictMonitor() + self.db_path = f"test-{uuid.uuid4()}.db" + return SqliteMonitor(self.db_path) + + def tearDown(self) -> None: + os.remove(self.db_path) + + def test_register_budget(self) -> None: + """Test register_budget method of monitor""" + self.assertTrue( + self.monitor.register_budget( + model_name="gpt-4", + value=5, + prefix="agent_A.gpt-4", + ), + ) + # register an existing model with different prefix is ok + self.assertTrue( + self.monitor.register_budget( + model_name="gpt-4", + value=15, + prefix="agent_B.gpt-4", + ), + ) + gpt_4_3d = { + "prompt_tokens": 50000, + "completion_tokens": 25000, + "total_tokens": 750000, + } + # agentA uses 3 dollors + self.monitor.update(gpt_4_3d, prefix="agent_A.gpt-4") + # agentA uses another 3 dollors and exceeds quota + self.assertRaises( + QuotaExceededError, + self.monitor.update, + gpt_4_3d, + "agent_A.gpt-4", + ) + self.assertLess( + self.monitor.get_value( # type: ignore [arg-type] + "agent_A.gpt-4.cost", + ), + 5, + ) + # register an existing model with existing prefix is wrong + self.assertFalse( + self.monitor.register_budget( + model_name="gpt-4", + value=5, + prefix="agent_A.gpt-4", + ), + ) + self.assertEqual( + self.monitor.get_value( # type: ignore [arg-type] + "agent_A.gpt-4.cost", + ), + 3, + )