From aa8c5b8f5c6efac3db24a798f6012c56ea94552f Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Thu, 16 Jan 2025 12:37:12 +0100 Subject: [PATCH] fix: add model field to chat completion responses (#174) --- aidial_adapter_vertexai/chat_completion.py | 2 ++ tests/conftest.py | 6 +++-- .../integration_tests/test_chat_completion.py | 7 ++++++ tests/utils/openai.py | 23 +++++++++++-------- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/aidial_adapter_vertexai/chat_completion.py b/aidial_adapter_vertexai/chat_completion.py index e5c75fb..0f2c7e1 100644 --- a/aidial_adapter_vertexai/chat_completion.py +++ b/aidial_adapter_vertexai/chat_completion.py @@ -51,6 +51,8 @@ async def _get_model( @dial_exception_decorator async def chat_completion(self, request: Request, response: Response): + response.set_model(request.deployment_id) + model = await self._get_model(request) tools = ToolsConfig.from_request(request) static_tools = StaticToolsConfig.from_request(request) diff --git a/tests/conftest.py b/tests/conftest.py index b498b96..8a806b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,8 @@ async def test_http_client() -> AsyncGenerator[httpx.AsyncClient, None]: async with httpx.AsyncClient( transport=ASGITransport(app), # type: ignore base_url="http://test-app.com", + params={"api-version": "dummy-version"}, + headers={"api-key": "dummy-key"}, ) as client: yield client @@ -51,8 +53,8 @@ def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI: return AsyncAzureOpenAI( azure_endpoint=str(test_http_client.base_url), azure_deployment=deployment_id, - api_version="", - api_key="dummy_key", + api_version="dummy-version", + api_key="dummy-key", max_retries=2, timeout=30, http_client=test_http_client, diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 5cc15d9..15eb464 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -207,6 +207,13 @@ def test_case( expected=for_all_choices(lambda s: "5" in s), ) + test_case( + name="model field", + messages=[user("test")], + max_tokens=1, + expected=lambda s: s.response.model == deployment.value, + ) + test_case( name="hello", messages=[user('Reply with "Hello"')], diff --git a/tests/utils/openai.py b/tests/utils/openai.py index 2b8b587..9a60f7a 100644 --- a/tests/utils/openai.py +++ b/tests/utils/openai.py @@ -1,6 +1,6 @@ import json import re -from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar +from typing import Any, Callable, List, Optional, TypeVar import httpx from aidial_sdk.chat_completion.request import Attachment, StaticTool @@ -10,7 +10,10 @@ TokenizeResponse, TokenizeSuccess, ) -from aidial_sdk.utils.streaming import merge_chunks +from aidial_sdk.utils.merge_chunks import ( + cleanup_indices, + merge_chat_completion_chunks, +) from openai import AsyncAzureOpenAI, AsyncStream from openai._types import NOT_GIVEN from openai.types import CompletionUsage @@ -215,7 +218,6 @@ async def tokenize( tokenize_response = await http_client.post( f"openai/deployments/{model_id}/tokenize", json=tokenize_request, - headers={"api-key": "dummy_key"}, ) tokenize_response.raise_for_status() @@ -250,7 +252,7 @@ async def get_response() -> ChatCompletion: merged_tools += tools response = await client.chat.completions.create( - model="dummy_model", + model="dummy-model", messages=messages, stream=stream, stop=stop, @@ -267,14 +269,17 @@ async def get_response() -> ChatCompletion: ) if isinstance(response, AsyncStream): + chunks: List[dict] = [] + async for chunk in response: + chunks.append(chunk.dict()) + + response_dict = merge_chat_completion_chunks(*chunks) - async def generator() -> AsyncGenerator[dict, None]: - async for chunk in response: - yield chunk.dict() + for choice in response_dict["choices"]: + choice["message"] = cleanup_indices(choice["delta"]) + del choice["delta"] - response_dict = await merge_chunks(generator()) response_dict["object"] = "chat.completion" - response_dict["model"] = "dummy_model" return ChatCompletion.parse_obj(response_dict) else: