Skip to content

Commit

Permalink
fix: fixed typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 23, 2025
1 parent 2009c5e commit 0c2adde
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions aidial_adapter_vertexai/chat/claude/conversation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
)

ClaudePart = (
str
| TextBlockParam
TextBlockParam
| ImageBlockParam
| ToolUseBlockParam
| ToolResultBlockParam
Expand Down Expand Up @@ -68,8 +67,11 @@ def create_text_part(self, text: str) -> ClaudePart:
def create_function_call_part(
self, name: str, args: str, *, tool_call_id: str | None = None
) -> ClaudePart:
if tool_call_id is None:
raise InvalidRequestError("tool_call_id field must be present")

return ToolUseBlockParam(
id=tool_call_id or "123", # fixme
id=tool_call_id,
input=json.loads(args),
name=name,
type="tool_use",
Expand All @@ -78,8 +80,11 @@ def create_function_call_part(
def create_function_result_part(
self, name: str, args: str, *, tool_call_id: str | None = None
) -> ClaudePart:
if tool_call_id is None:
raise InvalidRequestError("tool_call_id field must be present")

return ToolResultBlockParam(
tool_use_id=tool_call_id or "123", # fixme
tool_use_id=tool_call_id,
type="tool_result",
content=[{"type": "text", "text": args}],
)
Expand All @@ -89,9 +94,9 @@ def create_content(
) -> MessageParam:
match role:
case Role.USER | Role.FUNCTION | Role.TOOL:
return MessageParam(content=parts, role="user") # type: ignore
return MessageParam(content=parts, role="user")
case Role.ASSISTANT:
return MessageParam(content=parts, role="assistant") # type: ignore
return MessageParam(content=parts, role="assistant")
case Role.SYSTEM:
raise InvalidRequestError(
"System message is only allowed as the first message"
Expand All @@ -105,6 +110,24 @@ def create_conversation(
contents: List[MessageParam],
) -> ClaudeConversation:
return ClaudeConversation.create(
system_instruction, # type: ignore # FIXME
_sanitize_system_instruction(system_instruction),
contents,
)


def _sanitize_system_instruction(
parts: List[ClaudePart] | None,
) -> List[TextBlockParam] | None:
if parts is None:
return None

ret: List[TextBlockParam] = []
for part in parts:
if isinstance(part, dict) and part["type"] == "text":
ret.append(part)
else:
raise InvalidRequestError(
"Only text parts are allowed in the system message"
)

return ret

0 comments on commit 0c2adde

Please sign in to comment.