Skip to content

Commit

Permalink
Support for mxint8 kv-cache added in QNN Compilation path.
Browse files Browse the repository at this point in the history
Signed-off-by: Shubham Agrawal <[email protected]>
  • Loading branch information
shubhagr-quic committed Jan 9, 2025
1 parent 41cf878 commit dff1e49
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 17 deletions.
36 changes: 25 additions & 11 deletions QEfficient/compile/qnn_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from QEfficient.utils._utils import create_json, execute_command, load_json
from QEfficient.utils.constants import QnnConstants
from QEfficient.utils.generate_qnn_network_specialization_config import fetch_nodes_info
from QEfficient.utils.generate_qnn_network_specialization_config import fetch_nodes_info, generate_data_format_config
from QEfficient.utils.logging_utils import logger


Expand All @@ -38,6 +38,8 @@ def __init__(
qnn_target: str = QnnConstants.TARGET,
qnn_config_path: Optional[str] = None,
qnn_binary_dir: Optional[str] = None,
mxint8: Optional[bool] = False,
compiler_mxint8_mdp_io: Optional[bool] = False,
**kwargs,
) -> None:
self.onnx_path = onnx_path
Expand All @@ -52,6 +54,8 @@ def __init__(
self.compiler_mxfp6_matmul_weights = compiler_mxfp6_matmul_weights
self.qnn_config_path = qnn_config_path
self.qnn_binary_dir = qnn_binary_dir
self.mxint8 = mxint8
self.compiler_mxint8_mdp_io = compiler_mxint8_mdp_io
self.custom_io_path = custom_io_path
self.dlc_model_path = os.path.join(qpc_base_path, f"{QnnConstants.MODEL_NAME}.dlc")
self.qnn_target = qnn_target
Expand Down Expand Up @@ -148,6 +152,7 @@ def create_qnn_compile_backend_json(self) -> str:
"compiler_stat_level": QnnConstants.COMPILER_STAT_LEVEL,
"compiler_stats_batch_size": QnnConstants.COMPILER_STATS_BATCH_SIZE,
"compiler_time_passes": QnnConstants.COMPILER_TIME_PASSES,
"compiler_mxint8_mdp_io": self.compiler_mxint8_mdp_io,
}
if self.compiler_max_out_channel_split > 0:
qnn_compile_backend["compiler_max_out_channel_split"] = str(self.compiler_max_out_channel_split)
Expand Down Expand Up @@ -225,10 +230,10 @@ def converter(self) -> str:
IMMUTABLE parameters which can not be overridden by the user using qnn_config.json:
:input_network (str): Generated ``ONNX`` Model Path.
:output_path (str): Path to generated DLC file, which is provided qpc_base_path/model.dlc
:io_config (str): Path to custom_io_config.yaml file created using GenerateQNNnetworkSpecializationconfig.py
:config (str): Path to custom_io_config.yaml file created using GenerateQNNnetworkSpecializationconfig.py
:float_bias_bitwidth (int): Bitwidth to use for float bias tensor
:float_bitwidth (int): Converts the graph to the specified float bitwidth, either 32 or 16(Default).
:keep_int64_inputs(flag): Passed by default.
:preserve_io_datatype(flag): Passed by default.
CONVERTOR_ARGS_EXTENSION passed in qnn_config.json is appended to the command created.
Expand All @@ -240,7 +245,7 @@ def converter(self) -> str:
cmd = (
f"{converter_tool} --input_network {self.onnx_path} "
f"--output_path {self.dlc_model_path} "
f"--io_config {self.custom_io_path} "
f"--config {self.custom_io_path} "
f"--float_bias_bitwidth {QnnConstants.FLOAT_BIAS_BITWIDTH} "
f"--float_bitwidth {QnnConstants.FLOAT_BITWIDTH} "
)
Expand Down Expand Up @@ -287,6 +292,17 @@ def generate_context_binary(self) -> str:
f"--config_file {config_file_path} "
)

if self.mxint8:
data_format_file_path = os.path.join(self.qpc_base_path, QnnConstants.QNN_DATA_FORMAT_CONFIG_NAME)
generate_data_format_config(
self.onnx_path, model_dlc_name=QnnConstants.MODEL_NAME, file_path=data_format_file_path
)
if not os.path.isfile(data_format_file_path):
raise FileNotFoundError(
f"file {data_format_file_path} needs to exist in the qpc_base_path for mxint8 compilation. Please rerun infer/compile Api"
)
cmd += f"--data_format_config {data_format_file_path} "

