diff --git a/aidial_adapter_vertexai/dial_api/exceptions.py b/aidial_adapter_vertexai/dial_api/exceptions.py index ca94a08..799ecc0 100644 --- a/aidial_adapter_vertexai/dial_api/exceptions.py +++ b/aidial_adapter_vertexai/dial_api/exceptions.py @@ -60,20 +60,17 @@ def to_dial_exception(e: Exception) -> DialException: ) if isinstance(e, anthropic.APIStatusError): - code = e.status_code try: - response = e.response.json()["error"] - return DialException( - status_code=code, - type=response["type"], - message=response["message"], - ) + message = e.body["error"]["message"] # type: ignore except Exception: - return DialException( - status_code=code, - type=_get_exception_type(code), - message=e.message, - ) + message = e.message + + code = e.status_code + return DialException( + status_code=code, + type=_get_exception_type(code), + message=message, + ) if isinstance(e, ValidationError): return e.to_dial_exception() diff --git a/tests/conftest.py b/tests/conftest.py index 8a806b5..62ee38c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import os -from typing import AsyncGenerator +from typing import AsyncGenerator, Mapping import httpx import pytest @@ -49,7 +49,10 @@ async def test_http_client() -> AsyncGenerator[httpx.AsyncClient, None]: @pytest.fixture def get_openai_client(test_http_client: httpx.AsyncClient): - def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI: + def _get_client( + deployment_id: str | None = None, + extra_headers: Mapping[str, str] | None = None, + ) -> AsyncAzureOpenAI: return AsyncAzureOpenAI( azure_endpoint=str(test_http_client.base_url), azure_deployment=deployment_id, @@ -58,6 +61,7 @@ def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI: max_retries=2, timeout=30, http_client=test_http_client, + default_headers=extra_headers, ) yield _get_client diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 9eaf47d..4362dfc 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -1,7 +1,8 @@ import asyncio +import json import re from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Mapping import pytest from aidial_sdk.chat_completion.request import StaticFunction @@ -55,6 +56,7 @@ class TestCase: __test__ = False name: str + region: str | None deployment: ChatCompletionDeployment streaming: bool @@ -71,32 +73,45 @@ class TestCase: static_tools: StaticToolsConfig | None def get_id(self): - max_tokens_str = f"maxt={self.max_tokens}" if self.max_tokens else "" - stop_sequence_str = f"stop={self.stop}" if self.stop else "" - n_str = f"n={self.n}" if self.n else "" + max_tokens_str = f"maxt:{self.max_tokens}" if self.max_tokens else None + stop_sequence_str = f"stop:{self.stop}" if self.stop else None + n_str = f"n:{self.n}" if self.n else None return sanitize_test_name( - f"{self.deployment.value} {self.streaming} {max_tokens_str} " - f"{stop_sequence_str} {n_str} {self.name}" + "/".join( + str(part) + for part in [ + self.deployment.value, + self.streaming, + max_tokens_str, + stop_sequence_str, + n_str, + self.name, + ] + if part is not None + ) ) -deployments = [ - ChatCompletionDeployment.CHAT_BISON_1, - ChatCompletionDeployment.CHAT_BISON_2_32K, - ChatCompletionDeployment.CODECHAT_BISON_1, - ChatCompletionDeployment.GEMINI_PRO_1, - ChatCompletionDeployment.GEMINI_FLASH_1_5_V2, - ChatCompletionDeployment.GEMINI_PRO_VISION_1, - ChatCompletionDeployment.GEMINI_PRO_1_5_V2, - ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP, - ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206, - ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219, - ChatCompletionDeployment.CLAUDE_3_5_SONNET_V2, - ChatCompletionDeployment.CLAUDE_3_5_HAIKU, - ChatCompletionDeployment.CLAUDE_3_OPUS, - ChatCompletionDeployment.CLAUDE_3_5_SONNET, - ChatCompletionDeployment.CLAUDE_3_HAIKU, -] +_CENTRAL = "us-central1" +_EAST = "us-east5" + +chat_deployments: Mapping[ChatCompletionDeployment, str] = { + ChatCompletionDeployment.CHAT_BISON_1: _CENTRAL, + ChatCompletionDeployment.CHAT_BISON_2_32K: _CENTRAL, + ChatCompletionDeployment.CODECHAT_BISON_1: _CENTRAL, + ChatCompletionDeployment.GEMINI_PRO_1: _CENTRAL, + ChatCompletionDeployment.GEMINI_FLASH_1_5_V2: _CENTRAL, + ChatCompletionDeployment.GEMINI_PRO_VISION_1: _CENTRAL, + ChatCompletionDeployment.GEMINI_PRO_1_5_V2: _CENTRAL, + ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP: _CENTRAL, + ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206: _CENTRAL, + ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219: _CENTRAL, + ChatCompletionDeployment.CLAUDE_3_5_SONNET_V2: _EAST, + ChatCompletionDeployment.CLAUDE_3_5_HAIKU: _EAST, + ChatCompletionDeployment.CLAUDE_3_OPUS: _EAST, + ChatCompletionDeployment.CLAUDE_3_5_SONNET: _EAST, + ChatCompletionDeployment.CLAUDE_3_HAIKU: _EAST, +} def is_codechat(deployment: ChatCompletionDeployment) -> bool: @@ -195,7 +210,7 @@ def is_gemini_2(deployment: ChatCompletionDeployment) -> bool: def get_test_cases( - deployment: ChatCompletionDeployment, streaming: bool + deployment: ChatCompletionDeployment, region: str, streaming: bool ) -> List[TestCase]: test_cases: List[TestCase] = [] @@ -215,6 +230,7 @@ def test_case( test_cases.append( TestCase( name, + region, deployment, streaming, messages, @@ -596,18 +612,28 @@ def _check(id: str) -> bool: return test_cases +def get_extra_headers(region: str | None) -> Mapping[str, str]: + return ( + {"x-upstream-extra-data": json.dumps({"region": region})} + if region is not None + else {} + ) + + @pytest.mark.parametrize( "test", [ test - for deployment in deployments + for deployment, region in chat_deployments.items() for streaming in [False, True] - for test in get_test_cases(deployment, streaming) + for test in get_test_cases(deployment, region, streaming) ], ids=lambda test: test.get_id(), ) async def test_chat_completion_openai(get_openai_client, test: TestCase): - client = get_openai_client(test.deployment.value) + client = get_openai_client( + test.deployment.value, get_extra_headers(test.region) + ) async def run_chat_completion() -> ChatCompletionResult: retries = 7 diff --git a/tests/utils/openai.py b/tests/utils/openai.py index 6e8f069..37d232e 100644 --- a/tests/utils/openai.py +++ b/tests/utils/openai.py @@ -152,8 +152,10 @@ def function_to_tool(function: FunctionDefinition) -> ChatCompletionToolParam: def sanitize_test_name(name: str) -> str: - name2 = "".join(c if c.isalnum() else "_" for c in name.lower()) - return re.sub("_+", "_", name2) + name = "".join( + c if (c.isalnum() or c in "/:") else "_" for c in name.lower() + ) + return re.sub("_+", "_", name) class ChatCompletionResult(BaseModel):