Skip to content

Commit

Permalink
Lint & Format fixed II
Browse files Browse the repository at this point in the history
Signed-off-by: Abukhoyer Shaik <[email protected]>
  • Loading branch information
abukhoy committed Jan 16, 2025
1 parent 84a551a commit 8ffbba1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
12 changes: 6 additions & 6 deletions tests/qnn_tests/test_causal_lm_models_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(

pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)

assert (
pytorch_hf_tokens == pytorch_kv_tokens
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
)

onnx_model_path = qeff_model.export()
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path)
Expand All @@ -112,9 +112,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
gen_len = ort_tokens.shape[-1]
assert (
ort_tokens == cloud_ai_100_tokens[:, :gen_len]
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
)

# testing for CB models
model_hf, _ = load_causal_lm_model(model_config)
Expand Down
18 changes: 9 additions & 9 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(

pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)

assert (
pytorch_hf_tokens == pytorch_kv_tokens
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
)

onnx_model_path = qeff_model.export()
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
Expand All @@ -135,9 +135,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
gen_len = ort_tokens.shape[-1]
assert (
ort_tokens == cloud_ai_100_tokens[:, :gen_len]
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
)

# testing for CB models
model_hf, _ = load_causal_lm_model(model_config)
Expand Down Expand Up @@ -206,9 +206,9 @@ def test_causal_lm_export_with_deprecated_api(model_name):
new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path)
old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path)

assert (
new_api_ort_tokens == old_api_ort_tokens
).all(), "New API output does not match old API output for ONNX export function"
assert (new_api_ort_tokens == old_api_ort_tokens).all(), (
"New API output does not match old API output for ONNX export function"
)


@pytest.mark.on_qaic
Expand Down

0 comments on commit 8ffbba1

Please sign in to comment.