Skip to content

Commit

Permalink
Merge pull request #85 from Anindyadeep/anindya/pytorch
Browse files Browse the repository at this point in the history
Transformers PyTorch benchmark.
  • Loading branch information
nsosio authored Dec 8, 2023
2 parents fb2bfcb + a7c068d commit 6401851
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 27 deletions.
148 changes: 148 additions & 0 deletions bench_pytorch/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import logging
import sys
import time
from collections import defaultdict
from typing import Optional

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

logging.getLogger("transformers").setLevel(logging.ERROR)
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)


class LlamaPyTorchBenchmark:
def __init__(
self, model_path: str, precision: str, device: Optional[str] = "cpu"
) -> None:
self.model_path = model_path
self.precision = precision
self.results = []
self.precision_to_dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
}

# some of the conditions where things can not be supported
assert precision in ["bf16", "fp16", "fp32"], ValueError(
"Supported precisions are: 'bf16', fp16', 'fp32'"
)
assert device in ["cpu", "cuda", "metal"], ValueError(
"Supported devices are: 'cpu', 'cuda', 'metal'"
)

if device == "cpu" and precision != "fp32":
raise ValueError(
"When device is set to CPU, fp32 is the only supported precision."
)

self.device = "cuda:0" if device == "cuda" else device
# build the params
self.model_args = {
"device_map": self.device,
"torch_dtype": self.precision_to_dtype_map[self.precision],
}

def load_model(self):
"""Loads the model into various formats and device."""
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, **self.model_args
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
return self

def run_model(self, prompt: str, max_tokens: int) -> float:
tokenized_input = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.device
)
start = time.time()
output = (
self.model.generate(input_ids=tokenized_input, max_new_tokens=max_tokens)
.detach()
.cpu()
.numpy()
)
delta = time.time() - start
return len(output[0]) / delta

def benchmark(self, prompt: str, max_tokens: int, repetitions: int) -> None:
for i in range(repetitions):
logging.info(
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]"
)
tokens_per_second = self.run_model(prompt, max_tokens)
self.results.append(tokens_per_second)
del self.model
if self.device == "cuda":
torch.cuda.synchronize()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CTransformers Benchmark.")
parser.add_argument(
"--prompt",
type=str,
help="The prompt for the model.",
)
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.")
parser.add_argument(
"--repetitions",
type=int,
help="The number of repetitions for the benchmark.",
)
parser.add_argument(
"--device",
help="Device to use for the benchmark.",
)
parser.add_argument(
"--log_file",
type=str,
help="Path to the log file for writing logs (in append mode).",
)
parser.add_argument(
"--models_dir",
type=str,
help="Path to the models directory.",
)
args = parser.parse_args()
logging.info(
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} "
+ f"repetitions={args.repetitions} device={args.device}"
)
report = defaultdict(lambda: defaultdict(float))

for precision in ("fp16", "fp32") if args.device != "cpu" else ("fp32",):
logging.info(
f"Running Transformer benchmark (pytorch backend) on Llama with precision: {precision}"
)
llama_transformers_pytorch_benchmark = LlamaPyTorchBenchmark(
model_path=f"{args.models_dir}/llama-2-7b-hf",
device=args.device,
precision=precision,
).load_model()
llama_transformers_pytorch_benchmark.benchmark(
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions
)

report["llama_transformers_pytorch"][precision] = {
"mean": np.mean(llama_transformers_pytorch_benchmark.results),
"std": np.mean(llama_transformers_pytorch_benchmark.results),
}
logging.info("Benchmark Report")
with open(args.log_file, "a") as file:
for framework, quantizations in report.items():
for quantization, stats in quantizations.items():
logging.info(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}"
)
print(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}",
file=file,
)
151 changes: 151 additions & 0 deletions bench_pytorch/bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/bin/bash

########################################################################################################
# Script: bench.sh
# Description: This script runs benchmarks llama.cpp llama benchmark.
#
# Usage: ./bench.sh [OPTIONS]
# OPTIONS:
# -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer')
# -r, --repetitions Number of repetitions for benchmarks (default: 2)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 100)
# -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu')
# -lf, --log_file Logging file name.
# -md, --models_dir Models directory.
# -h, --help Show this help message
########################################################################################################

