From 0c2adde337b7ff1accd3bf6c0cc1aa7e8dc064b9 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Thu, 23 Jan 2025 17:07:34 +0000 Subject: [PATCH] fix: fixed typing issues --- .../chat/claude/conversation_factory.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/aidial_adapter_vertexai/chat/claude/conversation_factory.py b/aidial_adapter_vertexai/chat/claude/conversation_factory.py index aa38742..f4d5922 100644 --- a/aidial_adapter_vertexai/chat/claude/conversation_factory.py +++ b/aidial_adapter_vertexai/chat/claude/conversation_factory.py @@ -22,8 +22,7 @@ ) ClaudePart = ( - str - | TextBlockParam + TextBlockParam | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam @@ -68,8 +67,11 @@ def create_text_part(self, text: str) -> ClaudePart: def create_function_call_part( self, name: str, args: str, *, tool_call_id: str | None = None ) -> ClaudePart: + if tool_call_id is None: + raise InvalidRequestError("tool_call_id field must be present") + return ToolUseBlockParam( - id=tool_call_id or "123", # fixme + id=tool_call_id, input=json.loads(args), name=name, type="tool_use", @@ -78,8 +80,11 @@ def create_function_call_part( def create_function_result_part( self, name: str, args: str, *, tool_call_id: str | None = None ) -> ClaudePart: + if tool_call_id is None: + raise InvalidRequestError("tool_call_id field must be present") + return ToolResultBlockParam( - tool_use_id=tool_call_id or "123", # fixme + tool_use_id=tool_call_id, type="tool_result", content=[{"type": "text", "text": args}], ) @@ -89,9 +94,9 @@ def create_content( ) -> MessageParam: match role: case Role.USER | Role.FUNCTION | Role.TOOL: - return MessageParam(content=parts, role="user") # type: ignore + return MessageParam(content=parts, role="user") case Role.ASSISTANT: - return MessageParam(content=parts, role="assistant") # type: ignore + return MessageParam(content=parts, role="assistant") case Role.SYSTEM: raise InvalidRequestError( "System message is only allowed as the first message" @@ -105,6 +110,24 @@ def create_conversation( contents: List[MessageParam], ) -> ClaudeConversation: return ClaudeConversation.create( - system_instruction, # type: ignore # FIXME + _sanitize_system_instruction(system_instruction), contents, ) + + +def _sanitize_system_instruction( + parts: List[ClaudePart] | None, +) -> List[TextBlockParam] | None: + if parts is None: + return None + + ret: List[TextBlockParam] = [] + for part in parts: + if isinstance(part, dict) and part["type"] == "text": + ret.append(part) + else: + raise InvalidRequestError( + "Only text parts are allowed in the system message" + ) + + return ret