Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported tokenize and truncate_prompt endpoints #69

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# E501 line is too long
# W503 line break before binary operator
# E203 whitespace before ':'
ignore = E501, W503, E203
# E704 multiple statements on one line (def)
ignore = E501, W503, E203, E704
exclude =
.venv,
.nox,
Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,30 @@

The project implements [AI DIAL API](https://epam-rail.com/dial_api) for language models and embeddings from [Vertex AI](https://console.cloud.google.com/vertex-ai).

Find the list of supported models in [the source code](./aidial_adapter_vertexai/deployments.py).
## Supported models

The following models support `POST SERVER_URL/openai/deployments/MODEL_NAME/chat/completions` endpoint along with optional support of `/tokenize` and `/truncate_prompt` endpoints:

|Model|Modality|`/tokenize`|`/truncate_prompt`|
|---|---|---|---|
|chat-bison@001|text-to-text|✅|✅|
|chat-bison@002|text-to-text|✅|✅|
|chat-bison-32k@002|text-to-text|✅|✅|
|codechat-bison@001|text-to-text|✅|✅|
|codechat-bison@002|text-to-text|✅|✅|
|codechat-bison-32k@002|text-to-text|✅|✅|
|codechat-bison-32k@002|text-to-text|✅|✅|
|gemini-pro|text-to-text|✅|❌|
|gemini-pro-vision|text-to-text, image-to-text|✅|❌|
|imagegeneration@005|text-to-image|✅|✅|

The models that support `/truncate_prompt` do also support `max_prompt_tokens` request parameter.

The following models support `SERVER_URL/openai/deployments/MODEL_NAME/embeddings` endpoint:

|Model|Modality|
|---|---|
|textembedding-gecko@001|text-to-embedding|

## Developer environment

Expand Down
3 changes: 0 additions & 3 deletions aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ async def parse_prompt(self, messages: List[Message]) -> BisonPrompt:
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> Tuple[BisonPrompt, List[int]]:
if max_prompt_tokens is None:
return prompt, []

return await get_discarded_messages(self, prompt, max_prompt_tokens)

@override
Expand Down
22 changes: 10 additions & 12 deletions aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.utils.not_implemented import not_implemented

P = TypeVar("P")

Expand All @@ -15,22 +16,19 @@ class ChatCompletionAdapter(ABC, Generic[P]):
async def parse_prompt(self, messages: List[Message]) -> P | UserError:
pass

@abstractmethod
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[P, List[int]]:
pass

@abstractmethod
async def chat(
self, params: ModelParameters, consumer: Consumer, prompt: P
) -> None:
pass

@abstractmethod
async def count_prompt_tokens(self, prompt: P) -> int:
pass
@not_implemented
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[P, List[int]]: ...

@abstractmethod
async def count_completion_tokens(self, string: str) -> int:
pass
@not_implemented
async def count_prompt_tokens(self, prompt: P) -> int: ...

@not_implemented
async def count_completion_tokens(self, string: str) -> int: ...
9 changes: 0 additions & 9 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Dict,
List,
Optional,
Tuple,
TypeVar,
assert_never,
)
Expand Down Expand Up @@ -88,14 +87,6 @@ async def parse_prompt(
else:
return GeminiPrompt.parse_non_vision(messages)

@override
async def truncate_prompt(
self, prompt: GeminiPrompt, max_prompt_tokens: int
) -> Tuple[GeminiPrompt, List[int]]:
raise NotImplementedError(
"Prompt truncation is not supported for Genimi model yet"
)

async def send_message_async(
self, params: ModelParameters, prompt: GeminiPrompt
) -> AsyncIterator[GenerationResponse]:
Expand Down
110 changes: 104 additions & 6 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import asyncio
from typing import List
from typing import List, assert_never

from aidial_sdk import HTTPException as DialException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
TokenizeError,
TokenizeInputRequest,
TokenizeInputString,
TokenizeOutput,
TokenizeRequest,
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptError,
TruncatePromptRequest,
TruncatePromptResponse,
TruncatePromptResult,
TruncatePromptSuccess,
)
from typing_extensions import override

from aidial_adapter_vertexai.adapters import get_chat_completion_model
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
Expand All @@ -14,6 +34,7 @@
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log
from aidial_adapter_vertexai.utils.not_implemented import is_implemented


