Skip to content

Commit

Permalink
Merge pull request #59 from nsosio/feat/onnx-runtime
Browse files Browse the repository at this point in the history
Added ONNX Runtime
  • Loading branch information
nsosio authored Nov 23, 2023
2 parents d1520a8 + c653ef9 commit 95ddca6
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 4 deletions.
11 changes: 7 additions & 4 deletions README.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Model: LLAMA-2-7B

CUDA Version: 11.7

Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device gpu --nvidia --prompt 'Explain what is a transformer'`
Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device cuda --prompt 'Explain what is a transformer'`

| Engine | float32 | float16 | int8 | int4 |
|-------------|--------------|---------------|---------------|---------------|
Expand All @@ -80,6 +80,7 @@ Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device gpu --nvidia
| llama.cpp | - | - | 84.48 ± 3.76 | 106.76 ± 1.29 |
| ctranslate | - | 51.38 ± 16.01 | 36.12 ± 11.93 | - |
| tinygrad | - | 20.32 ± 0.06 | - | - |
| onnx | - | 54.16 ± 3.15 | - | - |

*(data updated: <LAST_UPDATE>)

Expand All @@ -94,24 +95,26 @@ CUDA Version: NA

Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device cpu --prompt 'Explain what is a transformer'`

| Engine | float32 | float16 | int8 | int4 |
| Engine | float32 | float16 | int8 | int4 |
|-------------|--------------|--------------|--------------|--------------|
| burn | 0.30 ± 0.09 | - | - | - |
| candle | - | 3.43 ± 0.02 | - | - |
| llama.cpp | - | - | 14.41 ± 1.59 | 20.96 ± 1.94 |
| ctranslate | - | - | 2.11 ± 0.73 | - |
| tinygrad | - | 4.21 ± 0.38 | - | - |
| onnx | - | - | - | - |

#### GPU (Metal)

Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device gpu --prompt 'Explain what is a transformer'`
Command: `./benchmark.sh --repetitions 10 --max_tokens 100 --device metal --prompt 'Explain what is a transformer'`

| Engine | float32 | float16 | int8 | int4 |
| Engine | float32 | float16 | int8 | int4 |
|-------------|--------------|--------------|--------------|--------------|
| burn | - | - | - | - |
| candle | - | - | - | - |
| llama.cpp | - | - | 31.24 ± 7.82 | 46.75 ± 9.55 |
| ctranslate | - | - | - | - |
| tinygrad | - | 29.78 ± 1.18 | - | - |
| onnx | - | - | - | - |

*(data updated: <LAST_UPDATE>)
111 changes: 111 additions & 0 deletions bench_onnxruntime/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse
import logging
import sys
import time
from collections import defaultdict

import numpy as np
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer

logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)


class ONNXBenchmark:
def __init__(self, model_path, device="cpu"):
self.model_path = model_path
self.device = device
self.provider = (
"CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider"
)
self.results = []

def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.model = ORTModelForCausalLM.from_pretrained(
self.model_path,
use_cache=False,
use_io_binding=False,
provider=self.provider,
)
return self

def run_model(self, prompt, max_tokens) -> float:
device_str = "cuda" if self.device == "cuda" else "cpu"
inputs = self.tokenizer(prompt, return_tensors="pt").to(device_str)
start = time.time()
gen_tokens = self.model.generate(**inputs, max_length=max_tokens)
tokens_per_second = (gen_tokens.shape[1] - inputs["input_ids"].shape[1]) / (
time.time() - start
)
return tokens_per_second

def benchmark(self, prompt, max_tokens, repetitions):
for i in range(repetitions):
logging.info(
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]"
)
tokens_per_second = self.run_model(prompt, max_tokens)
self.results.append(tokens_per_second)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="llama.cpp Benchmark Llama model.")
parser.add_argument(
"--prompt",
type=str,
help="The prompt for the model.",
)
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.")
parser.add_argument(
"--repetitions",
type=int,
help="The number of repetitions for the benchmark.",
)
parser.add_argument(
"--device",
help="Device to use for the benchmark.",
)
parser.add_argument(
"--log_file",
type=str,
help="Path to the log file for writing logs (in append mode).",
)
parser.add_argument(
"--models_dir",
type=str,
help="Path to the models directory.",
)
args = parser.parse_args()
logging.info(
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} "
+ f"repetitions={args.repetitions} device={args.device}"
)
report = defaultdict(lambda: defaultdict(float))
onnx_bench = ONNXBenchmark(
f"{args.models_dir}/llama-2-7b-onnx",
device=args.device,
).load_model()
onnx_bench.benchmark(
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions
)
report["onnx"]["float16"] = {
"mean": np.mean(onnx_bench.results),
"std": np.std(onnx_bench.results),
}

logging.info("Benchmark report")
with open(args.log_file, "a") as file:
for framework, quantizations in report.items():
for quantization, stats in quantizations.items():
logging.info(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}"
)
print(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}",
file=file,
)
161 changes: 161 additions & 0 deletions bench_onnxruntime/bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/bin/bash

########################################################################################################
# Script: bench.sh
# Description: This script runs benchmarks onnxruntime llama benchmark.
#
# Usage: ./bench.sh [OPTIONS]
# OPTIONS:
# -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer')
# -r, --repetitions Number of repetitions for benchmarks (default: 2)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 100)
# -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu')
# -lf, --log_file Logging file name.
# -md, --models_dir Models directory.
# -h, --help Show this help message
########################################################################################################

