Skip to content

Commit

Permalink
feat: migrated latest fixes (#21)
Browse files Browse the repository at this point in the history
* feat: added timings to debug prints
* feat: allowed empty messages for Bison models
* feat: supported streaming for Bison models
* fix: disabled malfunctioning clustering- and classification- gecko endpoints
* chore: bumped version of aidial-sdk to 0.1.2
* feat: supported history truncation via max_prompt_tokens/discarded_messages parameters
  • Loading branch information
adubovik authored Nov 14, 2023
1 parent 423c56e commit f6c21ea
Show file tree
Hide file tree
Showing 39 changed files with 1,287 additions and 769 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ secret
dist
.vscode/launch.json
~*
leftovers
.idea/
13 changes: 8 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ PORT ?= 5001
IMAGE_NAME ?= ai-dial-adapter-vertexai
PLATFORM ?= linux/amd64
DEV_PYTHON ?= 3.11
DOCKER ?= docker
ARGS=

.PHONY: all install build serve clean lint format test integration_tests docker_build docker_run
Expand All @@ -16,7 +17,9 @@ build: install
poetry build

serve: install
poetry run uvicorn "aidial_adapter_vertexai.app:app" --reload --host "0.0.0.0" --port $(PORT) --workers=1 --env-file ./.env
poetry run uvicorn "aidial_adapter_vertexai.app:app" \
--reload --host "0.0.0.0" --port $(PORT) \
--workers=1 --env-file ./.env

clean:
poetry run python -m scripts.clean
Expand All @@ -35,12 +38,12 @@ integration_tests: install
poetry run nox -s integration_tests

docker_test:
docker build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test .
docker run --platform $(PLATFORM) --rm $(IMAGE_NAME):test
$(DOCKER) build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test .
$(DOCKER) run --platform $(PLATFORM) --rm $(IMAGE_NAME):test

docker_serve:
docker build --platform $(PLATFORM) -t $(IMAGE_NAME):dev .
docker run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev
$(DOCKER) build --platform $(PLATFORM) -t $(IMAGE_NAME):dev .
$(DOCKER) run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev

help:
@echo '===================='
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|DEFAULT_REGION||Default region for Vertex AI (e.g. "us-central1")|
|GCP_PROJECT_ID||GCP project ID|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level|
|WEB_CONCURRENCY|1|Number of workers for the server|
|TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests|

Expand Down
122 changes: 110 additions & 12 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,105 @@
import asyncio
from typing import List
from typing import List, Optional, Tuple

from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion import (
ChatCompletion,
Message,
Request,
Response,
Role,
)

from aidial_adapter_vertexai.llm.consumer import ChoiceConsumer
from aidial_adapter_vertexai.llm.exceptions import ValidationError
from aidial_adapter_vertexai.llm.history_trimming import (
get_discarded_messages_count,
)
from aidial_adapter_vertexai.llm.vertex_ai_adapter import (
get_chat_completion_model,
)
from aidial_adapter_vertexai.llm.vertex_ai_chat import (
VertexAIAuthor,
VertexAIMessage,
)
from aidial_adapter_vertexai.llm.vertex_ai_deployments import (
ChatCompletionDeployment,
)
from aidial_adapter_vertexai.server.exceptions import dial_exception_decorator
from aidial_adapter_vertexai.universal_api.request import ModelParameters
from aidial_adapter_vertexai.universal_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log

_SUPPORTED_ROLES = {Role.SYSTEM, Role.USER, Role.ASSISTANT}


def _parse_message(message: Message) -> VertexAIMessage:
author = (
VertexAIAuthor.BOT
if message.role == Role.ASSISTANT
else VertexAIAuthor.USER
)
return VertexAIMessage(author=author, content=message.content) # type: ignore


def _validate_messages_and_split(
messages: List[Message],
) -> Tuple[Optional[str], List[Message]]:
if len(messages) == 0:
raise ValidationError("The chat history must have at least one message")

for message in messages:
if message.content is None:
raise ValidationError("Message content must be present")

if message.role not in _SUPPORTED_ROLES:
raise ValidationError(
f"Message role must be one of {_SUPPORTED_ROLES}"
)

context: Optional[str] = None
if len(messages) > 0 and messages[0].role == Role.SYSTEM:
context = messages[0].content or ""
context = context if context.strip() else None
messages = messages[1:]

if len(messages) == 0 and context is not None:
raise ValidationError(
"The chat history must have at least one non-system message"
)

role: Optional[Role] = None
for message in messages:
if message.role == Role.SYSTEM:
raise ValidationError(
"System messages other than the initial system message are not allowed"
)

# Bison doesn't support empty messages,
# so we replace it with a single space.
message.content = message.content or " "

if role == message.role:
raise ValidationError("Messages must alternate between authors")

role = message.role

if len(messages) % 2 == 0:
raise ValidationError(
"There should be odd number of messages for correct alternating turn"
)

if messages[-1].role != Role.USER:
raise ValidationError("The last message must be a user message")

return context, messages


def _parse_history(
history: List[Message],
) -> Tuple[Optional[str], List[VertexAIMessage]]:
context, history = _validate_messages_and_split(history)

return context, list(map(_parse_message, history))


class VertexAIChatCompletion(ChatCompletion):
Expand All @@ -28,21 +116,31 @@ async def chat_completion(self, request: Request, response: Response):
deployment=ChatCompletionDeployment(request.deployment_id),
project_id=self.project_id,
location=self.region,
model_params=ModelParameters.create(request),
)

