From a735a4e596a27a3afa889494c87e9d08c27f2e87 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Thu, 28 Mar 2024 16:06:00 +0100 Subject: [PATCH] feat: supported tokenize and truncate_prompt endpoints (#69) --- .flake8 | 3 +- README.md | 25 +++- aidial_adapter_vertexai/chat/bison/base.py | 3 - .../chat/chat_completion_adapter.py | 22 ++-- .../chat/gemini/adapter.py | 9 -- aidial_adapter_vertexai/chat_completion.py | 110 +++++++++++++++++- .../utils/not_implemented.py | 7 ++ client/chat/adapter.py | 4 +- tests/unit_tests/conftest.py | 2 +- tests/unit_tests/test_endpoints.py | 51 ++++++++ 10 files changed, 201 insertions(+), 35 deletions(-) create mode 100644 aidial_adapter_vertexai/utils/not_implemented.py create mode 100644 tests/unit_tests/test_endpoints.py diff --git a/.flake8 b/.flake8 index 35a678a..908dd96 100644 --- a/.flake8 +++ b/.flake8 @@ -2,7 +2,8 @@ # E501 line is too long # W503 line break before binary operator # E203 whitespace before ':' -ignore = E501, W503, E203 +# E704 multiple statements on one line (def) +ignore = E501, W503, E203, E704 exclude = .venv, .nox, diff --git a/README.md b/README.md index 44ba606..18f4c1e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,30 @@ The project implements [AI DIAL API](https://epam-rail.com/dial_api) for language models and embeddings from [Vertex AI](https://console.cloud.google.com/vertex-ai). -Find the list of supported models in [the source code](./aidial_adapter_vertexai/deployments.py). +## Supported models + +The following models support `POST SERVER_URL/openai/deployments/MODEL_NAME/chat/completions` endpoint along with optional support of `/tokenize` and `/truncate_prompt` endpoints: + +|Model|Modality|`/tokenize`|`/truncate_prompt`| +|---|---|---|---| +|chat-bison@001|text-to-text|✅|✅| +|chat-bison@002|text-to-text|✅|✅| +|chat-bison-32k@002|text-to-text|✅|✅| +|codechat-bison@001|text-to-text|✅|✅| +|codechat-bison@002|text-to-text|✅|✅| +|codechat-bison-32k@002|text-to-text|✅|✅| +|codechat-bison-32k@002|text-to-text|✅|✅| +|gemini-pro|text-to-text|✅|❌| +|gemini-pro-vision|text-to-text, image-to-text|✅|❌| +|imagegeneration@005|text-to-image|✅|✅| + +The models that support `/truncate_prompt` do also support `max_prompt_tokens` request parameter. + +The following models support `SERVER_URL/openai/deployments/MODEL_NAME/embeddings` endpoint: + +|Model|Modality| +|---|---| +|textembedding-gecko@001|text-to-embedding| ## Developer environment diff --git a/aidial_adapter_vertexai/chat/bison/base.py b/aidial_adapter_vertexai/chat/bison/base.py index 4cad67c..d1d92ee 100644 --- a/aidial_adapter_vertexai/chat/bison/base.py +++ b/aidial_adapter_vertexai/chat/bison/base.py @@ -43,9 +43,6 @@ async def parse_prompt(self, messages: List[Message]) -> BisonPrompt: async def truncate_prompt( self, prompt: BisonPrompt, max_prompt_tokens: int ) -> Tuple[BisonPrompt, List[int]]: - if max_prompt_tokens is None: - return prompt, [] - return await get_discarded_messages(self, prompt, max_prompt_tokens) @override diff --git a/aidial_adapter_vertexai/chat/chat_completion_adapter.py b/aidial_adapter_vertexai/chat/chat_completion_adapter.py index a9a2727..1bb48ff 100644 --- a/aidial_adapter_vertexai/chat/chat_completion_adapter.py +++ b/aidial_adapter_vertexai/chat/chat_completion_adapter.py @@ -6,6 +6,7 @@ from aidial_adapter_vertexai.chat.consumer import Consumer from aidial_adapter_vertexai.chat.errors import UserError from aidial_adapter_vertexai.dial_api.request import ModelParameters +from aidial_adapter_vertexai.utils.not_implemented import not_implemented P = TypeVar("P") @@ -15,22 +16,19 @@ class ChatCompletionAdapter(ABC, Generic[P]): async def parse_prompt(self, messages: List[Message]) -> P | UserError: pass - @abstractmethod - async def truncate_prompt( - self, prompt: P, max_prompt_tokens: int - ) -> Tuple[P, List[int]]: - pass - @abstractmethod async def chat( self, params: ModelParameters, consumer: Consumer, prompt: P ) -> None: pass - @abstractmethod - async def count_prompt_tokens(self, prompt: P) -> int: - pass + @not_implemented + async def truncate_prompt( + self, prompt: P, max_prompt_tokens: int + ) -> Tuple[P, List[int]]: ... - @abstractmethod - async def count_completion_tokens(self, string: str) -> int: - pass + @not_implemented + async def count_prompt_tokens(self, prompt: P) -> int: ... + + @not_implemented + async def count_completion_tokens(self, string: str) -> int: ... diff --git a/aidial_adapter_vertexai/chat/gemini/adapter.py b/aidial_adapter_vertexai/chat/gemini/adapter.py index f075c5b..58830fe 100644 --- a/aidial_adapter_vertexai/chat/gemini/adapter.py +++ b/aidial_adapter_vertexai/chat/gemini/adapter.py @@ -5,7 +5,6 @@ Dict, List, Optional, - Tuple, TypeVar, assert_never, ) @@ -88,14 +87,6 @@ async def parse_prompt( else: return GeminiPrompt.parse_non_vision(messages) - @override - async def truncate_prompt( - self, prompt: GeminiPrompt, max_prompt_tokens: int - ) -> Tuple[GeminiPrompt, List[int]]: - raise NotImplementedError( - "Prompt truncation is not supported for Genimi model yet" - ) - async def send_message_async( self, params: ModelParameters, prompt: GeminiPrompt ) -> AsyncIterator[GenerationResponse]: diff --git a/aidial_adapter_vertexai/chat_completion.py b/aidial_adapter_vertexai/chat_completion.py index f7eed58..1b1902b 100644 --- a/aidial_adapter_vertexai/chat_completion.py +++ b/aidial_adapter_vertexai/chat_completion.py @@ -1,7 +1,27 @@ import asyncio -from typing import List +from typing import List, assert_never +from aidial_sdk import HTTPException as DialException from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status +from aidial_sdk.chat_completion.request import ChatCompletionRequest +from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin +from aidial_sdk.deployment.tokenize import ( + TokenizeError, + TokenizeInputRequest, + TokenizeInputString, + TokenizeOutput, + TokenizeRequest, + TokenizeResponse, + TokenizeSuccess, +) +from aidial_sdk.deployment.truncate_prompt import ( + TruncatePromptError, + TruncatePromptRequest, + TruncatePromptResponse, + TruncatePromptResult, + TruncatePromptSuccess, +) +from typing_extensions import override from aidial_adapter_vertexai.adapters import get_chat_completion_model from aidial_adapter_vertexai.chat.chat_completion_adapter import ( @@ -14,6 +34,7 @@ from aidial_adapter_vertexai.dial_api.request import ModelParameters from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage from aidial_adapter_vertexai.utils.log_config import app_logger as log +from aidial_adapter_vertexai.utils.not_implemented import is_implemented class VertexAIChatCompletion(ChatCompletion): @@ -24,16 +45,19 @@ def __init__(self, region: str, project_id: str): self.region = region self.project_id = project_id - @dial_exception_decorator - async def chat_completion(self, request: Request, response: Response): - headers = request.headers - model: ChatCompletionAdapter = await get_chat_completion_model( + async def get_model( + self, request: FromRequestDeploymentMixin + ) -> ChatCompletionAdapter: + return await get_chat_completion_model( deployment=ChatCompletionDeployment(request.deployment_id), project_id=self.project_id, location=self.region, - headers=headers, + headers=request.headers, ) + @dial_exception_decorator + async def chat_completion(self, request: Request, response: Response): + model = await self.get_model(request) prompt = await model.parse_prompt(request.messages) if isinstance(prompt, UserError): @@ -86,3 +110,77 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None: if params.max_prompt_tokens is not None: response.set_discarded_messages(discarded_messages) + + @override + async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: + model = await self.get_model(request) + + if not is_implemented( + model.count_completion_tokens + ) or not is_implemented(model.count_prompt_tokens): + raise DialException(status_code=404, message="Not found") + + outputs: List[TokenizeOutput] = [] + for input in request.inputs: + match input: + case TokenizeInputRequest(): + outputs.append( + await self.tokenize_request(model, input.value) + ) + case TokenizeInputString(): + outputs.append( + await self.tokenize_string(model, input.value) + ) + case _: + assert_never(input.type) + return TokenizeResponse(outputs=outputs) + + async def tokenize_string( + self, model: ChatCompletionAdapter, value: str + ) -> TokenizeOutput: + try: + tokens = await model.count_completion_tokens(value) + return TokenizeSuccess(token_count=tokens) + except Exception as e: + return TokenizeError(error=str(e)) + + async def tokenize_request( + self, model: ChatCompletionAdapter, request: ChatCompletionRequest + ) -> TokenizeOutput: + try: + prompt = await model.parse_prompt(request.messages) + if isinstance(prompt, UserError): + raise prompt + + token_count = await model.count_prompt_tokens(prompt) + return TokenizeSuccess(token_count=token_count) + except Exception as e: + return TokenizeError(error=str(e)) + + @override + async def truncate_prompt( + self, request: TruncatePromptRequest + ) -> TruncatePromptResponse: + model = await self.get_model(request) + + if not is_implemented(model.truncate_prompt): + raise DialException(status_code=404, message="Not found") + + outputs: List[TruncatePromptResult] = [] + for input in request.inputs: + outputs.append(await self.truncate_prompt_request(model, input)) + return TruncatePromptResponse(outputs=outputs) + + async def truncate_prompt_request( + self, model: ChatCompletionAdapter, request: ChatCompletionRequest + ) -> TruncatePromptResult: + try: + if request.max_prompt_tokens is None: + raise ValueError("max_prompt_tokens is required") + + _prompt, discarded_messages = await model.truncate_prompt( + request.messages, request.max_prompt_tokens + ) + return TruncatePromptSuccess(discarded_messages=discarded_messages) + except Exception as e: + return TruncatePromptError(error=str(e)) diff --git a/aidial_adapter_vertexai/utils/not_implemented.py b/aidial_adapter_vertexai/utils/not_implemented.py new file mode 100644 index 0000000..8c9749f --- /dev/null +++ b/aidial_adapter_vertexai/utils/not_implemented.py @@ -0,0 +1,7 @@ +def not_implemented(func): + setattr(func, "_not_implemented", True) + return func + + +def is_implemented(method): + return not getattr(method, "_not_implemented", False) diff --git a/client/chat/adapter.py b/client/chat/adapter.py index 6126318..ab5c7f9 100644 --- a/client/chat/adapter.py +++ b/client/chat/adapter.py @@ -53,8 +53,8 @@ async def task(on_content): prompt = await self.model.parse_prompt(self.history) if isinstance(prompt, UserError): raise prompt - else: - await self.model.chat(params, consumer, prompt) + + await self.model.chat(params, consumer, prompt) async def on_content(chunk: str): return diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 0c196ea..10b02ae 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -5,5 +5,5 @@ def pytest_configure(config): """ Setting up fake environment variables for unit tests. """ - os.environ["DEFAULT_REGION"] = "dummy_region" + os.environ["DEFAULT_REGION"] = "us-central1" os.environ["GCP_PROJECT_ID"] = "dummy_project_id" diff --git a/tests/unit_tests/test_endpoints.py b/tests/unit_tests/test_endpoints.py new file mode 100644 index 0000000..b0ae89a --- /dev/null +++ b/tests/unit_tests/test_endpoints.py @@ -0,0 +1,51 @@ +from typing import List, Tuple + +import pytest +import requests + +from aidial_adapter_vertexai.deployments import ChatCompletionDeployment +from tests.conftest import TEST_SERVER_URL + +test_cases: List[Tuple[ChatCompletionDeployment, bool, bool]] = [ + (ChatCompletionDeployment.CHAT_BISON_1, True, True), + (ChatCompletionDeployment.CHAT_BISON_2, True, True), + (ChatCompletionDeployment.CHAT_BISON_2_32K, True, True), + (ChatCompletionDeployment.CODECHAT_BISON_1, True, True), + (ChatCompletionDeployment.CODECHAT_BISON_2, True, True), + (ChatCompletionDeployment.CODECHAT_BISON_2_32K, True, True), + (ChatCompletionDeployment.GEMINI_PRO_1, True, False), + (ChatCompletionDeployment.GEMINI_PRO_VISION_1, True, False), + (ChatCompletionDeployment.IMAGEN_005, True, True), +] + + +def feature_test_helper( + url: str, is_supported: bool, headers: dict, payload: dict +) -> None: + response = requests.post(url, json=payload, headers=headers) + assert ( + response.status_code != 404 + ) == is_supported, ( + f"is_supported={is_supported}, code={response.status_code}, url={url}" + ) + + +@pytest.mark.parametrize( + "deployment, tokenize_supported, truncate_supported", test_cases +) +def test_model_features( + server, + deployment: ChatCompletionDeployment, + tokenize_supported: bool, + truncate_supported: bool, +): + payload = {"inputs": []} + headers = {"Content-Type": "application/json", "Api-Key": "dummy"} + + BASE_URL = f"{TEST_SERVER_URL}/openai/deployments/{deployment.value}" + + tokenize_url = f"{BASE_URL}/tokenize" + feature_test_helper(tokenize_url, tokenize_supported, headers, payload) + + truncate_url = f"{BASE_URL}/truncate_prompt" + feature_test_helper(truncate_url, truncate_supported, headers, payload)