-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from nsosio/feat/onnx-runtime
Added ONNX Runtime
- Loading branch information
Showing
5 changed files
with
323 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import argparse | ||
import logging | ||
import sys | ||
import time | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
from optimum.onnxruntime import ORTModelForCausalLM | ||
from transformers import AutoTokenizer | ||
|
||
logging.basicConfig( | ||
stream=sys.stdout, | ||
level=logging.INFO, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
) | ||
|
||
|
||
class ONNXBenchmark: | ||
def __init__(self, model_path, device="cpu"): | ||
self.model_path = model_path | ||
self.device = device | ||
self.provider = ( | ||
"CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider" | ||
) | ||
self.results = [] | ||
|
||
def load_model(self): | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | ||
self.model = ORTModelForCausalLM.from_pretrained( | ||
self.model_path, | ||
use_cache=False, | ||
use_io_binding=False, | ||
provider=self.provider, | ||
) | ||
return self | ||
|
||
def run_model(self, prompt, max_tokens) -> float: | ||
device_str = "cuda" if self.device == "cuda" else "cpu" | ||
inputs = self.tokenizer(prompt, return_tensors="pt").to(device_str) | ||
start = time.time() | ||
gen_tokens = self.model.generate(**inputs, max_length=max_tokens) | ||
tokens_per_second = (gen_tokens.shape[1] - inputs["input_ids"].shape[1]) / ( | ||
time.time() - start | ||
) | ||
return tokens_per_second | ||
|
||
def benchmark(self, prompt, max_tokens, repetitions): | ||
for i in range(repetitions): | ||
logging.info( | ||
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]" | ||
) | ||
tokens_per_second = self.run_model(prompt, max_tokens) | ||
self.results.append(tokens_per_second) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="llama.cpp Benchmark Llama model.") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
help="The prompt for the model.", | ||
) | ||
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.") | ||
parser.add_argument( | ||
"--repetitions", | ||
type=int, | ||
help="The number of repetitions for the benchmark.", | ||
) | ||
parser.add_argument( | ||
"--device", | ||
help="Device to use for the benchmark.", | ||
) | ||
parser.add_argument( | ||
"--log_file", | ||
type=str, | ||
help="Path to the log file for writing logs (in append mode).", | ||
) | ||
parser.add_argument( | ||
"--models_dir", | ||
type=str, | ||
help="Path to the models directory.", | ||
) | ||
args = parser.parse_args() | ||
logging.info( | ||
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} " | ||
+ f"repetitions={args.repetitions} device={args.device}" | ||
) | ||
report = defaultdict(lambda: defaultdict(float)) | ||
onnx_bench = ONNXBenchmark( | ||
f"{args.models_dir}/llama-2-7b-onnx", | ||
device=args.device, | ||
).load_model() | ||
onnx_bench.benchmark( | ||
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions | ||
) | ||
report["onnx"]["float16"] = { | ||
"mean": np.mean(onnx_bench.results), | ||
"std": np.std(onnx_bench.results), | ||
} | ||
|
||
logging.info("Benchmark report") | ||
with open(args.log_file, "a") as file: | ||
for framework, quantizations in report.items(): | ||
for quantization, stats in quantizations.items(): | ||
logging.info( | ||
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}" | ||
) | ||
print( | ||
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}", | ||
file=file, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
#!/bin/bash | ||
|
||
######################################################################################################## | ||
# Script: bench.sh | ||
# Description: This script runs benchmarks onnxruntime llama benchmark. | ||
# | ||
# Usage: ./bench.sh [OPTIONS] | ||
# OPTIONS: | ||
# -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer') | ||
# -r, --repetitions Number of repetitions for benchmarks (default: 2) | ||
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 100) | ||
# -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu') | ||
# -lf, --log_file Logging file name. | ||
# -md, --models_dir Models directory. | ||
# -h, --help Show this help message | ||
######################################################################################################## | ||
|
||
set -euo pipefail | ||
|
||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | ||
|
||
print_usage() { | ||
echo "Usage: $0 [OPTIONS]" | ||
echo "OPTIONS:" | ||
echo " -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer')" | ||
echo " -r, --repetitions Number of repetitions for benchmarks (default: 2)" | ||
echo " -m, --max_tokens Maximum number of tokens for benchmarks (default: 100)" | ||
echo " -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu')" | ||
echo " -lf, --log_file Logging file name." | ||
echo " -md, --models_dir Models directory." | ||
echo " -h, --help Show this help message" | ||
exit 1 | ||
} | ||
|
||
check_cuda() { | ||
if command -v nvcc &> /dev/null | ||
then | ||
echo -e "\nUsing CUDA" | ||
nvcc --version | ||
else | ||
echo -e "\nCUDA is not available." | ||
exit 1 | ||
fi | ||
} | ||
|
||
check_platform() { | ||
local platform | ||
platform=$(uname -s) | ||
if [[ "$platform" == "Linux" ]]; then | ||
echo "Running on Linux." | ||
elif [[ "$platform" == "Darwin" ]]; then | ||
echo "Running on Mac OS." | ||
else | ||
echo "Unknown platform." | ||
exit 1 | ||
fi | ||
} | ||
|
||
check_python() { | ||
if command -v python &> /dev/null | ||
then | ||
echo -e "\nUsing $(python --version)." | ||
else | ||
echo -e "\nPython does not exist." | ||
exit 1 | ||
fi | ||
} | ||
|
||
setup() { | ||
echo -e "\nSetting up with $SCRIPT_DIR/setup.sh..." | ||
bash "$SCRIPT_DIR"/setup.sh "$1" | ||
} | ||
|
||
run_benchmarks() { | ||
local PROMPT="$1" | ||
local REPETITIONS="$2" | ||
local MAX_TOKENS="$3" | ||
local DEVICE="$4" | ||
local LOG_FILENAME="$5" | ||
local MODELS_DIR="$6" | ||
|
||
# shellcheck disable=SC1091 | ||
source "$SCRIPT_DIR/venv/bin/activate" | ||
|
||
python "$SCRIPT_DIR"/bench.py \ | ||
--prompt "$PROMPT" \ | ||
--repetitions "$REPETITIONS" \ | ||
--max_tokens "$MAX_TOKENS" \ | ||
--log_file "$LOG_FILENAME" \ | ||
--models_dir "$MODELS_DIR" \ | ||
--device "$DEVICE" | ||
|
||
} | ||
|
||
# Parse command-line arguments | ||
while [ "$#" -gt 0 ]; do | ||
case "$1" in | ||
-p|--prompt) | ||
PROMPT="$2" | ||
shift 2 | ||
;; | ||
-r|--repetitions) | ||
REPETITIONS="$2" | ||
shift 2 | ||
;; | ||
-m|--max_tokens) | ||
MAX_TOKENS="$2" | ||
shift 2 | ||
;; | ||
-d|--device) | ||
DEVICE="$2" | ||
case "$DEVICE" in | ||
"cuda" | "metal" | "cpu") | ||
;; | ||
*) | ||
echo "Invalid value for --device. Please use 'cuda', 'gpu' or 'cpu'." | ||
print_usage | ||
;; | ||
esac | ||
if [ "$DEVICE" == "cuda" ]; then | ||
check_cuda | ||
fi | ||
if [ "$DEVICE" == "metal" ]; then | ||
echo "Metal not supported!" | ||
exit 0 | ||
fi | ||
if [ "$DEVICE" == "cpu" ]; then | ||
echo "cpu not supported!" | ||
exit 0 | ||
fi | ||
shift 2 | ||
;; | ||
-lf|--log_file) | ||
LOG_FILENAME="$2" | ||
shift 2 | ||
;; | ||
-md|--models_dir) | ||
MODELS_DIR="$2" | ||
shift 2 | ||
;; | ||
-h|--help) | ||
print_usage | ||
;; | ||
*) | ||
echo "Unknown option: $1" | ||
print_usage | ||
;; | ||
esac | ||
done | ||
# Set default values if not provided | ||
PROMPT="${PROMPT:-"Explain what is a transformer"}" | ||
REPETITIONS="${REPETITIONS:-10}" | ||
MAX_TOKENS="${MAX_TOKENS:-100}" | ||
DEVICE="${DEVICE:-'cpu'}" | ||
LOG_FILENAME="${LOG_FILENAME:-"benchmark_$(date +'%Y%m%d%H%M%S').log"}" | ||
MODELS_DIR="${MODELS_DIR:-"./models"}" | ||
|
||
check_platform | ||
check_python | ||
setup "$MODELS_DIR" | ||
run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch --index-url https://download.pytorch.org/whl/cu116 | ||
optimum[onnxruntime-gpu]==1.14 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/bin/bash | ||
|
||
################################################################################ | ||
# Script: setup.sh <MODELS_FOLDER> | ||
# Description: Automates the setup of a virtual environment and installs project | ||
# requirements and handles model conversion. | ||
################################################################################ | ||
|
||
set -euo pipefail | ||
|
||
if [ "$#" -ne 1 ]; then | ||
echo "Usage: $0 <models_folder>" | ||
exit 1 | ||
fi | ||
|
||
# Define directory paths | ||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
VENV_DIR="$SCRIPT_DIR/venv" | ||
MODELS_FOLDER="$1" | ||
LLAMA_HF_MODEL_DIR="$MODELS_FOLDER/llama-2-7b-hf" | ||
LLAMA_ONNX_MODEL_DIR="$MODELS_FOLDER/llama-2-7b-onnx" | ||
|
||
if [ ! -d "$VENV_DIR" ]; then | ||
python -m venv "$VENV_DIR" | ||
echo "Virtual environment '$VENV_DIR' created." | ||
# shellcheck disable=SC1091 | ||
source "$VENV_DIR/bin/activate" | ||
pip install --upgrade pip > /dev/null | ||
pip install -r "$SCRIPT_DIR"/requirements.txt > /dev/null | ||
else | ||
# shellcheck disable=SC1091 | ||
source "$VENV_DIR/bin/activate" | ||
fi | ||
# Check and create llama-2-7b-onnx model | ||
if [ ! -d "$LLAMA_ONNX_MODEL_DIR" ]; then | ||
optimum-cli export onnx \ | ||
--model "$LLAMA_HF_MODEL_DIR" --task text-generation --framework pt \ | ||
--opset 17 --sequence_length 1024 --batch_size 1 --device cuda --fp16 \ | ||
"$LLAMA_ONNX_MODEL_DIR" > /dev/null | ||
else | ||
echo "Model llama-2-7b-onnx already exists!" | ||
fi |