Skip to content

Commit

Permalink
feat: added integration tests for Claude tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 23, 2025
1 parent 580ccab commit 2009c5e
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 66 deletions.
24 changes: 6 additions & 18 deletions aidial_adapter_vertexai/chat/claude/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, assert_never

from aidial_sdk.chat_completion import Message
from aidial_sdk.exceptions import InternalServerError
from anthropic import AsyncAnthropicVertex, MessageStopEvent
from anthropic.lib.streaming import (
ContentBlockStopEvent,
Expand All @@ -12,9 +13,10 @@
ContentBlockDeltaEvent,
ContentBlockStartEvent,
MessageDeltaEvent,
MessageStartEvent,
TextBlock,
ToolUseBlock,
)
from anthropic.types import MessageParam as ClaudeMessage
from anthropic.types import MessageStartEvent, TextBlock, ToolUseBlock
from typing_extensions import override

from aidial_adapter_vertexai.app_config import get_anthropic_client
Expand All @@ -28,10 +30,7 @@
create_chat_params,
none_to_not_given,
)
from aidial_adapter_vertexai.chat.claude.prompt.base import (
ClaudeConversation,
ClaudePrompt,
)
from aidial_adapter_vertexai.chat.claude.prompt.base import ClaudePrompt
from aidial_adapter_vertexai.chat.claude.prompt.claude_3 import (
parse_claude_3_prompt,
)
Expand All @@ -49,7 +48,6 @@
from aidial_adapter_vertexai.dial_api.storage import FileStorage
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.json import json_dumps_short
from aidial_adapter_vertexai.utils.list_projection import ListProjection
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log


Expand Down Expand Up @@ -164,7 +162,6 @@ async def _invoke_streaming(
case _:
assert_never(content_block)
case MessageStopEvent(message=message):
completion_tokens += message.usage.output_tokens
stop_reason = message.stop_reason
case (
InputJsonEvent()
Expand Down Expand Up @@ -247,16 +244,7 @@ async def count_prompt_tokens(self, prompt: ClaudePrompt) -> int:

@override
async def count_completion_tokens(self, string: str) -> int:
# FIXME: figure out if it's correct - add an integration test for it
message: ClaudeMessage = {"role": "user", "content": string}
return await self.count_prompt_tokens(
ClaudePrompt(
conversation=ClaudeConversation(
system=None,
messages=ListProjection.create([message]),
)
)
)
raise InternalServerError("Tokenization of strings is not supported")


def _project_to_original_indices(
Expand Down
3 changes: 1 addition & 2 deletions aidial_adapter_vertexai/chat/claude/prompt/claude_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
AttachmentProcessor,
AttachmentProcessorsBase,
max_count_validator,
seq_validators,
)
from aidial_adapter_vertexai.chat.claude.conversation_factory import (
SUPPORTED_IMAGE_TYPES,
Expand Down Expand Up @@ -96,7 +95,7 @@ def _create_image_processor(max_count: int) -> AttachmentProcessor:
# NOTE: not checked condition: The maximum allowed image file size is 5 MB
return AttachmentProcessor(
file_types=SUPPORTED_IMAGE_TYPES,
init_validator=seq_validators(None, max_count_validator(max_count)),
init_validator=max_count_validator(max_count),
)


Expand Down
6 changes: 6 additions & 0 deletions tests/integration_tests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from aidial_adapter_vertexai.utils.resource import Resource

BLUE_PNG_PICTURE = Resource.from_base64(
type="image/png",
data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=",
)
36 changes: 16 additions & 20 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Callable, List, Mapping
Expand All @@ -16,13 +15,14 @@

from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from tests.integration_tests.constants import BLUE_PNG_PICTURE
from tests.utils.dial import get_extra_headers
from tests.utils.openai import (
GET_WEATHER_FUNCTION,
ChatCompletionResult,
ai,
ai_function,
ai_tools,
blue_pic,
chat_completion,
for_all_choices,
function_request,
Expand Down Expand Up @@ -310,18 +310,22 @@ def test_case(
),
)

def _check_max_tokens_1(r: ChatCompletionResult) -> bool:
expected_tokens = 0 if support_thinking(deployment) else 1
assert for_all_choices(
lambda text: len(text.split()) == expected_tokens
)(r)
assert r.usage is not None
assert r.usage.completion_tokens == expected_tokens
return True

test_case(
name="max tokens 1",
max_tokens=1,
messages=[user("tell me the full story of Pinocchio")],
expected=for_all_choices(
lambda s: (
len(s.split()) == 1
if not support_thinking(deployment)
else len(s.split()) == 0
)
),
expected=_check_max_tokens_1,
)

# Gemini 2.0 rate-limits always fail on such concurrency
candidates_count = 5 if not is_gemini_2(deployment) else 2
test_case(
Expand Down Expand Up @@ -357,9 +361,9 @@ def test_case(
content = "describe the image"
for idx, user_message in enumerate(
[
user_with_attachment_data(content, blue_pic),
user_with_attachment_url(content, blue_pic),
user_with_image_url(content, blue_pic),
user_with_attachment_data(content, BLUE_PNG_PICTURE),
user_with_attachment_url(content, BLUE_PNG_PICTURE),
user_with_image_url(content, BLUE_PNG_PICTURE),
]
):
test_case(
Expand Down Expand Up @@ -612,14 +616,6 @@ def _check(id: str) -> bool:
return test_cases


def get_extra_headers(region: str | None) -> Mapping[str, str]:
return (
{"x-upstream-extra-data": json.dumps({"region": region})}
if region is not None
else {}
)


@pytest.mark.parametrize(
"test",
[
Expand Down
Loading

0 comments on commit 2009c5e

Please sign in to comment.