Skip to content

Commit

Permalink
feat: added region to integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 23, 2025
1 parent d240f94 commit 38faa35
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 43 deletions.
21 changes: 9 additions & 12 deletions aidial_adapter_vertexai/dial_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,17 @@ def to_dial_exception(e: Exception) -> DialException:
)

if isinstance(e, anthropic.APIStatusError):
code = e.status_code
try:
response = e.response.json()["error"]
return DialException(
status_code=code,
type=response["type"],
message=response["message"],
)
message = e.body["error"]["message"] # type: ignore
except Exception:
return DialException(
status_code=code,
type=_get_exception_type(code),
message=e.message,
)
message = e.message

code = e.status_code
return DialException(
status_code=code,
type=_get_exception_type(code),
message=message,
)

if isinstance(e, ValidationError):
return e.to_dial_exception()
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Mapping

import httpx
import pytest
Expand Down Expand Up @@ -49,7 +49,10 @@ async def test_http_client() -> AsyncGenerator[httpx.AsyncClient, None]:

@pytest.fixture
def get_openai_client(test_http_client: httpx.AsyncClient):
def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI:
def _get_client(
deployment_id: str | None = None,
extra_headers: Mapping[str, str] | None = None,
) -> AsyncAzureOpenAI:
return AsyncAzureOpenAI(
azure_endpoint=str(test_http_client.base_url),
azure_deployment=deployment_id,
Expand All @@ -58,6 +61,7 @@ def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI:
max_retries=2,
timeout=30,
http_client=test_http_client,
default_headers=extra_headers,
)

yield _get_client
80 changes: 53 additions & 27 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Callable, List
from typing import Callable, List, Mapping

import pytest
from aidial_sdk.chat_completion.request import StaticFunction
Expand Down Expand Up @@ -55,6 +56,7 @@ class TestCase:
__test__ = False

name: str
region: str | None
deployment: ChatCompletionDeployment
streaming: bool

Expand All @@ -71,32 +73,45 @@ class TestCase:
static_tools: StaticToolsConfig | None

def get_id(self):
max_tokens_str = f"maxt={self.max_tokens}" if self.max_tokens else ""
stop_sequence_str = f"stop={self.stop}" if self.stop else ""
n_str = f"n={self.n}" if self.n else ""
max_tokens_str = f"maxt:{self.max_tokens}" if self.max_tokens else None
stop_sequence_str = f"stop:{self.stop}" if self.stop else None
n_str = f"n:{self.n}" if self.n else None
return sanitize_test_name(
f"{self.deployment.value} {self.streaming} {max_tokens_str} "
f"{stop_sequence_str} {n_str} {self.name}"
"/".join(
str(part)
for part in [
self.deployment.value,
self.streaming,
max_tokens_str,
stop_sequence_str,
n_str,
self.name,
]
if part is not None
)
)


deployments = [
ChatCompletionDeployment.CHAT_BISON_1,
ChatCompletionDeployment.CHAT_BISON_2_32K,
ChatCompletionDeployment.CODECHAT_BISON_1,
ChatCompletionDeployment.GEMINI_PRO_1,
ChatCompletionDeployment.GEMINI_FLASH_1_5_V2,
ChatCompletionDeployment.GEMINI_PRO_VISION_1,
ChatCompletionDeployment.GEMINI_PRO_1_5_V2,
ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP,
ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206,
ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219,
ChatCompletionDeployment.CLAUDE_3_5_SONNET_V2,
ChatCompletionDeployment.CLAUDE_3_5_HAIKU,
ChatCompletionDeployment.CLAUDE_3_OPUS,
ChatCompletionDeployment.CLAUDE_3_5_SONNET,
ChatCompletionDeployment.CLAUDE_3_HAIKU,
]
_CENTRAL = "us-central1"
_EAST = "us-east5"

chat_deployments: Mapping[ChatCompletionDeployment, str] = {
ChatCompletionDeployment.CHAT_BISON_1: _CENTRAL,
ChatCompletionDeployment.CHAT_BISON_2_32K: _CENTRAL,
ChatCompletionDeployment.CODECHAT_BISON_1: _CENTRAL,
ChatCompletionDeployment.GEMINI_PRO_1: _CENTRAL,
ChatCompletionDeployment.GEMINI_FLASH_1_5_V2: _CENTRAL,
ChatCompletionDeployment.GEMINI_PRO_VISION_1: _CENTRAL,
ChatCompletionDeployment.GEMINI_PRO_1_5_V2: _CENTRAL,
ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP: _CENTRAL,
ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206: _CENTRAL,
ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219: _CENTRAL,
ChatCompletionDeployment.CLAUDE_3_5_SONNET_V2: _EAST,
ChatCompletionDeployment.CLAUDE_3_5_HAIKU: _EAST,
ChatCompletionDeployment.CLAUDE_3_OPUS: _EAST,
ChatCompletionDeployment.CLAUDE_3_5_SONNET: _EAST,
ChatCompletionDeployment.CLAUDE_3_HAIKU: _EAST,
}


def is_codechat(deployment: ChatCompletionDeployment) -> bool:
Expand Down Expand Up @@ -195,7 +210,7 @@ def is_gemini_2(deployment: ChatCompletionDeployment) -> bool:


def get_test_cases(
deployment: ChatCompletionDeployment, streaming: bool
deployment: ChatCompletionDeployment, region: str, streaming: bool
) -> List[TestCase]:
test_cases: List[TestCase] = []

Expand All @@ -215,6 +230,7 @@ def test_case(
test_cases.append(
TestCase(
name,
region,
deployment,
streaming,
messages,
Expand Down Expand Up @@ -596,18 +612,28 @@ def _check(id: str) -> bool:
return test_cases


def get_extra_headers(region: str | None) -> Mapping[str, str]:
return (
{"x-upstream-extra-data": json.dumps({"region": region})}
if region is not None
else {}
)


@pytest.mark.parametrize(
"test",
[
test
for deployment in deployments
for deployment, region in chat_deployments.items()
for streaming in [False, True]
for test in get_test_cases(deployment, streaming)
for test in get_test_cases(deployment, region, streaming)
],
ids=lambda test: test.get_id(),
)
async def test_chat_completion_openai(get_openai_client, test: TestCase):
client = get_openai_client(test.deployment.value)
client = get_openai_client(
test.deployment.value, get_extra_headers(test.region)
)

async def run_chat_completion() -> ChatCompletionResult:
retries = 7
Expand Down
6 changes: 4 additions & 2 deletions tests/utils/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ def function_to_tool(function: FunctionDefinition) -> ChatCompletionToolParam:


def sanitize_test_name(name: str) -> str:
name2 = "".join(c if c.isalnum() else "_" for c in name.lower())
return re.sub("_+", "_", name2)
name = "".join(
c if (c.isalnum() or c in "/:") else "_" for c in name.lower()
)
return re.sub("_+", "_", name)


class ChatCompletionResult(BaseModel):
Expand Down

0 comments on commit 38faa35

Please sign in to comment.