From 0eef957e0314f731b5927ea2d5c585d5470bf22a Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Thu, 7 Nov 2024 12:09:02 -0300 Subject: [PATCH] [Frontend] Tool calling parser for Granite 3.0 models (#9027) Signed-off-by: Max de Bayser --- .../serving/openai_compatible_server.md | 44 ++-- examples/tool_chat_template_granite.jinja | 40 ++++ tests/tool_use/conftest.py | 6 + tests/tool_use/utils.py | 37 +-- .../openai/tool_parsers/__init__.py | 5 +- .../tool_parsers/granite_tool_parser.py | 215 ++++++++++++++++++ 6 files changed, 314 insertions(+), 33 deletions(-) create mode 100644 examples/tool_chat_template_granite.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 0b5f75caf2475..a196f8b1e574e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -160,14 +160,7 @@ this, unless explicitly specified. :func: create_parser_for_docs :prog: vllm serve ``` -## Tool Calling in the Chat Completion API -### Named Function Calling -vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is -enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a -high-quality one. -To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and -specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. ### Config file @@ -196,12 +189,22 @@ The order of priorities is `command line > config file values > defaults`. --- ## Tool calling in the chat completion API -vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. + +vLLM supports named function calling and `auto` tool choice in the chat completion API. The `tool_choice` options `required` is **not yet supported** but on the roadmap. It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. + +### Named Function Calling +vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. + ### Automatic Function Calling To enable this feature, you should set the following flags: @@ -275,6 +278,21 @@ it works better with vLLM. Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +#### IBM Granite + +Supported models: +* `ibm-granite/granite-3.0-8b-instruct` + +Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` + +`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + +* `ibm-granite/granite-20b-functioncalling` + +Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` + +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + #### InternLM Models (`internlm`) @@ -297,16 +315,6 @@ AI21's Jamba-1.5 models are supported. Flags: `--tool-call-parser jamba` -#### IBM Granite (`granite-20b-fc`) - -Supported models: -* `ibm-granite/granite-20b-functioncalling` - -Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` - -The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. - - ### How to write a tool parser plugin A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja new file mode 100644 index 0000000000000..2cc19e77188dc --- /dev/null +++ b/examples/tool_chat_template_granite.jinja @@ -0,0 +1,40 @@ +{%- if tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|> +' }} + {%- for tool in tools %} + {{- tool | tojson(indent=4) }} + {%- if not loop.last %} + {{- ' + +' }} + {%- endif %} + {%- endfor %} + {{- '<|end_of_text|> +' }} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'user' %} + {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {% for tc in message.tool_calls %} + {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} + {% endfor %} + {{- '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %} + {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- endif %} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- endif %} +{%- endfor %} diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index ab6a29eba1b3f..294acf202a232 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -3,6 +3,7 @@ from huggingface_hub import snapshot_download from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform from .utils import ARGS, CONFIGS, ServerConfig @@ -11,6 +12,11 @@ @pytest.fixture(scope="session", params=CONFIGS.keys()) def server_config(request): config = CONFIGS[request.param] + + if current_platform.is_rocm() and not config.get("supports_rocm", True): + pytest.skip("The {} model can't be tested on the ROCm platform".format( + config["model"])) + # download model and tokenizer using transformers snapshot_download(config["model"]) yield CONFIGS[request.param] diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index d9ee0b1d54b0a..576555b368afe 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -13,6 +13,7 @@ class ServerConfig(TypedDict, total=False): arguments: List[str] system_prompt: Optional[str] supports_parallel: Optional[bool] + supports_rocm: Optional[bool] def patch_system_prompt(messages: List[Dict[str, Any]], @@ -36,7 +37,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] CONFIGS: Dict[str, ServerConfig] = { "hermes": { @@ -88,18 +89,28 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally." }, - ## FIXME: temporary disabled due to lack of hardware specification - ## for individual runs - #"granite20b": { - # "model": - # "ibm-granite/granite-20b-functioncalling", - # "arguments": [ - # "--tool-call-parser", "granite-20b-fc", "--chat-template", - # str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja") - # ], - # "supports_parallel": - # False, - #}, + "granite20b": { + "model": + "mbayser/granite-20b-functioncalling-FP8-KV", + "arguments": [ + "--tool-call-parser", "granite-20b-fc", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_granite_20b_fc.jinja"), + "--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20" + ], + "supports_parallel": + False, + "supports_rocm": + False, + }, + "granite8b": { + "model": + "ibm-granite/granite-3.0-8b-instruct", + "arguments": [ + "--tool-call-parser", "granite", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") + ], + }, "internlm": { "model": "internlm/internlm2_5-7b-chat", diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 1b299ce655570..2187862e8380b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,5 +1,6 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .granite_20b_fc_tool_parser import Granite20bFCToolParser +from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser @@ -8,6 +9,6 @@ __all__ = [ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", - "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", - "Llama3JsonToolParser", "JambaToolParser" + "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", + "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py new file mode 100644 index 0000000000000..b5854ca39ab47 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -0,0 +1,215 @@ +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite") +class GraniteToolParser(ToolParser): + """ + Tool call parser for the granite 3.0 models. Intended + for use with the examples/tool_chat_template_granite.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + stripped = model_output.strip() + if not stripped or stripped[0] != '[': + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + raw_function_calls = json.loads(stripped) + if not isinstance(raw_function_calls, list): + raise Exception( + f"Expected dict or list, got {type(raw_function_calls)}") + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + start_idx = consume_space(0, current_text) + if not current_text or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = None + is_complete = None + try: + tool_calls, end_idx = partial_json_loads( + current_text[start_idx:], flags) + if type(tool_calls) is list: + tool_call_arr = tool_calls + else: + return DeltaMessage(content=delta_text) + + is_complete = [True] * len(tool_calls) + if not is_complete_json( + current_text[start_idx:start_idx + end_idx]): + is_complete[-1] = False + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if not tool_call_arr: + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + delta = None + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None