Skip to content

Commit

Permalink
fix: add model field to chat completion responses (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jan 16, 2025
1 parent c0be57a commit aa8c5b8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
2 changes: 2 additions & 0 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ async def _get_model(

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
response.set_model(request.deployment_id)

model = await self._get_model(request)
tools = ToolsConfig.from_request(request)
static_tools = StaticToolsConfig.from_request(request)
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ async def test_http_client() -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(
transport=ASGITransport(app), # type: ignore
base_url="http://test-app.com",
params={"api-version": "dummy-version"},
headers={"api-key": "dummy-key"},
) as client:
yield client

Expand All @@ -51,8 +53,8 @@ def _get_client(deployment_id: str | None = None) -> AsyncAzureOpenAI:
return AsyncAzureOpenAI(
azure_endpoint=str(test_http_client.base_url),
azure_deployment=deployment_id,
api_version="",
api_key="dummy_key",
api_version="dummy-version",
api_key="dummy-key",
max_retries=2,
timeout=30,
http_client=test_http_client,
Expand Down
7 changes: 7 additions & 0 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def test_case(
expected=for_all_choices(lambda s: "5" in s),
)

test_case(
name="model field",
messages=[user("test")],
max_tokens=1,
expected=lambda s: s.response.model == deployment.value,
)

test_case(
name="hello",
messages=[user('Reply with "Hello"')],
Expand Down
23 changes: 14 additions & 9 deletions tests/utils/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar
from typing import Any, Callable, List, Optional, TypeVar

import httpx
from aidial_sdk.chat_completion.request import Attachment, StaticTool
Expand All @@ -10,7 +10,10 @@
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.utils.streaming import merge_chunks
from aidial_sdk.utils.merge_chunks import (
cleanup_indices,
merge_chat_completion_chunks,
)
from openai import AsyncAzureOpenAI, AsyncStream
from openai._types import NOT_GIVEN
from openai.types import CompletionUsage
Expand Down Expand Up @@ -215,7 +218,6 @@ async def tokenize(
tokenize_response = await http_client.post(
f"openai/deployments/{model_id}/tokenize",
json=tokenize_request,
headers={"api-key": "dummy_key"},
)

tokenize_response.raise_for_status()
Expand Down Expand Up @@ -250,7 +252,7 @@ async def get_response() -> ChatCompletion:
merged_tools += tools

response = await client.chat.completions.create(
model="dummy_model",
model="dummy-model",
messages=messages,
stream=stream,
stop=stop,
Expand All @@ -267,14 +269,17 @@ async def get_response() -> ChatCompletion:
)

if isinstance(response, AsyncStream):
chunks: List[dict] = []
async for chunk in response:
chunks.append(chunk.dict())

response_dict = merge_chat_completion_chunks(*chunks)

async def generator() -> AsyncGenerator[dict, None]:
async for chunk in response:
yield chunk.dict()
for choice in response_dict["choices"]:
choice["message"] = cleanup_indices(choice["delta"])
del choice["delta"]

response_dict = await merge_chunks(generator())
response_dict["object"] = "chat.completion"
response_dict["model"] = "dummy_model"

return ChatCompletion.parse_obj(response_dict)
else:
Expand Down

0 comments on commit aa8c5b8

Please sign in to comment.