Skip to content

Commit

Permalink
rename alias to model_type
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Feb 5, 2024
1 parent d4f7f11 commit 5334330
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ For post requests APIs, the config contains the following fields.

```
{
"model_id": "{your_model_id}", # To identify the model instance
"model_id": "{model id}", # To identify the model instance
"model_type": "post_api",
"api_url": "https://xxx", # The target url
"headers": { # Required headers
Expand Down
5 changes: 3 additions & 2 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"ModelWrapperBase",
"ModelResponse",
"PostAPIModelWrapperBase",
"PostAPIChatWrapper",
"OpenAIWrapper",
"OpenAIChatWrapper",
"OpenAIDALLEWrapper",
Expand All @@ -44,8 +45,8 @@ def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]:
Returns:
`Type[ModelWrapperBase]`: The corresponding model wrapper class.
"""
if model_type in ModelWrapperBase.alias_registry:
return ModelWrapperBase.alias_registry[ # type: ignore [return-value]
if model_type in ModelWrapperBase.type_registry:
return ModelWrapperBase.type_registry[ # type: ignore [return-value]
model_type
]
elif model_type in ModelWrapperBase.registry:
Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
Args:
model_id (`str`): The id of the generated model wrapper.
model_type (`str`, optional): The class name (or its alias) of
model_type (`str`, optional): The class name (or its model type) of
the generated model wrapper. Defaults to None.
Raises:
Expand Down
6 changes: 3 additions & 3 deletions src/agentscope/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
if not hasattr(cls, "registry"):
cls.registry = {}
cls.alias_registry = {}
cls.type_registry = {}
else:
cls.registry[name] = cls
if hasattr(cls, "alias"):
cls.alias_registry[cls.alias] = cls
if hasattr(cls, "model_type"):
cls.type_registry[cls.model_type] = cls
super().__init__(name, bases, attrs)


Expand Down
6 changes: 3 additions & 3 deletions src/agentscope/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _metric(self, metric_name: str) -> str:
class OpenAIChatWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI's chat API."""

alias: str = "openai"
model_type: str = "openai"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
Expand Down Expand Up @@ -236,7 +236,7 @@ def __call__(
class OpenAIDALLEWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI's DALL·E API."""

alias: str = "openai_dall_e"
model_type: str = "openai_dall_e"

_resolutions: list = [
"1792*1024",
Expand Down Expand Up @@ -334,7 +334,7 @@ def __call__(
class OpenAIEmbeddingWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI embedding API."""

alias: str = "openai_embedding"
model_type: str = "openai_embedding"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
Expand Down
6 changes: 3 additions & 3 deletions src/agentscope/models/post_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class PostAPIModelWrapperBase(ModelWrapperBase):
"""The base model wrapper for the model deployed on the POST API."""

alias: str = "post_api"
model_type: str = "post_api"

def __init__(
self,
Expand Down Expand Up @@ -174,7 +174,7 @@ class PostAPIChatWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatilble with openai chat, e.g., vLLM,
FastChat."""

alias: str = "post_api_chat"
model_type: str = "post_api_chat"

def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
Expand All @@ -187,7 +187,7 @@ def _parse_response(self, response: dict) -> ModelResponse:
class PostAPIDALLEWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai dalle"""

alias: str = "post_api_dalle"
model_type: str = "post_api_dalle"

def _parse_response(self, response: dict) -> ModelResponse:
urls = [img["url"] for img in response["data"]["response"]["data"]]
Expand Down
2 changes: 1 addition & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_model_registry(self) -> None:
_get_model_wrapper(model_type="TestModelWrapperSimple"),
TestModelWrapperSimple,
)
# get model wrapper class by alias
# get model wrapper class by model type
self.assertEqual(
_get_model_wrapper(model_type="openai"),
OpenAIChatWrapper,
Expand Down

0 comments on commit 5334330

Please sign in to comment.