class VertexAIChatCompletion(ChatCompletion):
Expand All @@ -24,16 +45,19 @@ def __init__(self, region: str, project_id: str):
self.region = region
self.project_id = project_id

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
headers = request.headers
model: ChatCompletionAdapter = await get_chat_completion_model(
async def get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
return await get_chat_completion_model(
deployment=ChatCompletionDeployment(request.deployment_id),
project_id=self.project_id,
location=self.region,
headers=headers,
headers=request.headers,
)

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model = await self.get_model(request)
prompt = await model.parse_prompt(request.messages)

if isinstance(prompt, UserError):
Expand Down Expand Up @@ -86,3 +110,77 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:

if params.max_prompt_tokens is not None:
response.set_discarded_messages(discarded_messages)

@override
async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
model = await self.get_model(request)

if not is_implemented(
model.count_completion_tokens
) or not is_implemented(model.count_prompt_tokens):
raise DialException(status_code=404, message="Not found")

outputs: List[TokenizeOutput] = []
for input in request.inputs:
match input:
case TokenizeInputRequest():
outputs.append(
await self.tokenize_request(model, input.value)
)
case TokenizeInputString():
outputs.append(
await self.tokenize_string(model, input.value)
)
case _:
assert_never(input.type)
return TokenizeResponse(outputs=outputs)

async def tokenize_string(
self, model: ChatCompletionAdapter, value: str
) -> TokenizeOutput:
try:
tokens = await model.count_completion_tokens(value)
return TokenizeSuccess(token_count=tokens)
except Exception as e:
return TokenizeError(error=str(e))

async def tokenize_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TokenizeOutput:
try:
prompt = await model.parse_prompt(request.messages)
Oleksii-Klimov marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(prompt, UserError):
raise prompt

token_count = await model.count_prompt_tokens(prompt)
return TokenizeSuccess(token_count=token_count)
except Exception as e:
return TokenizeError(error=str(e))

@override
async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
model = await self.get_model(request)

if not is_implemented(model.truncate_prompt):
raise DialException(status_code=404, message="Not found")

outputs: List[TruncatePromptResult] = []
for input in request.inputs:
outputs.append(await self.truncate_prompt_request(model, input))
return TruncatePromptResponse(outputs=outputs)

async def truncate_prompt_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TruncatePromptResult:
try:
if request.max_prompt_tokens is None:
raise ValueError("max_prompt_tokens is required")

_prompt, discarded_messages = await model.truncate_prompt(
request.messages, request.max_prompt_tokens
)
return TruncatePromptSuccess(discarded_messages=discarded_messages)
except Exception as e:
return TruncatePromptError(error=str(e))
7 changes: 7 additions & 0 deletions aidial_adapter_vertexai/utils/not_implemented.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def not_implemented(func):
setattr(func, "_not_implemented", True)
return func


def is_implemented(method):
return not getattr(method, "_not_implemented", False)
4 changes: 2 additions & 2 deletions client/chat/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ async def task(on_content):
prompt = await self.model.parse_prompt(self.history)
if isinstance(prompt, UserError):
raise prompt
else:
await self.model.chat(params, consumer, prompt)

await self.model.chat(params, consumer, prompt)

async def on_content(chunk: str):
return
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def pytest_configure(config):
"""
Setting up fake environment variables for unit tests.
"""
os.environ["DEFAULT_REGION"] = "dummy_region"
os.environ["DEFAULT_REGION"] = "us-central1"
os.environ["GCP_PROJECT_ID"] = "dummy_project_id"
51 changes: 51 additions & 0 deletions tests/unit_tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List, Tuple

import pytest
import requests

from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from tests.conftest import TEST_SERVER_URL

test_cases: List[Tuple[ChatCompletionDeployment, bool, bool]] = [
(ChatCompletionDeployment.CHAT_BISON_1, True, True),
(ChatCompletionDeployment.CHAT_BISON_2, True, True),
(ChatCompletionDeployment.CHAT_BISON_2_32K, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_1, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_2, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_2_32K, True, True),
(ChatCompletionDeployment.GEMINI_PRO_1, True, False),
(ChatCompletionDeployment.GEMINI_PRO_VISION_1, True, False),
(ChatCompletionDeployment.IMAGEN_005, True, True),
]


def feature_test_helper(
url: str, is_supported: bool, headers: dict, payload: dict
) -> None:
response = requests.post(url, json=payload, headers=headers)
assert (
response.status_code != 404
) == is_supported, (
f"is_supported={is_supported}, code={response.status_code}, url={url}"
)


@pytest.mark.parametrize(
"deployment, tokenize_supported, truncate_supported", test_cases
)
def test_model_features(
server,
deployment: ChatCompletionDeployment,
tokenize_supported: bool,
truncate_supported: bool,
):
payload = {"inputs": []}
headers = {"Content-Type": "application/json", "Api-Key": "dummy"}

BASE_URL = f"{TEST_SERVER_URL}/openai/deployments/{deployment.value}"

tokenize_url = f"{BASE_URL}/tokenize"
feature_test_helper(tokenize_url, tokenize_supported, headers, payload)

truncate_url = f"{BASE_URL}/truncate_prompt"
feature_test_helper(truncate_url, truncate_supported, headers, payload)
Loading