diff --git a/README.md b/README.md index 035e13b8b..61a2f5233 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ For post requests APIs, the config contains the following fields. ``` { - "model_id": "{your_model_id}", # To identify the model instance + "model_id": "{model id}", # To identify the model instance "model_type": "post_api", "api_url": "https://xxx", # The target url "headers": { # Required headers diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index 72b8e21d9..36687e828 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -23,6 +23,7 @@ "ModelWrapperBase", "ModelResponse", "PostAPIModelWrapperBase", + "PostAPIChatWrapper", "OpenAIWrapper", "OpenAIChatWrapper", "OpenAIDALLEWrapper", @@ -44,8 +45,8 @@ def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]: Returns: `Type[ModelWrapperBase]`: The corresponding model wrapper class. """ - if model_type in ModelWrapperBase.alias_registry: - return ModelWrapperBase.alias_registry[ # type: ignore [return-value] + if model_type in ModelWrapperBase.type_registry: + return ModelWrapperBase.type_registry[ # type: ignore [return-value] model_type ] elif model_type in ModelWrapperBase.registry: diff --git a/src/agentscope/models/config.py b/src/agentscope/models/config.py index 4b76a4459..93ef06cc9 100644 --- a/src/agentscope/models/config.py +++ b/src/agentscope/models/config.py @@ -22,7 +22,7 @@ def __init__( Args: model_id (`str`): The id of the generated model wrapper. - model_type (`str`, optional): The class name (or its alias) of + model_type (`str`, optional): The class name (or its model type) of the generated model wrapper. Defaults to None. Raises: diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index 43d6b5ebd..ef6b182a3 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -190,11 +190,11 @@ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: if not hasattr(cls, "registry"): cls.registry = {} - cls.alias_registry = {} + cls.type_registry = {} else: cls.registry[name] = cls - if hasattr(cls, "alias"): - cls.alias_registry[cls.alias] = cls + if hasattr(cls, "model_type"): + cls.type_registry[cls.model_type] = cls super().__init__(name, bases, attrs) diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index d72db2145..48dc974d4 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -126,7 +126,7 @@ def _metric(self, metric_name: str) -> str: class OpenAIChatWrapper(OpenAIWrapper): """The model wrapper for OpenAI's chat API.""" - alias: str = "openai" + model_type: str = "openai" def _register_default_metrics(self) -> None: # Set monitor accordingly @@ -236,7 +236,7 @@ def __call__( class OpenAIDALLEWrapper(OpenAIWrapper): """The model wrapper for OpenAI's DALLĀ·E API.""" - alias: str = "openai_dall_e" + model_type: str = "openai_dall_e" _resolutions: list = [ "1792*1024", @@ -334,7 +334,7 @@ def __call__( class OpenAIEmbeddingWrapper(OpenAIWrapper): """The model wrapper for OpenAI embedding API.""" - alias: str = "openai_embedding" + model_type: str = "openai_embedding" def _register_default_metrics(self) -> None: # Set monitor accordingly diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index 4865bc2dd..2a62d4174 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -16,7 +16,7 @@ class PostAPIModelWrapperBase(ModelWrapperBase): """The base model wrapper for the model deployed on the POST API.""" - alias: str = "post_api" + model_type: str = "post_api" def __init__( self, @@ -174,7 +174,7 @@ class PostAPIChatWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatilble with openai chat, e.g., vLLM, FastChat.""" - alias: str = "post_api_chat" + model_type: str = "post_api_chat" def _parse_response(self, response: dict) -> ModelResponse: return ModelResponse( @@ -187,7 +187,7 @@ def _parse_response(self, response: dict) -> ModelResponse: class PostAPIDALLEWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatible with openai dalle""" - alias: str = "post_api_dalle" + model_type: str = "post_api_dalle" def _parse_response(self, response: dict) -> ModelResponse: urls = [img["url"] for img in response["data"]["response"]["data"]] diff --git a/tests/model_test.py b/tests/model_test.py index 753806db1..039736069 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -35,7 +35,7 @@ def test_model_registry(self) -> None: _get_model_wrapper(model_type="TestModelWrapperSimple"), TestModelWrapperSimple, ) - # get model wrapper class by alias + # get model wrapper class by model type self.assertEqual( _get_model_wrapper(model_type="openai"), OpenAIChatWrapper,