set -euo pipefail

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "OPTIONS:"
echo " -p, --prompt Prompt for benchmarks (default: 'Explain what is a transformer')"
echo " -r, --repetitions Number of repetitions for benchmarks (default: 2)"
echo " -m, --max_tokens Maximum number of tokens for benchmarks (default: 100)"
echo " -d, --device Device for benchmarks (possible values: 'metal', 'gpu', and 'cpu', default: 'cpu')"
echo " -lf, --log_file Logging file name."
echo " -md, --models_dir Models directory."
echo " -h, --help Show this help message"
exit 1
}

check_cuda() {
if command -v nvcc &> /dev/null
then
echo -e "\nUsing CUDA"
nvcc --version
else
echo -e "\nCUDA is not available."
exit 1
fi
}

check_platform() {
local platform
platform=$(uname -s)
if [[ "$platform" == "Linux" ]]; then
echo "Running on Linux."
elif [[ "$platform" == "Darwin" ]]; then
echo "Running on Mac OS."
else
echo "Unknown platform."
exit 1
fi
}

check_python() {
if command -v python &> /dev/null
then
echo -e "\nUsing $(python --version)."
else
echo -e "\nPython does not exist."
exit 1
fi
}

setup() {
echo -e "\nSetting up with $SCRIPT_DIR/setup.sh..."
bash "$SCRIPT_DIR"/setup.sh
}

run_benchmarks() {
local PROMPT="$1"
local REPETITIONS="$2"
local MAX_TOKENS="$3"
local DEVICE="$4"
local LOG_FILENAME="$5"
local MODELS_DIR="$6"

# shellcheck disable=SC1091
source "$SCRIPT_DIR/venv/bin/activate"
python "$SCRIPT_DIR"/bench.py \
--prompt "$PROMPT" \
--repetitions "$REPETITIONS" \
--max_tokens "$MAX_TOKENS" \
--log_file "$LOG_FILENAME" \
--models_dir "$MODELS_DIR" \
--device "$DEVICE"
}

# Parse command-line arguments
while [ "$#" -gt 0 ]; do
case "$1" in
-p|--prompt)
PROMPT="$2"
shift 2
;;
-r|--repetitions)
REPETITIONS="$2"
shift 2
;;
-m|--max_tokens)
MAX_TOKENS="$2"
shift 2
;;
-d|--device)
DEVICE="$2"
case "$DEVICE" in
"cuda" | "metal" | "cpu")
;;
*)
echo "Invalid value for --device. Please use 'cuda', 'gpu' or 'cpu'."
print_usage
;;
esac
if [ "$DEVICE" == "cuda" ]; then
check_cuda
fi
shift 2
;;
-lf|--log_file)
LOG_FILENAME="$2"
shift 2
;;
-md|--models_dir)
MODELS_DIR="$2"
shift 2
;;
-h|--help)
print_usage
;;
*)
echo "Unknown option: $1"
print_usage
;;
esac
done
# Set default values if not provided
PROMPT="${PROMPT:-"Explain what is a transformer"}"
REPETITIONS="${REPETITIONS:-10}"
MAX_TOKENS="${MAX_TOKENS:-100}"
DEVICE="${DEVICE:-'cpu'}"
LOG_FILENAME="${LOG_FILENAME:-"benchmark_$(date +'%Y%m%d%H%M%S').log"}"
MODELS_DIR="${MODELS_DIR:-"./models"}"

check_platform
check_python
setup
run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR"
3 changes: 3 additions & 0 deletions bench_pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers==4.34.1
torch==2.0.0
accelerate
25 changes: 25 additions & 0 deletions bench_pytorch/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

################################################################################
# Script: setup.sh
# Description: Automates the setup of a virtual environment and installs project
# requirements.
################################################################################

set -euo pipefail

# Main script starts here.
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV_DIR="$SCRIPT_DIR/venv"

if [ ! -d "$VENV_DIR" ]; then
python -m venv "$VENV_DIR"
echo "Virtual environment '$VENV_DIR' created."
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
pip install --upgrade pip > /dev/null
pip install -r "$SCRIPT_DIR/requirements.txt" --no-cache-dir > /dev/null
else
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
fi
Loading

0 comments on commit 6401851

Please sign in to comment.