Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Developer Applications Demo using Transformers Library #10

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def onnx_exists(onnx_file_path: str) -> bool:
)


def main(
def infer_api(
model_name: str,
num_cores: int,
prompt: str,
prompt: str = Constants.input_str,
aic_enable_depth_first: bool = False,
mos: int = -1,
cache_dir: str = Constants.CACHE_DIR,
Expand All @@ -60,7 +60,8 @@ def main(
device_group: List[int] = [
0,
],
) -> None:
skip_stats : bool = False,
hupreti marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
# Make
model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_name))
os.makedirs(model_card_dir, exist_ok=True)
Expand All @@ -76,21 +77,24 @@ def main(
onnx_dir_path = os.path.join(model_card_dir, "onnx")
onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")

# skip model download if qpc exits and we do not need stats
if not qpc_exists(qpc_dir_path) or not skip_stats:
# Get tokenizer
if hf_token is not None:
login(hf_token)
model_hf_path = hf_download(
repo_id=model_name,
cache_dir=cache_dir,
ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf"],
)
tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left")
if hf_token is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code (line 83-90) can also be moved under main function.
and pass model_hf_path as input to infer_api, and return qpc_path as output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should user downloads hf_model files from model_card and provide you model_hf_path ? please keep in mind that this is high_level api for user

login(hf_token)
model_hf_path = hf_download(
repo_id=model_name,
cache_dir=cache_dir,
ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf"],
)
tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left")

if qpc_exists(qpc_dir_path):
# execute
logger.info("Pre-compiled qpc found! Trying to execute with given prompt")
latency_stats_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
return
if not skip_stats:
latency_stats_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
return qpc_dir_path

if onnx_exists(onnx_model_path):
# Compile -> execute
Expand All @@ -110,8 +114,9 @@ def main(
assert (
generated_qpc_path == qpc_dir_path
), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}"
latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
return
if not skip_stats:
latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
return generated_qpc_path

#############################################
# hf model -> export -> compile -> execute
Expand Down Expand Up @@ -156,9 +161,43 @@ def main(
), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}"
logger.info(f"Compiled qpc files can be found at : {generated_qpc_path}")

# Execute
latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
if not skip_stats:
latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)

return generated_qpc_path

def main(
model_name: str,
num_cores: int,
prompt: str,
aic_enable_depth_first: bool = False,
mos: int = -1,
cache_dir: str = Constants.CACHE_DIR,
hf_token: str = None,
batch_size: int = 1,
prompt_len: int = 32,
ctx_len: int = 128,
mxfp6: bool = False,
device_group: List[int] = [
0,
],
) -> None:
_ = infer_api(
model_name=model_name,
num_cores=num_cores,
prompt=prompt,
aic_enable_depth_first=aic_enable_depth_first,
mos=mos,
cache_dir=cache_dir,
hf_token=hf_token,
batch_size=batch_size,
prompt_len=prompt_len,
ctx_len=ctx_len,
mxfp6=mxfp6,
device_group=device_group
)

return

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down
Loading