set -euo pipefail

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "OPTIONS:"
echo " -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer')"
echo " -r, --repetitions Number of repetitions for benchmarks (default: 2)"
echo " -m, --max_tokens Maximum number of tokens for benchmarks (default: 100)"
echo " -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu')"
echo " -lf, --log_file Logging file name."
echo " -md, --models_dir Models directory."
echo " -h, --help Show this help message"
exit 1
}

check_cuda() {
if command -v nvcc &> /dev/null
then
echo -e "\nUsing CUDA"
nvcc --version
else
echo -e "\nCUDA is not available."
exit 1
fi
}

check_platform() {
local platform
platform=$(uname -s)
if [[ "$platform" == "Linux" ]]; then
echo "Running on Linux."
elif [[ "$platform" == "Darwin" ]]; then
echo "Running on Mac OS."
else
echo "Unknown platform."
exit 1
fi
}

check_python() {
if command -v python &> /dev/null
then
echo -e "\nUsing $(python --version)."
else
echo -e "\nPython does not exist."
exit 1
fi
}

setup() {
echo -e "\nSetting up with $SCRIPT_DIR/setup.sh..."
bash "$SCRIPT_DIR"/setup.sh "$1"
}

run_benchmarks() {
local PROMPT="$1"
local REPETITIONS="$2"
local MAX_TOKENS="$3"
local DEVICE="$4"
local LOG_FILENAME="$5"
local MODELS_DIR="$6"

# shellcheck disable=SC1091
source "$SCRIPT_DIR/venv/bin/activate"

python "$SCRIPT_DIR"/bench.py \
--prompt "$PROMPT" \
--repetitions "$REPETITIONS" \
--max_tokens "$MAX_TOKENS" \
--log_file "$LOG_FILENAME" \
--models_dir "$MODELS_DIR" \
--device "$DEVICE"

}

# Parse command-line arguments
while [ "$#" -gt 0 ]; do
case "$1" in
-p|--prompt)
PROMPT="$2"
shift 2
;;
-r|--repetitions)
REPETITIONS="$2"
shift 2
;;
-m|--max_tokens)
MAX_TOKENS="$2"
shift 2
;;
-d|--device)
DEVICE="$2"
case "$DEVICE" in
"cuda" | "metal" | "cpu")
;;
*)
echo "Invalid value for --device. Please use 'cuda', 'gpu' or 'cpu'."
print_usage
;;
esac
if [ "$DEVICE" == "cuda" ]; then
check_cuda
fi
if [ "$DEVICE" == "metal" ]; then
echo "Metal not supported!"
exit 0
fi
if [ "$DEVICE" == "cpu" ]; then
echo "cpu not supported!"
exit 0
fi
shift 2
;;
-lf|--log_file)
LOG_FILENAME="$2"
shift 2
;;
-md|--models_dir)
MODELS_DIR="$2"
shift 2
;;
-h|--help)
print_usage
;;
*)
echo "Unknown option: $1"
print_usage
;;
esac
done
# Set default values if not provided
PROMPT="${PROMPT:-"Explain what is a transformer"}"
REPETITIONS="${REPETITIONS:-10}"
MAX_TOKENS="${MAX_TOKENS:-100}"
DEVICE="${DEVICE:-'cpu'}"
LOG_FILENAME="${LOG_FILENAME:-"benchmark_$(date +'%Y%m%d%H%M%S').log"}"
MODELS_DIR="${MODELS_DIR:-"./models"}"

check_platform
check_python
setup "$MODELS_DIR"
run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR"
2 changes: 2 additions & 0 deletions bench_onnxruntime/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch --index-url https://download.pytorch.org/whl/cu116
optimum[onnxruntime-gpu]==1.14
42 changes: 42 additions & 0 deletions bench_onnxruntime/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

################################################################################
# Script: setup.sh <MODELS_FOLDER>
# Description: Automates the setup of a virtual environment and installs project
# requirements and handles model conversion.
################################################################################

set -euo pipefail

if [ "$#" -ne 1 ]; then
echo "Usage: $0 <models_folder>"
exit 1
fi

# Define directory paths
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV_DIR="$SCRIPT_DIR/venv"
MODELS_FOLDER="$1"
LLAMA_HF_MODEL_DIR="$MODELS_FOLDER/llama-2-7b-hf"
LLAMA_ONNX_MODEL_DIR="$MODELS_FOLDER/llama-2-7b-onnx"

if [ ! -d "$VENV_DIR" ]; then
python -m venv "$VENV_DIR"
echo "Virtual environment '$VENV_DIR' created."
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
pip install --upgrade pip > /dev/null
pip install -r "$SCRIPT_DIR"/requirements.txt > /dev/null
else
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
fi
# Check and create llama-2-7b-onnx model
if [ ! -d "$LLAMA_ONNX_MODEL_DIR" ]; then
optimum-cli export onnx \
--model "$LLAMA_HF_MODEL_DIR" --task text-generation --framework pt \
--opset 17 --sequence_length 1024 --batch_size 1 --device cuda --fp16 \
"$LLAMA_ONNX_MODEL_DIR" > /dev/null
else
echo "Model llama-2-7b-onnx already exists!"
fi

0 comments on commit 95ddca6

Please sign in to comment.