Skip to content

Commit

Permalink
modelwrapper supports automatic regsitration and uses modelresponse t…
Browse files Browse the repository at this point in the history
…o decouple agent and modelwrapper
  • Loading branch information
pan-x-c committed Feb 2, 2024
1 parent 70f6d8c commit 3b5aff1
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 46 deletions.
51 changes: 35 additions & 16 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
""" Import modules in models package."""
import json
from typing import Union, Sequence
from typing import Union, Type

from loguru import logger

from .model import ModelWrapperBase
from .post_model import PostApiModelWrapper
from .model import ModelWrapperBase, ModelResponse
from .post_model import PostAPIModelWrapperBase
from .openai_model import (
OpenAIWrapper,
OpenAIChatWrapper,
Expand All @@ -17,7 +17,8 @@

__all__ = [
"ModelWrapperBase",
"PostApiModelWrapper",
"ModelResponse",
"PostAPIModelWrapperBase",
"OpenAIWrapper",
"OpenAIChatWrapper",
"OpenAIDALLEWrapper",
Expand All @@ -31,6 +32,35 @@


_MODEL_CONFIGS = []
_MODEL_MAP: dict[str, Type[ModelWrapperBase]] = {
"openai": OpenAIChatWrapper,
"openai_dall_e": OpenAIDALLEWrapper,
"openai_embedding": OpenAIEmbeddingWrapper,
"post_api": PostAPIModelWrapperBase,
}


def get_model(model_type: str) -> Type[ModelWrapperBase]:
"""Get the specific type of model wrapper
Args:
model_type (`str`): The model type name.
Returns:
`Type[ModelWrapperBase]`: The corresponding model wrapper class.
"""
if model_type in _MODEL_MAP:
return _MODEL_MAP[model_type]
elif model_type in ModelWrapperBase.registry:
return ModelWrapperBase.registry[ # type: ignore [return-value]
model_type
]
else:
logger.warning(
f"Unsupported model_type [{model_type}],"
"use PostApiModelWrapper instead.",
)
return PostAPIModelWrapperBase


def load_model_by_name(model_name: str) -> ModelWrapperBase:
Expand All @@ -54,18 +84,7 @@ def load_model_by_name(model_name: str) -> ModelWrapperBase:
)

model_type = config.pop("type")
if model_type == "openai":
return OpenAIChatWrapper(**config)
elif model_type == "openai_dall_e":
return OpenAIDALLEWrapper(**config)
elif model_type == "openai_embedding":
return OpenAIEmbeddingWrapper(**config)
elif model_type == "post_api":
return PostApiModelWrapper(**config)
else:
raise ValueError(
f"Cannot find [{config['type']}] in loaded configurations.",
)
return get_model(model_type=model_type)(**config)


def clear_model_configs() -> None:
Expand Down
51 changes: 49 additions & 2 deletions src/agentscope/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import time
from abc import ABCMeta
from functools import wraps
from typing import Union, Any, Callable
from typing import Sequence, Any, Callable

from loguru import logger

Expand Down Expand Up @@ -146,6 +146,53 @@ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
attrs["__call__"] = _response_parse_decorator(attrs["__call__"])
return super().__new__(mcs, name, bases, attrs)

def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
if not hasattr(cls, "registry"):
cls.registry = {}
else:
cls.registry[name] = cls
super().__init__(name, bases, attrs)


class ModelResponse:
"""Encapsulation of data returned by the model.
The main purpose of this class is to align the return formats of different
models and act as a bridge between models and agents.
"""

def __init__(
self,
text: str = None,
embedding: Sequence = None,
image_urls: Sequence[str] = None,
raw: dict = None,
) -> None:
self._text = text
self._embedding = embedding
self._image_urls = image_urls
self._raw = raw

@property
def text(self) -> str:
"""Text field."""
return self._text

@property
def embedding(self) -> Sequence:
"""Embedding field."""
return self._embedding

@property
def image_urls(self) -> Sequence[str]:
"""Image URLs field."""
return self._image_urls

@property
def raw(self) -> dict:
"""Raw dictionary field."""
return self._raw


class ModelWrapperBase(metaclass=_ModelWrapperMeta):
"""The base class for model wrapper."""
Expand All @@ -166,7 +213,7 @@ def __init__(
"""
self.name = name

def __call__(self, *args: Any, **kwargs: Any) -> Union[str, dict, list]:
def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Processing input with the model."""
raise NotImplementedError(
f"Model Wrapper [{type(self).__name__}]"
Expand Down
30 changes: 16 additions & 14 deletions src/agentscope/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from loguru import logger

from .model import ModelWrapperBase
from .model import ModelWrapperBase, ModelResponse
from ..file_manager import file_manager

try:
Expand Down Expand Up @@ -138,7 +138,7 @@ def __call__(
messages: list,
return_raw: bool = False,
**kwargs: Any,
) -> Union[str, dict]:
) -> ModelResponse:
"""Processes a list of messages to construct a payload for the OpenAI
API call. It then makes a request to the OpenAI API and returns the
response. This method also updates monitoring metrics based on the
Expand Down Expand Up @@ -218,9 +218,9 @@ def __call__(

# step6: return raw response if needed
if return_raw:
return response.model_dump()
return ModelResponse(raw=response.model_dump())
else:
return response.choices[0].message.content
return ModelResponse(text=response.choices[0].message.content)


class OpenAIDALLEWrapper(OpenAIWrapper):
Expand Down Expand Up @@ -250,7 +250,7 @@ def __call__(
return_raw: bool = False,
save_local: bool = False,
**kwargs: Any,
) -> Union[dict, list[str]]:
) -> ModelResponse:
"""
Args:
prompt (`str`):
Expand Down Expand Up @@ -313,18 +313,16 @@ def __call__(

# step4: return raw response if needed
if return_raw:
return response
return ModelResponse(raw=response.model_dump())
else:
images = response.model_dump()["data"]
# Get image urls as a list
urls = [_["url"] for _ in images]

if save_local:
# Return local url if save_local is True
local_urls = [file_manager.save_image(_) for _ in urls]
return local_urls
else:
return urls
urls = [file_manager.save_image(_) for _ in urls]
return ModelResponse(image_urls=urls)


class OpenAIEmbeddingWrapper(OpenAIWrapper):
Expand All @@ -348,7 +346,7 @@ def __call__(
texts: Union[list[str], str],
return_raw: bool = False,
**kwargs: Any,
) -> Union[list, dict]:
) -> ModelResponse:
"""Embed the messages with OpenAI embedding API.
Args:
Expand Down Expand Up @@ -402,9 +400,13 @@ def __call__(
# step4: return raw response if needed
response_json = response.model_dump()
if return_raw:
return response_json
return ModelResponse(raw=response_json)
else:
if len(response_json["data"]) == 0:
return response_json["data"]["embedding"][0]
return ModelResponse(
embedding=response_json["data"]["embedding"][0],
)
else:
return [_["embedding"] for _ in response_json["data"]]
return ModelResponse(
embedding=[_["embedding"] for _ in response_json["data"]],
)
53 changes: 46 additions & 7 deletions src/agentscope/models/post_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import requests
from loguru import logger

from .model import ModelWrapperBase
from .model import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_MESSAGES_KEY
from ..constants import _DEFAULT_RETRY_INTERVAL


class PostApiModelWrapper(ModelWrapperBase):
"""The model wrapper for the model deployed on the POST API."""
class PostAPIModelWrapperBase(ModelWrapperBase):
"""The base model wrapper for the model deployed on the POST API."""

def __init__(
self,
Expand Down Expand Up @@ -82,7 +82,11 @@ def __init__(
self.messages_key = messages_key
self.retry_interval = retry_interval

def __call__(self, input_: str, **kwargs: Any) -> dict:
def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)

def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
"""Calling the model with requests.post.
Args:
Expand Down Expand Up @@ -143,12 +147,47 @@ def __call__(self, input_: str, **kwargs: Any) -> dict:

# step4: parse the response
if response.status_code == requests.codes.ok:
return response.json()["data"]["response"]["choices"][0][
"message"
]["content"]
return self._parse_response(response.json())
else:
logger.error(json.dumps(request_kwargs, indent=4))
raise RuntimeError(
f"Failed to call the model with "
f"requests.codes == {response.status_code}",
)


class PostAPIChatWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatilble with openai chat"""

def __init__(
self,
name: str,
api_url: str,
headers: dict = None,
max_length: int = 2048,
timeout: int = 30,
json_args: dict = None,
post_args: dict = None,
max_retries: int = _DEFAULT_MAX_RETRIES,
messages_key: str = _DEFAULT_MESSAGES_KEY,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
) -> None:
super().__init__(
name=name,
api_url=api_url,
headers=headers,
max_length=max_length,
timeout=timeout,
json_args=json_args,
post_args=post_args,
max_retries=max_retries,
messages_key=messages_key,
retry_interval=retry_interval,
)

def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response.json()["data"]["response"]["choices"][0]["message"][
"content"
],
)
8 changes: 4 additions & 4 deletions tests/prompt_engine_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""Unit test for prompt engine."""
import unittest
from typing import Union, Any
from typing import Any

from agentscope.models import read_model_configs
from agentscope.models import load_model_by_name
from agentscope.models import OpenAIWrapper
from agentscope.models import ModelResponse, OpenAIWrapper
from agentscope.prompt import PromptEngine


Expand Down Expand Up @@ -62,8 +62,8 @@ def __call__(
self,
*args: Any,
**kwargs: Any,
) -> Union[str, dict, list]:
return ""
) -> ModelResponse:
return ModelResponse(text="")

def _register_default_metrics(self) -> None:
pass
Expand Down
6 changes: 3 additions & 3 deletions tests/retrieval_from_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from agentscope.service.service_status import ServiceExecStatus
from agentscope.message import MessageBase, Msg, Tht
from agentscope.memory.temporary_memory import TemporaryMemory
from agentscope.models import OpenAIEmbeddingWrapper
from agentscope.models import OpenAIEmbeddingWrapper, ModelResponse


class TestRetrieval(unittest.TestCase):
Expand All @@ -25,9 +25,9 @@ class DummyModel(OpenAIEmbeddingWrapper):
def __init__(self) -> None:
pass

def __call__(self, *args: Any, **kwargs: Any) -> dict:
def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse:
print(*args, **kwargs)
return {}
return ModelResponse(raw={})

dummy_model = DummyModel()

Expand Down

0 comments on commit 3b5aff1

Please sign in to comment.