if self.qnn_config and QnnConstants.CONTEXT_BIN_ARGS_EXTENSION_STR in self.qnn_config:
if "--log_level " not in self.qnn_config[QnnConstants.CONTEXT_BIN_ARGS_EXTENSION_STR]:
cmd += f"--log_level {QnnConstants.LOG_LEVEL} "
Expand Down Expand Up @@ -353,27 +369,23 @@ def compile(

if kwargs:
logger.warning("Extra arguments to QNN compilation are not supported as of now!")

raise NotImplementedError("Can't handle extra compilation args now!")

if allow_mxint8_mdp_io:
logger.warning("QNN doesn't support allow_mxint8_mdp_io. Bypassing the value passed for allow_mxint8_mdp_io")

if mxint8:
logger.warning("QNN doesn't support mxint8. Bypassing the value passed for mxint8")

os.makedirs(qpc_base_path, exist_ok=True)

# Created custom_io_config.yaml file for QNN-Convertor stage.
# TODO To make custom_io_config.yaml configurable as not all models need it.
custom_io_file_path = os.path.join(qpc_base_path, "custom_io_config.yaml")

kv_precision = "uint8" if mxint8 else "float16"
fetch_nodes_info(
onnx_graph_path=onnx_path,
batch_size=batch_size,
sequence_length=prompt_len,
context_length=ctx_len,
file_path=custom_io_file_path,
full_batch_size=full_batch_size,
kv_precision=kv_precision,
)

if not os.path.isfile(custom_io_file_path):
Expand All @@ -395,6 +407,8 @@ def compile(
ctx_len=ctx_len,
compiler_mxfp6_matmul_weights=mxfp6,
qnn_binary_dir=qnn_binary_dir,
mxint8=mxint8,
compiler_mxint8_mdp_io=allow_mxint8_mdp_io,
)

compiled_binary_path = qnn_obj.compile()
Expand Down
9 changes: 5 additions & 4 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class QnnConstants:

# QNN Compilation target names
MODEL_NAME = "model"
QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json"
CONTEXT_BIN_NAME = "qnngraph.serialized"
CONTEXT_BIN_QPC_NAME = "programqpc.bin"

Expand All @@ -90,7 +91,7 @@ class QnnConstants:
# Convertor Arguments
FLOAT_BITWIDTH = 16
FLOAT_BIAS_BITWIDTH = 32
CONVERTOR_DEFAULT_ARGS = "--keep_int64_inputs --onnx_no_simplification "
CONVERTOR_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification "

# Context-Binary-Generator Arguments
LOG_LEVEL = "error"
Expand Down Expand Up @@ -118,11 +119,11 @@ class QnnConstants:
IMMUTABLE_CONVERTOR_ARGS = [
"--input_network ",
"--output_path ",
"--io_config ",
"--config ",
"--float_bias_bitwidth ",
"--float_bitwidth ",
"--keep_int64_inputs",
"--onnx_no_simplification",
"--preserve_io_datatype",
"--onnx_skip_simplification",
"--onnx_defer_loading",
]

Expand Down
44 changes: 42 additions & 2 deletions QEfficient/utils/generate_qnn_network_specialization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

import json
from typing import Optional

import onnx
Expand All @@ -24,6 +25,7 @@ def fetch_nodes_info(
file_path: str = "custom_io_config.yaml",
full_batch_size: Optional[int] = None,
decode_only: Optional[bool] = False,
kv_precision: Optional[str] = "float16",
) -> None:
# Load the ONNX model
onnx_model = onnx.load(onnx_graph_path)
Expand All @@ -38,7 +40,7 @@ def fetch_nodes_info(
input_info = {}
input_info["DataType"] = str(helper.tensor_dtype_to_np_dtype(node.type.tensor_type.elem_type))
if "past_key" in node.name or "past_value" in node.name:
input_info["DataType"] = "float16"
input_info["DataType"] = kv_precision

if "batch_index" in node.name:
if full_batch_size:
Expand Down Expand Up @@ -128,7 +130,7 @@ def fetch_nodes_info(
output_info = {}
output_info["DataType"] = str(helper.tensor_dtype_to_np_dtype(output.type.tensor_type.elem_type))
if "past_key" in output.name or "past_value" in output.name:
output_info["DataType"] = "float16"
output_info["DataType"] = kv_precision
elif "logits" in output.name:
output_info["DataType"] = "float32"
output_nodes_info.append({"Name": output.name, "Desired Model Parameters": output_info})
Expand All @@ -142,3 +144,41 @@ def fetch_nodes_info(
yaml.dump(final_dict, yaml_file, default_flow_style=False, sort_keys=False)
except Exception as e:
print(f"Failed to create YAML File for QNN Network Specialization Configuration{file_path}: {e}")


def generate_data_format_config(
onnx_graph_path: str,
*,
data_format: Optional[str] = "QNN_TENSOR_DATA_FORMAT_MX",
model_dlc_name: Optional[str] = "model",
file_path: str = "qnn_data_format_config.json",
) -> None:
# Load the ONNX model
onnx_model = onnx.load(onnx_graph_path)

kv_nodes: list = []

for input in onnx_model.graph.input:
if "past_key" in input.name or "past_value" in input.name:
kv_nodes.append((input.name).replace(".", "_"))
for output in onnx_model.graph.output:
if "past_key" in output.name or "past_value" in output.name:
kv_nodes.append((output.name).replace(".", "_"))
kv_overrides = {}

kv_overrides["graphs"] = [
{
"graph_name": model_dlc_name + "_configuration_1",
"tensors": [{"tensor_name": node, "dataFormat": data_format} for node in kv_nodes],
},
{
"graph_name": model_dlc_name + "_configuration_2",
"tensors": [{"tensor_name": node, "dataFormat": data_format} for node in kv_nodes],
},
]

try:
with open(file_path, "w") as json_file:
json.dump(kv_overrides, json_file, indent=4)
except Exception as e:
print(f"Failed to create JSON File for QNN Data Format Configuration{file_path}: {e}")

0 comments on commit dff1e49

Please sign in to comment.