Skip to content

Commit

Permalink
Merge branch 'development' into dependabot/pip/langchain-core-0.1.35
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Mar 28, 2024
2 parents 8806dd2 + a735a4e commit 62e6b81
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 35 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# E501 line is too long
# W503 line break before binary operator
# E203 whitespace before ':'
ignore = E501, W503, E203
# E704 multiple statements on one line (def)
ignore = E501, W503, E203, E704
exclude =
.venv,
.nox,
Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,30 @@

The project implements [AI DIAL API](https://epam-rail.com/dial_api) for language models and embeddings from [Vertex AI](https://console.cloud.google.com/vertex-ai).

Find the list of supported models in [the source code](./aidial_adapter_vertexai/deployments.py).
## Supported models

The following models support `POST SERVER_URL/openai/deployments/MODEL_NAME/chat/completions` endpoint along with optional support of `/tokenize` and `/truncate_prompt` endpoints:

|Model|Modality|`/tokenize`|`/truncate_prompt`|
|---|---|---|---|
|chat-bison@001|text-to-text|||
|chat-bison@002|text-to-text|||
|chat-bison-32k@002|text-to-text|||
|codechat-bison@001|text-to-text|||
|codechat-bison@002|text-to-text|||
|codechat-bison-32k@002|text-to-text|||
|codechat-bison-32k@002|text-to-text|||
|gemini-pro|text-to-text|||
|gemini-pro-vision|text-to-text, image-to-text|||
|imagegeneration@005|text-to-image|||

The models that support `/truncate_prompt` do also support `max_prompt_tokens` request parameter.

The following models support `SERVER_URL/openai/deployments/MODEL_NAME/embeddings` endpoint:

|Model|Modality|
|---|---|
|textembedding-gecko@001|text-to-embedding|

## Developer environment

Expand Down
3 changes: 0 additions & 3 deletions aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ async def parse_prompt(self, messages: List[Message]) -> BisonPrompt:
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> Tuple[BisonPrompt, List[int]]:
if max_prompt_tokens is None:
return prompt, []

return await get_discarded_messages(self, prompt, max_prompt_tokens)

@override
Expand Down
22 changes: 10 additions & 12 deletions aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.utils.not_implemented import not_implemented

P = TypeVar("P")

Expand All @@ -15,22 +16,19 @@ class ChatCompletionAdapter(ABC, Generic[P]):
async def parse_prompt(self, messages: List[Message]) -> P | UserError:
pass

@abstractmethod
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[P, List[int]]:
pass

@abstractmethod
async def chat(
self, params: ModelParameters, consumer: Consumer, prompt: P
) -> None:
pass

@abstractmethod
async def count_prompt_tokens(self, prompt: P) -> int:
pass
@not_implemented
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[P, List[int]]: ...

@abstractmethod
async def count_completion_tokens(self, string: str) -> int:
pass
@not_implemented
async def count_prompt_tokens(self, prompt: P) -> int: ...

@not_implemented
async def count_completion_tokens(self, string: str) -> int: ...
9 changes: 0 additions & 9 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Dict,
List,
Optional,
Tuple,
TypeVar,
assert_never,
)
Expand Down Expand Up @@ -88,14 +87,6 @@ async def parse_prompt(
else:
return GeminiPrompt.parse_non_vision(messages)

@override
async def truncate_prompt(
self, prompt: GeminiPrompt, max_prompt_tokens: int
) -> Tuple[GeminiPrompt, List[int]]:
raise NotImplementedError(
"Prompt truncation is not supported for Genimi model yet"
)

async def send_message_async(
self, params: ModelParameters, prompt: GeminiPrompt
) -> AsyncIterator[GenerationResponse]:
Expand Down
110 changes: 104 additions & 6 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import asyncio
from typing import List
from typing import List, assert_never

from aidial_sdk import HTTPException as DialException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
TokenizeError,
TokenizeInputRequest,
TokenizeInputString,
TokenizeOutput,
TokenizeRequest,
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptError,
TruncatePromptRequest,
TruncatePromptResponse,
TruncatePromptResult,
TruncatePromptSuccess,
)
from typing_extensions import override

from aidial_adapter_vertexai.adapters import get_chat_completion_model
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
Expand All @@ -14,6 +34,7 @@
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log
from aidial_adapter_vertexai.utils.not_implemented import is_implemented


