Skip to content

Commit

Permalink
feat: Add Google Search Grounding static tool (#147)
Browse files Browse the repository at this point in the history
* add google search static tool

* Fix google search tool

* fix linter

* remove dead code

* Fixes due to PR comments

* fix tests

* fix test helper

* replace long comment with issue in GH

* update integration test for static tools

* minor refactor
  • Loading branch information
roman-romanov-o authored Nov 15, 2024
1 parent 9d3d351 commit 801bedf
Show file tree
Hide file tree
Showing 17 changed files with 339 additions and 48 deletions.
7 changes: 6 additions & 1 deletion aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -36,9 +37,13 @@ def send_message_async(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> BisonPrompt:
tools.not_supported()
static_tools.not_supported()
return BisonPrompt.parse(messages)

@override
Expand Down
6 changes: 5 additions & 1 deletion aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -16,7 +17,10 @@
class ChatCompletionAdapter(ABC, Generic[P]):
@abstractmethod
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> P | UserError:
pass

Expand Down
32 changes: 26 additions & 6 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_5 import (
Gemini_1_5_Prompt,
)
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
Expand Down Expand Up @@ -105,14 +106,19 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiPrompt | UserError:
match self.deployment:
case ChatCompletionDeployment.GEMINI_PRO_1:
return await Gemini_1_0_Pro_Prompt.parse(tools, messages)
return await Gemini_1_0_Pro_Prompt.parse(
tools, static_tools, messages
)
case ChatCompletionDeployment.GEMINI_PRO_VISION_1:
return await Gemini_1_0_Pro_Vision_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case (
ChatCompletionDeployment.GEMINI_PRO_1_5_PREVIEW
Expand All @@ -122,7 +128,7 @@ async def parse_prompt(
| ChatCompletionDeployment.GEMINI_FLASH_1_5_V2
):
return await Gemini_1_5_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)
Expand All @@ -136,7 +142,7 @@ def _get_model(
parameters = create_generation_config(params) if params else None

if prompt is not None:
tools = prompt.tools.to_gemini_tools()
tools = prompt.to_gemini_tools() or None
tool_config = prompt.tools.to_gemini_tool_config()
system_instruction = cast(
List[str | Part | Image] | None,
Expand Down Expand Up @@ -190,6 +196,7 @@ async def process_chunks(
yield content

await create_function_calls(candidate, consumer, tools)
await create_grounding(candidate, consumer)
await create_attachments_from_citations(candidate, consumer)
await set_finish_reason(candidate, consumer)

Expand All @@ -211,7 +218,6 @@ async def chat(
)

completion = ""

async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
Expand Down Expand Up @@ -321,6 +327,20 @@ async def create_function_calls(
)


async def create_grounding(candidate: Candidate, consumer: Consumer) -> None:
if (
not candidate.grounding_metadata
or not candidate.grounding_metadata.grounding_chunks
):
return

for chunk in candidate.grounding_metadata.grounding_chunks:
if chunk.web and chunk.web.uri:
await consumer.add_attachment(
Attachment(url=chunk.web.uri, title=chunk.web.title)
)


def to_openai_finish_reason(
finish_reason: GenFinishReason, retriable: bool
) -> FinishReason | None:
Expand Down
11 changes: 11 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel, Field
from vertexai.preview.generative_models import Content, Part
from vertexai.preview.generative_models import Tool as GeminiTool

from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatablePrompt

Expand All @@ -20,6 +22,9 @@ class GeminiPrompt(BaseModel, TruncatablePrompt, ABC):
system_instruction: List[Part] | None = None
contents: List[Content]
tools: ToolsConfig = Field(default_factory=ToolsConfig.noop)
static_tools: StaticToolsConfig = Field(
default_factory=StaticToolsConfig.noop
)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -68,4 +73,10 @@ def select(self, indices: Set[int]) -> "GeminiPrompt":
system_instruction=system_instruction,
contents=contents,
tools=self.tools,
static_tools=self.static_tools,
)

def to_gemini_tools(self) -> List[GeminiTool]:
regular_tools = self.tools.to_gemini_tools()
static_tools = self.static_tools.to_gemini_tools()
return regular_tools + static_tools
7 changes: 6 additions & 1 deletion aidial_adapter_vertexai/chat/gemini/prompt/gemini_1_0_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
)
from aidial_adapter_vertexai.chat.gemini.processor import AttachmentProcessors
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig


class Gemini_1_0_Pro_Prompt(GeminiPrompt):
@classmethod
async def parse(
cls, tools: ToolsConfig, messages: List[Message]
cls,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
raise ValidationError(
Expand All @@ -34,4 +38,5 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.request import get_attachments
from aidial_adapter_vertexai.dial_api.storage import FileStorage
Expand All @@ -28,9 +29,11 @@ async def parse(
cls,
file_storage: FileStorage | None,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Union[Self, UserError]:
tools.not_supported()
static_tools.not_supported()

if len(messages) == 0:
raise ValidationError(
Expand Down
3 changes: 3 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/gemini_1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.storage import FileStorage

Expand All @@ -25,6 +26,7 @@ async def parse(
cls,
file_storage: Optional[FileStorage],
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
Expand Down Expand Up @@ -55,6 +57,7 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)


Expand Down
8 changes: 6 additions & 2 deletions aidial_adapter_vertexai/chat/imagen/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import (
Expand Down Expand Up @@ -43,10 +44,13 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> ImagenPrompt:
tools.not_supported()

static_tools.not_supported()
if len(messages) == 0:
raise ValidationError("The list of messages must not be empty")

Expand Down
87 changes: 87 additions & 0 deletions aidial_adapter_vertexai/chat/static_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, NoReturn, Self

from aidial_sdk.chat_completion.request import (
AzureChatCompletionRequest,
StaticFunction,
StaticTool,
)
from pydantic import BaseModel
from vertexai.preview.generative_models import Tool as GeminiTool
from vertexai.preview.generative_models import grounding

from aidial_adapter_vertexai.chat.errors import ValidationError


class ToolName(str, Enum):
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding
GOOGLE_SEARCH = "google_search"


class StaticToolProcessor(ABC):
@staticmethod
@abstractmethod
def parse_gemini_tools(
static_function: StaticFunction,
) -> List[GeminiTool] | None: ...


class GoogleSearchGroundingTool(StaticToolProcessor):
@staticmethod
def parse_gemini_tools(
static_function: StaticFunction,
) -> List[GeminiTool] | None:
if static_function.name == ToolName.GOOGLE_SEARCH:
if static_function.configuration:
raise ValidationError(
"Google search tool doesn't support configuration"
)
return [
GeminiTool.from_google_search_retrieval(
grounding.GoogleSearchRetrieval()
)
]
return None


def unknown_tool_name(
static_function: StaticFunction,
) -> NoReturn:
raise ValidationError(
f"Unsupported static function: {static_function.name}"
)


class StaticToolsConfig(BaseModel):
functions: List[StaticFunction]

@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self:
if request.tools is None:
return cls(functions=[])

return cls(
functions=[
tool.static_function
for tool in request.tools
if isinstance(tool, StaticTool)
]
)

@classmethod
def noop(cls) -> Self:
return cls(functions=[])

def to_gemini_tools(self) -> List[GeminiTool]:
ret: List[GeminiTool] = []
for tool in self.functions:
ret.extend(
GoogleSearchGroundingTool.parse_gemini_tools(tool)
or unknown_tool_name(tool)
)
return ret

def not_supported(self) -> None:
if self.functions:
raise ValidationError("Static tools aren't supported")
12 changes: 8 additions & 4 deletions aidial_adapter_vertexai/chat/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Role,
ToolChoice,
)
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest, Tool
from pydantic import BaseModel
from vertexai.preview.generative_models import (
FunctionDeclaration as GeminiFunction,
Expand Down Expand Up @@ -117,7 +117,11 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:
tool_ids = None

elif request.tools is not None:
functions = [tool.function for tool in request.tools]
functions = [
tool.function
for tool in request.tools
if isinstance(tool, Tool)
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
Expand All @@ -137,9 +141,9 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:

return cls(functions=selected, required=required, tool_ids=tool_ids)

def to_gemini_tools(self) -> List[GeminiTool] | None:
def to_gemini_tools(self) -> List[GeminiTool]:
if not self.functions:
return None
return []

return [
GeminiTool(
Expand Down
Loading

0 comments on commit 801bedf

Please sign in to comment.