Skip to content

Commit

Permalink
Lint & Format fixed
Browse files Browse the repository at this point in the history
Signed-off-by: Abukhoyer Shaik <[email protected]>
  • Loading branch information
abukhoy committed Jan 16, 2025
1 parent 1788b49 commit 84a551a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 25 deletions.
33 changes: 33 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,39 @@ def compile(
aic_num_cores=num_cores,
**compiler_options,
)

# Construct the qconfig json file path
qconfig_file_path = os.path.join(os.path.dirname(self.qpc_path), "qconfig.json")
huggingface_config = self.model.config.__dict__

pytorch_transforms = [cls.__name__ for cls in self._pytorch_transforms]
onnx_transforms = [cls.__name__ for cls in self._onnx_transforms]

onnx_path = str(self.onnx_path)
specializations_file_path = str(os.path.join(os.path.dirname(self.qpc_path), "specializations.json"))
compile_dir = str(os.path.dirname(self.qpc_path))

create_and_dump_configs(
qconfig_file_path,
specializations_file_path,
huggingface_config,
pytorch_transforms,
onnx_transforms,
onnx_path,
compile_dir,
prefill_seq_len,
ctx_len,
batch_size,
full_batch_size,
num_devices,
num_cores,
mxfp6_matmul,
mxint8_kv_cache,
num_speculative_tokens,
enable_qnn,
qnn_config,
)

return qpc_path

# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
Expand Down
35 changes: 27 additions & 8 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
import yaml
from huggingface_hub import login, snapshot_download
from requests.exceptions import HTTPError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -414,15 +415,15 @@ def create_and_dump_configs(
mxfp6_matmul,
mxint8_kv_cache,
num_speculative_tokens,
enable_qnn,
qnn_config,
):
try:
# Parse the XML file
tree = ET.parse(Constants.SDK_APPS_XML)
root = tree.getroot()
# Try to find the base_version element and get its text
version = root.find(".//base_version").text
qaic_version = root.find(".//base_version").text
except (FileNotFoundError, ET.ParseError, AttributeError):
version = None
qaic_version = None

# Ensure all objects in the configs dictionary are JSON serializable
def make_serializable(obj):
Expand All @@ -433,7 +434,18 @@ def make_serializable(obj):
elif isinstance(obj, dict):
return {key: make_serializable(value) for key, value in obj.items()}
else:
return str(obj) # Convert non-serializable objects to strings
return str(obj)

qnn_config_path = (
(qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None
)
yaml_file_path = os.path.join(os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME), "sdk.yaml")
yaml_data = {}
try:
with open(yaml_file_path, "r") as file:
yaml_data = yaml.safe_load(file)
except Exception:
yaml_data = None

configs = {
"huggingface_config": make_serializable(huggingface_config),
Expand All @@ -444,7 +456,7 @@ def make_serializable(obj):
"onnx_path": onnx_path,
},
"compilation_config": {
"apps_sdk_version": version,
"apps_sdk_version": qaic_version,
"compile_dir": compile_dir,
"specializtions_file_path": specializations_file_path,
"prefill_seq_len": prefill_seq_len,
Expand All @@ -457,8 +469,15 @@ def make_serializable(obj):
"mxint8_kv_cache": mxint8_kv_cache,
"num_speculative_tokens": num_speculative_tokens,
},
"qnn_config": {
"enable_qnn": enable_qnn,
"qnn_config_path": qnn_config_path,
},
},
}
# Dump the configs dictionary to a JSON file

if yaml_data:
configs["qpc_config"]["qnn_config"].update(yaml_data)

with open(config_file_path, "w") as file:
json.dump(configs, file, indent=4)
17 changes: 10 additions & 7 deletions tests/qnn_tests/test_causal_lm_models_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# -----------------------------------------------------------------------------

import os

import numpy as np
import pytest
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -86,9 +88,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(

pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)

assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
)
assert (
pytorch_hf_tokens == pytorch_kv_tokens
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"

onnx_model_path = qeff_model.export()
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path)
Expand All @@ -98,20 +100,21 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
if not get_available_device_id():
pytest.skip("No available devices to run model on Cloud AI 100")

_ = qeff_model.compile(
qpc_path = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
num_cores=14,
mxfp6=False,
aic_enable_depth_first=False,
enable_qnn=True,
)
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
gen_len = ort_tokens.shape[-1]
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
)
assert (
ort_tokens == cloud_ai_100_tokens[:, :gen_len]
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."

# testing for CB models
model_hf, _ = load_causal_lm_model(model_config)
Expand Down
22 changes: 12 additions & 10 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

import os
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -110,9 +111,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(

pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)

assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
)
assert (
pytorch_hf_tokens == pytorch_kv_tokens
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"

onnx_model_path = qeff_model.export()
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
Expand All @@ -122,20 +123,21 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
if not get_available_device_id():
pytest.skip("No available devices to run model on Cloud AI 100")

_ = qeff_model.compile(
qpc_path = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
num_cores=14,
mxfp6=False,
aic_enable_depth_first=False,
num_speculative_tokens=num_speculative_tokens,
)
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
gen_len = ort_tokens.shape[-1]
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
)
assert (
ort_tokens == cloud_ai_100_tokens[:, :gen_len]
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."

# testing for CB models
model_hf, _ = load_causal_lm_model(model_config)
Expand Down Expand Up @@ -204,9 +206,9 @@ def test_causal_lm_export_with_deprecated_api(model_name):
new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path)
old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path)

assert (new_api_ort_tokens == old_api_ort_tokens).all(), (
"New API output does not match old API output for ONNX export function"
)
assert (
new_api_ort_tokens == old_api_ort_tokens
).all(), "New API output does not match old API output for ONNX export function"


@pytest.mark.on_qaic
Expand Down

0 comments on commit 84a551a

Please sign in to comment.