class VertexAIChatCompletion(ChatCompletion):
Expand All @@ -24,16 +45,19 @@ def __init__(self, region: str, project_id: str):
self.region = region
self.project_id = project_id

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
headers = request.headers
model: ChatCompletionAdapter = await get_chat_completion_model(
async def get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
return await get_chat_completion_model(
deployment=ChatCompletionDeployment(request.deployment_id),
project_id=self.project_id,
location=self.region,
headers=headers,
headers=request.headers,
)

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model = await self.get_model(request)
prompt = await model.parse_prompt(request.messages)

if isinstance(prompt, UserError):
Expand Down Expand Up @@ -86,3 +110,77 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:

if params.max_prompt_tokens is not None:
response.set_discarded_messages(discarded_messages)

@override
async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
model = await self.get_model(request)

if not is_implemented(
model.count_completion_tokens
) or not is_implemented(model.count_prompt_tokens):
raise DialException(status_code=404, message="Not found")

outputs: List[TokenizeOutput] = []
for input in request.inputs:
match input:
case TokenizeInputRequest():
outputs.append(
await self.tokenize_request(model, input.value)
)
case TokenizeInputString():
outputs.append(
await self.tokenize_string(model, input.value)
)
case _:
assert_never(input.type)
return TokenizeResponse(outputs=outputs)

async def tokenize_string(
self, model: ChatCompletionAdapter, value: str
) -> TokenizeOutput:
try:
tokens = await model.count_completion_tokens(value)
return TokenizeSuccess(token_count=tokens)
except Exception as e:
return TokenizeError(error=str(e))

async def tokenize_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TokenizeOutput:
try:
prompt = await model.parse_prompt(request.messages)
if isinstance(prompt, UserError):
raise prompt

token_count = await model.count_prompt_tokens(prompt)
return TokenizeSuccess(token_count=token_count)
except Exception as e:
return TokenizeError(error=str(e))

@override
async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
model = await self.get_model(request)

if not is_implemented(model.truncate_prompt):
raise DialException(status_code=404, message="Not found")

outputs: List[TruncatePromptResult] = []
for input in request.inputs:
outputs.append(await self.truncate_prompt_request(model, input))
return TruncatePromptResponse(outputs=outputs)

async def truncate_prompt_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TruncatePromptResult:
try:
if request.max_prompt_tokens is None:
raise ValueError("max_prompt_tokens is required")

_prompt, discarded_messages = await model.truncate_prompt(
request.messages, request.max_prompt_tokens
)
return TruncatePromptSuccess(discarded_messages=discarded_messages)
except Exception as e:
return TruncatePromptError(error=str(e))
7 changes: 7 additions & 0 deletions aidial_adapter_vertexai/utils/not_implemented.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def not_implemented(func):
setattr(func, "_not_implemented", True)
return func


def is_implemented(method):
return not getattr(method, "_not_implemented", False)
4 changes: 2 additions & 2 deletions client/chat/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ async def task(on_content):
prompt = await self.model.parse_prompt(self.history)
if isinstance(prompt, UserError):
raise prompt
else:
await self.model.chat(params, consumer, prompt)

await self.model.chat(params, consumer, prompt)

async def on_content(chunk: str):
return
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def pytest_configure(config):
"""
Setting up fake environment variables for unit tests.
"""
os.environ["DEFAULT_REGION"] = "dummy_region"
os.environ["DEFAULT_REGION"] = "us-central1"
os.environ["GCP_PROJECT_ID"] = "dummy_project_id"
51 changes: 51 additions & 0 deletions tests/unit_tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List, Tuple

import pytest
import requests

from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from tests.conftest import TEST_SERVER_URL

test_cases: List[Tuple[ChatCompletionDeployment, bool, bool]] = [
(ChatCompletionDeployment.CHAT_BISON_1, True, True),
(ChatCompletionDeployment.CHAT_BISON_2, True, True),
(ChatCompletionDeployment.CHAT_BISON_2_32K, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_1, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_2, True, True),
(ChatCompletionDeployment.CODECHAT_BISON_2_32K, True, True),
(ChatCompletionDeployment.GEMINI_PRO_1, True, False),
(ChatCompletionDeployment.GEMINI_PRO_VISION_1, True, False),
(ChatCompletionDeployment.IMAGEN_005, True, True),
]


def feature_test_helper(
url: str, is_supported: bool, headers: dict, payload: dict
) -> None:
response = requests.post(url, json=payload, headers=headers)
assert (
response.status_code != 404
) == is_supported, (
f"is_supported={is_supported}, code={response.status_code}, url={url}"
)


@pytest.mark.parametrize(
"deployment, tokenize_supported, truncate_supported", test_cases
)
def test_model_features(
server,
deployment: ChatCompletionDeployment,
tokenize_supported: bool,
truncate_supported: bool,
):
payload = {"inputs": []}
headers = {"Content-Type": "application/json", "Api-Key": "dummy"}

BASE_URL = f"{TEST_SERVER_URL}/openai/deployments/{deployment.value}"

tokenize_url = f"{BASE_URL}/tokenize"
feature_test_helper(tokenize_url, tokenize_supported, headers, payload)

truncate_url = f"{BASE_URL}/truncate_prompt"
feature_test_helper(truncate_url, truncate_supported, headers, payload)

0 comments on commit 62e6b81

Please sign in to comment.