From 7f23f35fbd9f1b28f45de35a2211825b19484529 Mon Sep 17 00:00:00 2001 From: Daiyi Yang Date: Tue, 10 Sep 2024 23:15:02 +0000 Subject: [PATCH] [BFCL] Add empower functions models and the supporting handler --- .../bfcl/eval_checker/model_metadata.py | 12 +++ .../bfcl/model_handler/constant.py | 2 + .../bfcl/model_handler/handler_map.py | 5 +- .../bfcl/model_handler/oss_model/empower.py | 101 ++++++++++++++++++ 4 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/empower.py diff --git a/berkeley-function-call-leaderboard/bfcl/eval_checker/model_metadata.py b/berkeley-function-call-leaderboard/bfcl/eval_checker/model_metadata.py index 6c16a01e0..1f7a6bdf9 100644 --- a/berkeley-function-call-leaderboard/bfcl/eval_checker/model_metadata.py +++ b/berkeley-function-call-leaderboard/bfcl/eval_checker/model_metadata.py @@ -497,6 +497,18 @@ "Microsoft", "MIT", ], + "empower-dev/llama3-empower-functions-small-v1.1": [ + "Empower-Fucntions-Small-v1.1 (FC)", + "https://huggingface.co/empower-dev/llama3-empower-functions-small-v1.1", + "Empower.dev", + "apache-2.0" + ], + "empower-dev/llama3-empower-functions-large-v1.1": [ + "Empower-Fucntions-Large-v1.1 (FC)", + "https://huggingface.co/empower-dev/llama3-empower-functions-large-v1.1", + "Empower.dev", + "apache-2.0" + ] } INPUT_PRICE_PER_MILLION_TOKEN = { diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/constant.py b/berkeley-function-call-leaderboard/bfcl/model_handler/constant.py index f70e6da8e..bdc532477 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/constant.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/constant.py @@ -145,4 +145,6 @@ "THUDM/glm-4-9b-chat", "ibm-granite/granite-20b-functioncalling", "yi-large-fc", + "empower-dev/llama3-empower-functions-small-v1.1", + "empower-dev/llama3-empower-functions-large-v1.1", ] diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/handler_map.py b/berkeley-function-call-leaderboard/bfcl/model_handler/handler_map.py index 55de93e04..68d99f38b 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/handler_map.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/handler_map.py @@ -1,4 +1,5 @@ from bfcl.model_handler.oss_model.deepseek import DeepseekHandler +from bfcl.model_handler.oss_model.empower import EmpowerHandler from bfcl.model_handler.oss_model.gemma import GemmaHandler from bfcl.model_handler.oss_model.glaive import GlaiveHandler from bfcl.model_handler.oss_model.glm import GLMHandler @@ -92,7 +93,9 @@ "ibm-granite/granite-20b-functioncalling": GraniteHandler, # "MadeAgents/Hammer-7b": HammerHandler, # TODO: Update handler once they have a multi-turn format "THUDM/glm-4-9b-chat": GLMHandler, - + "empower-dev/llama3-empower-functions-small-v1.1": EmpowerHandler, + "empower-dev/llama3-empower-functions-large-v1.1": EmpowerHandler, + # Deprecated/outdated models, no longer on the leaderboard # "gorilla-openfunctions-v0": GorillaHandler, # "gpt-4o-2024-05-13": OpenAIHandler, diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/empower.py b/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/empower.py new file mode 100644 index 000000000..afa9e17ea --- /dev/null +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/empower.py @@ -0,0 +1,101 @@ +from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler +from bfcl.model_handler.model_style import ModelStyle +import json +from bfcl.model_handler.utils import ( + convert_to_tool, +) +from bfcl.model_handler.constant import ( + GORILLA_TO_OPENAPI, +) + + +class EmpowerHandler(OSSHandler): + def __init__(self, model_name, temperature) -> None: + super().__init__(model_name, temperature) + + def _preprocess_messages(self, messages): + # remove system message + messages = [ + message for message in messages if message['role'] != "system"] + + # combine tool responses + result = [] + temp_tool_content = None + for message in messages: + if message['role'] == 'tool': + decoded_content = json.loads(message['content']) + if temp_tool_content: + temp_tool_content.append(decoded_content) + else: + temp_tool_content = [decoded_content] + else: + if temp_tool_content: + result.append({ + 'role': 'tool', + 'content': json.dumps(temp_tool_content, indent=2) + }) + temp_tool_content = None + result.append(message) + if temp_tool_content: + result.append({ + 'role': 'tool', + 'content': json.dumps(temp_tool_content, indent=2) + }) + + return result + + def _format_prompt(self, messages, functions): + formatted_prompt = "<|begin_of_text|>" + + for idx, message in enumerate(self._preprocess_messages(messages)): + if idx == 0: + tools = convert_to_tool( + functions, GORILLA_TO_OPENAPI, ModelStyle.OSSMODEL + ) + message['content'] = "In this environment you have access to a set of functions defined in the JSON format you can use to address user's requests, use them if needed.\nFunctions:\n" \ + + json.dumps(tools, indent=2) \ + + "\n\n" \ + + "User Message:\n" \ + + message['content'] + else: + if message['role'] == 'tool': + message['role'] = 'user' + message['content'] = '' + message['content'] + elif message['role'] == 'user' and not message['content'].startswith('') and not message['content'].startswith(''): + message['content'] = '' + message['content'] + + formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>" + + formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n" + + return formatted_prompt + + def decode_ast(self, result, language="Python"): + if not result.startswith(''): + return [] + + # strip the function/conversation tag / + result_stripped = result[3:] + + decoded_output = [] + for invoked_function in json.loads(result_stripped): + name = invoked_function["name"] + params = invoked_function["arguments"] if "arguments" in invoked_function else { + } + decoded_output.append({name: params}) + + return decoded_output + + def decode_execute(self, result): + execution_list = [] + + for function_call in self.decode_ast(result): + for key, value in function_call.items(): + argument_list = [] + for k, v in value.items(): + argument_list.append(f'{k}={repr(v)}') + execution_list.append( + f"{key}({','.join(argument_list)})" + ) + + return execution_list