From dff1e494f3c27f0e94b6cfc1b35bb95548f06be5 Mon Sep 17 00:00:00 2001 From: Shubham Agrawal Date: Wed, 8 Jan 2025 16:22:37 +0530 Subject: [PATCH] Support for mxint8 kv-cache added in QNN Compilation path. Signed-off-by: Shubham Agrawal --- QEfficient/compile/qnn_compiler.py | 36 ++++++++++----- QEfficient/utils/constants.py | 9 ++-- ...erate_qnn_network_specialization_config.py | 44 ++++++++++++++++++- 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/QEfficient/compile/qnn_compiler.py b/QEfficient/compile/qnn_compiler.py index ad5da9767..11926c9a1 100644 --- a/QEfficient/compile/qnn_compiler.py +++ b/QEfficient/compile/qnn_compiler.py @@ -11,7 +11,7 @@ from QEfficient.utils._utils import create_json, execute_command, load_json from QEfficient.utils.constants import QnnConstants -from QEfficient.utils.generate_qnn_network_specialization_config import fetch_nodes_info +from QEfficient.utils.generate_qnn_network_specialization_config import fetch_nodes_info, generate_data_format_config from QEfficient.utils.logging_utils import logger @@ -38,6 +38,8 @@ def __init__( qnn_target: str = QnnConstants.TARGET, qnn_config_path: Optional[str] = None, qnn_binary_dir: Optional[str] = None, + mxint8: Optional[bool] = False, + compiler_mxint8_mdp_io: Optional[bool] = False, **kwargs, ) -> None: self.onnx_path = onnx_path @@ -52,6 +54,8 @@ def __init__( self.compiler_mxfp6_matmul_weights = compiler_mxfp6_matmul_weights self.qnn_config_path = qnn_config_path self.qnn_binary_dir = qnn_binary_dir + self.mxint8 = mxint8 + self.compiler_mxint8_mdp_io = compiler_mxint8_mdp_io self.custom_io_path = custom_io_path self.dlc_model_path = os.path.join(qpc_base_path, f"{QnnConstants.MODEL_NAME}.dlc") self.qnn_target = qnn_target @@ -148,6 +152,7 @@ def create_qnn_compile_backend_json(self) -> str: "compiler_stat_level": QnnConstants.COMPILER_STAT_LEVEL, "compiler_stats_batch_size": QnnConstants.COMPILER_STATS_BATCH_SIZE, "compiler_time_passes": QnnConstants.COMPILER_TIME_PASSES, + "compiler_mxint8_mdp_io": self.compiler_mxint8_mdp_io, } if self.compiler_max_out_channel_split > 0: qnn_compile_backend["compiler_max_out_channel_split"] = str(self.compiler_max_out_channel_split) @@ -225,10 +230,10 @@ def converter(self) -> str: IMMUTABLE parameters which can not be overridden by the user using qnn_config.json: :input_network (str): Generated ``ONNX`` Model Path. :output_path (str): Path to generated DLC file, which is provided qpc_base_path/model.dlc - :io_config (str): Path to custom_io_config.yaml file created using GenerateQNNnetworkSpecializationconfig.py + :config (str): Path to custom_io_config.yaml file created using GenerateQNNnetworkSpecializationconfig.py :float_bias_bitwidth (int): Bitwidth to use for float bias tensor :float_bitwidth (int): Converts the graph to the specified float bitwidth, either 32 or 16(Default). - :keep_int64_inputs(flag): Passed by default. + :preserve_io_datatype(flag): Passed by default. CONVERTOR_ARGS_EXTENSION passed in qnn_config.json is appended to the command created. @@ -240,7 +245,7 @@ def converter(self) -> str: cmd = ( f"{converter_tool} --input_network {self.onnx_path} " f"--output_path {self.dlc_model_path} " - f"--io_config {self.custom_io_path} " + f"--config {self.custom_io_path} " f"--float_bias_bitwidth {QnnConstants.FLOAT_BIAS_BITWIDTH} " f"--float_bitwidth {QnnConstants.FLOAT_BITWIDTH} " ) @@ -287,6 +292,17 @@ def generate_context_binary(self) -> str: f"--config_file {config_file_path} " ) + if self.mxint8: + data_format_file_path = os.path.join(self.qpc_base_path, QnnConstants.QNN_DATA_FORMAT_CONFIG_NAME) + generate_data_format_config( + self.onnx_path, model_dlc_name=QnnConstants.MODEL_NAME, file_path=data_format_file_path + ) + if not os.path.isfile(data_format_file_path): + raise FileNotFoundError( + f"file {data_format_file_path} needs to exist in the qpc_base_path for mxint8 compilation. Please rerun infer/compile Api" + ) + cmd += f"--data_format_config {data_format_file_path} " + if self.qnn_config and QnnConstants.CONTEXT_BIN_ARGS_EXTENSION_STR in self.qnn_config: if "--log_level " not in self.qnn_config[QnnConstants.CONTEXT_BIN_ARGS_EXTENSION_STR]: cmd += f"--log_level {QnnConstants.LOG_LEVEL} " @@ -353,20 +369,15 @@ def compile( if kwargs: logger.warning("Extra arguments to QNN compilation are not supported as of now!") - raise NotImplementedError("Can't handle extra compilation args now!") - if allow_mxint8_mdp_io: - logger.warning("QNN doesn't support allow_mxint8_mdp_io. Bypassing the value passed for allow_mxint8_mdp_io") - - if mxint8: - logger.warning("QNN doesn't support mxint8. Bypassing the value passed for mxint8") - os.makedirs(qpc_base_path, exist_ok=True) # Created custom_io_config.yaml file for QNN-Convertor stage. # TODO To make custom_io_config.yaml configurable as not all models need it. custom_io_file_path = os.path.join(qpc_base_path, "custom_io_config.yaml") + + kv_precision = "uint8" if mxint8 else "float16" fetch_nodes_info( onnx_graph_path=onnx_path, batch_size=batch_size, @@ -374,6 +385,7 @@ def compile( context_length=ctx_len, file_path=custom_io_file_path, full_batch_size=full_batch_size, + kv_precision=kv_precision, ) if not os.path.isfile(custom_io_file_path): @@ -395,6 +407,8 @@ def compile( ctx_len=ctx_len, compiler_mxfp6_matmul_weights=mxfp6, qnn_binary_dir=qnn_binary_dir, + mxint8=mxint8, + compiler_mxint8_mdp_io=allow_mxint8_mdp_io, ) compiled_binary_path = qnn_obj.compile() diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index bfbac905f..ab861a788 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -81,6 +81,7 @@ class QnnConstants: # QNN Compilation target names MODEL_NAME = "model" + QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json" CONTEXT_BIN_NAME = "qnngraph.serialized" CONTEXT_BIN_QPC_NAME = "programqpc.bin" @@ -90,7 +91,7 @@ class QnnConstants: # Convertor Arguments FLOAT_BITWIDTH = 16 FLOAT_BIAS_BITWIDTH = 32 - CONVERTOR_DEFAULT_ARGS = "--keep_int64_inputs --onnx_no_simplification " + CONVERTOR_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification " # Context-Binary-Generator Arguments LOG_LEVEL = "error" @@ -118,11 +119,11 @@ class QnnConstants: IMMUTABLE_CONVERTOR_ARGS = [ "--input_network ", "--output_path ", - "--io_config ", + "--config ", "--float_bias_bitwidth ", "--float_bitwidth ", - "--keep_int64_inputs", - "--onnx_no_simplification", + "--preserve_io_datatype", + "--onnx_skip_simplification", "--onnx_defer_loading", ] diff --git a/QEfficient/utils/generate_qnn_network_specialization_config.py b/QEfficient/utils/generate_qnn_network_specialization_config.py index 0e5e17c08..ca78c658c 100644 --- a/QEfficient/utils/generate_qnn_network_specialization_config.py +++ b/QEfficient/utils/generate_qnn_network_specialization_config.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import json from typing import Optional import onnx @@ -24,6 +25,7 @@ def fetch_nodes_info( file_path: str = "custom_io_config.yaml", full_batch_size: Optional[int] = None, decode_only: Optional[bool] = False, + kv_precision: Optional[str] = "float16", ) -> None: # Load the ONNX model onnx_model = onnx.load(onnx_graph_path) @@ -38,7 +40,7 @@ def fetch_nodes_info( input_info = {} input_info["DataType"] = str(helper.tensor_dtype_to_np_dtype(node.type.tensor_type.elem_type)) if "past_key" in node.name or "past_value" in node.name: - input_info["DataType"] = "float16" + input_info["DataType"] = kv_precision if "batch_index" in node.name: if full_batch_size: @@ -128,7 +130,7 @@ def fetch_nodes_info( output_info = {} output_info["DataType"] = str(helper.tensor_dtype_to_np_dtype(output.type.tensor_type.elem_type)) if "past_key" in output.name or "past_value" in output.name: - output_info["DataType"] = "float16" + output_info["DataType"] = kv_precision elif "logits" in output.name: output_info["DataType"] = "float32" output_nodes_info.append({"Name": output.name, "Desired Model Parameters": output_info}) @@ -142,3 +144,41 @@ def fetch_nodes_info( yaml.dump(final_dict, yaml_file, default_flow_style=False, sort_keys=False) except Exception as e: print(f"Failed to create YAML File for QNN Network Specialization Configuration{file_path}: {e}") + + +def generate_data_format_config( + onnx_graph_path: str, + *, + data_format: Optional[str] = "QNN_TENSOR_DATA_FORMAT_MX", + model_dlc_name: Optional[str] = "model", + file_path: str = "qnn_data_format_config.json", +) -> None: + # Load the ONNX model + onnx_model = onnx.load(onnx_graph_path) + + kv_nodes: list = [] + + for input in onnx_model.graph.input: + if "past_key" in input.name or "past_value" in input.name: + kv_nodes.append((input.name).replace(".", "_")) + for output in onnx_model.graph.output: + if "past_key" in output.name or "past_value" in output.name: + kv_nodes.append((output.name).replace(".", "_")) + kv_overrides = {} + + kv_overrides["graphs"] = [ + { + "graph_name": model_dlc_name + "_configuration_1", + "tensors": [{"tensor_name": node, "dataFormat": data_format} for node in kv_nodes], + }, + { + "graph_name": model_dlc_name + "_configuration_2", + "tensors": [{"tensor_name": node, "dataFormat": data_format} for node in kv_nodes], + }, + ] + + try: + with open(file_path, "w") as json_file: + json.dump(kv_overrides, json_file, indent=4) + except Exception as e: + print(f"Failed to create JSON File for QNN Data Format Configuration{file_path}: {e}")