streaming = bool(request.stream)

async def generate_response(idx: int) -> TokenUsage:
content, usage = await model.chat(streaming, request.messages)
params = ModelParameters.create(request)
context, messages = _parse_history(request.messages)
discarded_messages_count: Optional[int] = None
if params.max_prompt_tokens is not None:
discarded_messages_count = await get_discarded_messages_count(
model, context, messages, params.max_prompt_tokens
)
messages = messages[discarded_messages_count:]

async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
with response.create_choice() as choice:
choice.append_content(content)
return usage
consumer = ChoiceConsumer(choice)
await model.chat(consumer, context, messages, params)
usage.accumulate(consumer.usage)

usage = TokenUsage()

usages: List[TokenUsage] = await asyncio.gather(
*(generate_response(idx) for idx in range(request.n or 1))
await asyncio.gather(
*(generate_response(usage, idx) for idx in range(request.n or 1))
)

usage = sum(usages, TokenUsage())
log.debug(f"usage: {usage}")
response.set_usage(usage.prompt_tokens, usage.completion_tokens)

if discarded_messages_count is not None:
response.set_discarded_messages(discarded_messages_count)
85 changes: 85 additions & 0 deletions aidial_adapter_vertexai/llm/bison_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, Dict, List, Optional

from typing_extensions import override

from aidial_adapter_vertexai.llm.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.exceptions import ValidationError
from aidial_adapter_vertexai.llm.vertex_ai_chat import VertexAIMessage
from aidial_adapter_vertexai.universal_api.request import ModelParameters


class BisonChatAdapter(ChatCompletionAdapter):
@override
def _create_instance(
self,
context: Optional[str],
messages: List[VertexAIMessage],
) -> Dict[str, Any]:
return {
"context": context or "",
"messages": messages,
}

@override
def _create_parameters(
self,
params: ModelParameters,
) -> Dict[str, Any]:
# See chat playground: https://console.cloud.google.com/vertex-ai/generative/language/create/chat
ret: Dict[str, Any] = {}

if params.max_tokens is not None:
ret["maxOutputTokens"] = params.max_tokens

if params.temperature is not None:
ret["temperature"] = params.temperature

if params.stop is not None:
ret["stopSequences"] = (
[params.stop] if isinstance(params.stop, str) else params.stop
)

if params.top_p is not None:
ret["topP"] = params.top_p

return ret


class BisonCodeChatAdapter(ChatCompletionAdapter):
@override
def _create_instance(
self,
context: Optional[str],
messages: List[VertexAIMessage],
) -> Dict[str, Any]:
if context is not None:
raise ValidationError("System message is not supported")

return {
"messages": messages,
}

@override
def _create_parameters(
self,
params: ModelParameters,
) -> Dict[str, Any]:
ret: Dict[str, Any] = {}

if params.max_tokens is not None:
ret["maxOutputTokens"] = params.max_tokens

if params.temperature is not None:
ret["temperature"] = params.temperature

if params.stop is not None:
raise ValidationError(
"stop sequences are not supported for code chat model"
)

if params.top_p is not None:
raise ValidationError("top_p is not supported for code chat model")

return ret
Loading

0 comments on commit f6c21ea

Please sign in to comment.