-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: added timings to debug prints * feat: allowed empty messages for Bison models * feat: supported streaming for Bison models * fix: disabled malfunctioning clustering- and classification- gecko endpoints * chore: bumped version of aidial-sdk to 0.1.2 * feat: supported history truncation via max_prompt_tokens/discarded_messages parameters
- Loading branch information
Showing
39 changed files
with
1,287 additions
and
769 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,4 @@ secret | |
dist | ||
.vscode/launch.json | ||
~* | ||
leftovers | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from typing_extensions import override | ||
|
||
from aidial_adapter_vertexai.llm.chat_completion_adapter import ( | ||
ChatCompletionAdapter, | ||
) | ||
from aidial_adapter_vertexai.llm.exceptions import ValidationError | ||
from aidial_adapter_vertexai.llm.vertex_ai_chat import VertexAIMessage | ||
from aidial_adapter_vertexai.universal_api.request import ModelParameters | ||
|
||
|
||
class BisonChatAdapter(ChatCompletionAdapter): | ||
@override | ||
def _create_instance( | ||
self, | ||
context: Optional[str], | ||
messages: List[VertexAIMessage], | ||
) -> Dict[str, Any]: | ||
return { | ||
"context": context or "", | ||
"messages": messages, | ||
} | ||
|
||
@override | ||
def _create_parameters( | ||
self, | ||
params: ModelParameters, | ||
) -> Dict[str, Any]: | ||
# See chat playground: https://console.cloud.google.com/vertex-ai/generative/language/create/chat | ||
ret: Dict[str, Any] = {} | ||
|
||
if params.max_tokens is not None: | ||
ret["maxOutputTokens"] = params.max_tokens | ||
|
||
if params.temperature is not None: | ||
ret["temperature"] = params.temperature | ||
|
||
if params.stop is not None: | ||
ret["stopSequences"] = ( | ||
[params.stop] if isinstance(params.stop, str) else params.stop | ||
) | ||
|
||
if params.top_p is not None: | ||
ret["topP"] = params.top_p | ||
|
||
return ret | ||
|
||
|
||
class BisonCodeChatAdapter(ChatCompletionAdapter): | ||
@override | ||
def _create_instance( | ||
self, | ||
context: Optional[str], | ||
messages: List[VertexAIMessage], | ||
) -> Dict[str, Any]: | ||
if context is not None: | ||
raise ValidationError("System message is not supported") | ||
|
||
return { | ||
"messages": messages, | ||
} | ||
|
||
@override | ||
def _create_parameters( | ||
self, | ||
params: ModelParameters, | ||
) -> Dict[str, Any]: | ||
ret: Dict[str, Any] = {} | ||
|
||
if params.max_tokens is not None: | ||
ret["maxOutputTokens"] = params.max_tokens | ||
|
||
if params.temperature is not None: | ||
ret["temperature"] = params.temperature | ||
|
||
if params.stop is not None: | ||
raise ValidationError( | ||
"stop sequences are not supported for code chat model" | ||
) | ||
|
||
if params.top_p is not None: | ||
raise ValidationError("top_p is not supported for code chat model") | ||
|
||
return ret |
Oops, something went wrong.