Skip to content

Commit

Permalink
feat: made tool_call_id non-optional
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 23, 2025
1 parent 0c2adde commit b0d4367
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 30 deletions.
12 changes: 3 additions & 9 deletions aidial_adapter_vertexai/chat/claude/conversation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@ def create_text_part(self, text: str) -> ClaudePart:
return TextBlockParam(type="text", text=text)

def create_function_call_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> ClaudePart:
if tool_call_id is None:
raise InvalidRequestError("tool_call_id field must be present")

return ToolUseBlockParam(
id=tool_call_id,
input=json.loads(args),
Expand All @@ -78,15 +75,12 @@ def create_function_call_part(
)

def create_function_result_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> ClaudePart:
if tool_call_id is None:
raise InvalidRequestError("tool_call_id field must be present")

return ToolResultBlockParam(
tool_use_id=tool_call_id,
type="tool_result",
content=[{"type": "text", "text": args}],
content=[TextBlockParam(type="text", text=args)],
)

def create_content(
Expand Down
30 changes: 25 additions & 5 deletions aidial_adapter_vertexai/chat/conversation/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,32 @@
FunctionArgs = str


class Counter:
count: int = 0

def post_inc(self):
old_count = self.count
self.count += 1
return old_count


async def messages_to_conversation(
conversation_factory: ConversationFactoryBase[PartT, Any, ConversationT],
processors: AttachmentProcessorsBase[PartT],
tools: ToolsConfig,
messages: List[Message],
) -> ConversationT:
function_call_idx = Counter()

message_parts = [
(
message.role,
await _message_to_parts(
processors, tools, message, conversation_factory
processors,
tools,
message,
conversation_factory,
function_call_idx,
),
)
for message in messages
Expand All @@ -48,8 +63,8 @@ async def _message_to_parts(
tools: ToolsConfig,
message: Message,
conversation_factory: ConversationFactoryBase,
function_call_idx: Counter,
) -> List[PartT]:

content = message.content

match message.role:
Expand All @@ -65,18 +80,20 @@ async def _message_to_parts(

case Role.ASSISTANT:
if message.function_call is not None:
tool_call_id = f"function_call_{function_call_idx.count}"
return [
conversation_factory.create_function_call_part(
message.function_call.name,
message.function_call.arguments,
tool_call_id,
)
]
elif message.tool_calls is not None:
return [
conversation_factory.create_function_call_part(
call.function.name,
call.function.arguments,
tool_call_id=call.id,
call.id,
)
for call in message.tool_calls
]
Expand All @@ -99,8 +116,11 @@ async def _message_to_parts(
name = message.name
if name is None:
raise ValidationError("Function message name must be present")
tool_call_id = f"function_call_{function_call_idx.post_inc()}"
return [
conversation_factory.create_function_result_part(name, content)
conversation_factory.create_function_result_part(
name, content, tool_call_id
)
]

case Role.TOOL:
Expand All @@ -116,7 +136,7 @@ async def _message_to_parts(
name = tools.get_tool_name(tool_call_id)
return [
conversation_factory.create_function_result_part(
name, content, tool_call_id=tool_call_id
name, content, tool_call_id
)
]

Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_vertexai/chat/conversation/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def create_text_part(self, text: str) -> PartT: ...

@abstractmethod
def create_function_call_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> PartT: ...

@abstractmethod
def create_function_result_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> PartT: ...

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions aidial_adapter_vertexai/chat/gemini/conversation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_text_part(self, text: str) -> Part:
return Part.from_text(text)

def create_function_call_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> Part:
try:
args = json.loads(args)
Expand All @@ -64,7 +64,7 @@ def create_function_call_part(
)

def create_function_result_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> Part:
try:
args = json.loads(args)
Expand Down Expand Up @@ -115,7 +115,7 @@ def create_text_part(self, text: str) -> GenAIPart:
return GenAIPart.from_text(text)

def create_function_call_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> GenAIPart:
try:
return GenAIPart.from_function_call(name, json.loads(args))
Expand All @@ -125,7 +125,7 @@ def create_function_call_part(
)

def create_function_result_part(
self, name: str, args: str, *, tool_call_id: str | None = None
self, name: str, args: str, tool_call_id: str
) -> GenAIPart:
try:
processed_args = json.loads(args)
Expand Down
22 changes: 12 additions & 10 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,19 +632,21 @@ async def test_chat_completion_openai(get_openai_client, test: TestCase):
)

async def run_chat_completion() -> ChatCompletionResult:
retries = 7
delay = 5
attempts = 7
delay = 1

async def _retry_wait(
retries: int, delay: int, e: APIError | RateLimitError
is_last_attempt: bool, e: APIError | RateLimitError
):
if attempt < retries - 1:
delay *= 2
await asyncio.sleep(delay)
else:
if is_last_attempt:
raise e

for attempt in range(retries):
nonlocal delay
await asyncio.sleep(delay)
delay *= 2

for attempt in range(attempts):
is_last_attempt = attempt == attempts - 1
try:
return await chat_completion(
client,
Expand All @@ -658,10 +660,10 @@ async def _retry_wait(
test.static_tools,
)
except RateLimitError as e:
await _retry_wait(retries, delay, e)
await _retry_wait(is_last_attempt, e)
except APIError as e:
if e.code == "429":
await _retry_wait(retries, delay, e)
await _retry_wait(is_last_attempt, e)
else:
raise e
raise RuntimeError("Failed to get a valid response")
Expand Down

0 comments on commit b0d4367

Please sign in to comment.