Skip to content

Commit

Permalink
Bring OGA under test and fix OGA server. Improve llm-prompt. (#272)
Browse files Browse the repository at this point in the history
Co-authored-by: Akshay Sonawane <[email protected]>
Co-authored-by: amd-pworfolk <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent bc33e79 commit c1463b0
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 4 deletions.
60 changes: 60 additions & 0 deletions .github/workflows/test_lemonade_oga_cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Lint and Test Lemonade for OGA on CPU

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

permissions:
contents: read

jobs:
make-oga-cpu-lemonade:
env:
LEMONADE_CI_MODE: "True"
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- name: Set up Miniconda with 64-bit Python
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: lemon
python-version: "3.10"
run-post: "false"
- name: Install dependencies
shell: bash -el {0}
run: |
python -m pip install --upgrade pip
conda install pylint
python -m pip check
pip install -e .[llm-oga-cpu]
- name: Lint with Black
uses: psf/black@stable
with:
options: "--check --verbose"
src: "./src"
- name: Lint with PyLint
shell: bash -el {0}
run: |
pylint src/lemonade --rcfile .pylintrc --disable E0401
- name: Test OGA+CPU server
if: runner.os == 'Windows'
timeout-minutes: 10
uses: ./.github/actions/server-testing
with:
conda_env: -n lemon
load_command: -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4
hf_token: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
- name: Run lemonade tests
shell: bash -el {0}
env:
HF_TOKEN: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
run: |
lemonade -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4 llm-prompt -p "tell me a story" --max-new-tokens 5
python test/oga_cpu_api.py
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
"fastapi",
"uvicorn[standard]",
],
"llm-oga-cpu": [
"onnxruntime-genai>=0.5.2",
"torch>=2.0.0,<2.4",
"transformers<4.45.0",
"turnkeyml[llm]",
],
"llm-oga-igpu": [
"onnxruntime-genai-directml==0.4.0",
"torch>=2.0.0,<2.4",
Expand Down
3 changes: 3 additions & 0 deletions src/lemonade/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class Keys:
STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
CHECKPOINT = "checkpoint"
DTYPE = "dtype"
PROMPT = "prompt"
PROMPT_TOKENS = "prompt_tokens"
RESPONSE = "response"
RESPONSE_TOKENS = "response_tokens"
CACHE_DIR = "cache_dir"
DEVICE = "device"
OGA_MODELS_SUBFOLDER = "oga_models_subfolder"
34 changes: 31 additions & 3 deletions src/lemonade/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from turnkeyml.state import State
from turnkeyml.tools import Tool
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
from lemonade.cache import Keys

DEFAULT_GENERATE_PARAMS = {
"do_sample": True,
Expand All @@ -25,6 +26,10 @@
END_OF_STREAM = "</s>"


def sanitize_string(input_string):
return input_string.encode("utf-8", "ignore").decode("utf-8")


class LLMPrompt(Tool):
"""
Send a prompt to an LLM instance and print the response to the screen.
Expand All @@ -43,7 +48,12 @@ class LLMPrompt(Tool):
def __init__(self):
super().__init__(monitor_message="Prompting LLM")

self.status_stats = ["response"]
self.status_stats = [
Keys.PROMPT_TOKENS,
Keys.PROMPT,
Keys.RESPONSE_TOKENS,
Keys.RESPONSE,
]

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -75,13 +85,31 @@ def run(
tokenizer: TokenizerAdapter = state.tokenizer

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
if isinstance(input_ids, list):
# OGA models return a list of tokens
len_tokens_in = len(input_ids)
else:
# HF models return a 2-D tensor
len_tokens_in = input_ids.shape[1]

response = model.generate(
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
len_tokens_out = len(response[0]) - len_tokens_in
input_ids = input_ids if isinstance(input_ids, list) else input_ids[0]
i = 0
while i < len_tokens_in and input_ids[i] == response[0][i]:
i += 1
response_text = tokenizer.decode(
response[0][i:], skip_special_tokens=True
).strip()

state.response = response_text
state.save_stat("response", response_text)

state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
state.save_stat(Keys.PROMPT, prompt)
state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
state.save_stat(Keys.RESPONSE, sanitize_string(response_text))

return state

Expand Down
21 changes: 20 additions & 1 deletion src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
from fnmatch import fnmatch
from queue import Queue
from packaging.version import Version
from huggingface_hub import snapshot_download
import onnxruntime_genai as og
import onnxruntime_genai.models.builder as model_builder
Expand Down Expand Up @@ -120,11 +121,22 @@ def generate(
):
params = og.GeneratorParams(self.model)

# There is a breaking API change in OGA 0.6.0
# Determine whether we should use the old or new APIs
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
use_oga_post_6_api = (
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
)
use_oga_pre_6_api = not use_oga_post_6_api

if pad_token_id:
params.pad_token_id = pad_token_id

max_length = len(input_ids) + max_new_tokens

if use_oga_pre_6_api:
params.input_ids = input_ids

if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -158,10 +170,13 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
generator.append_tokens(input_ids)
if use_oga_post_6_api:
generator.append_tokens(input_ids)

if streamer is None:
prompt_start_time = time.perf_counter()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
prompt_end_time = time.perf_counter()

Expand All @@ -172,6 +187,8 @@ def generate(
token_gen_times = []
while not generator.is_done():
token_gen_start_time = time.perf_counter()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
token_gen_end_time = time.perf_counter()

Expand All @@ -192,6 +209,8 @@ def generate(
stop_early = False

while not generator.is_done() and not stop_early:
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]
Expand Down
100 changes: 100 additions & 0 deletions test/oga_cpu_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
import shutil
import os
import urllib3
from turnkeyml.state import State
import turnkeyml.common.test_helpers as common
import turnkeyml.common.filesystem as fs
from lemonade.tools.ort_genai.oga import OgaLoad
from lemonade.tools.chat import LLMPrompt
from lemonade.tools.mmlu import AccuracyMMLU
from lemonade.tools.humaneval import AccuracyHumaneval

ci_mode = os.getenv("LEMONADE_CI_MODE", False)

checkpoint = "TinyPixel/small-llama2"
device = "cpu"
dtype = "int4"
force = False
prompt = "Alice and Bob"

try:
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
resp = urllib3.request("GET", url, preload_content=False)
if 200 <= resp.status < 400:
eecs_berkeley_edu_cannot_be_reached = False
else:
eecs_berkeley_edu_cannot_be_reached = True
resp.release_conn()
except urllib3.exceptions.HTTPError:
eecs_berkeley_edu_cannot_be_reached = True


class Testing(unittest.TestCase):

def setUp(self) -> None:
shutil.rmtree(cache_dir, ignore_errors=True)

def test_001_ogaload(self):
# Test the OgaLoad and LLMPrompt tools on an NPU model

state = State(cache_dir=cache_dir, build_name="test")

state = OgaLoad().run(
state, input=checkpoint, device=device, dtype=dtype, force=force
)
state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=5)

assert len(state.response) > len(prompt), state.response

@unittest.skipIf(
eecs_berkeley_edu_cannot_be_reached,
"eecs.berkeley.edu cannot be reached for dataset download",
)
def test_002_accuracy_mmlu(self):
# Test MMLU benchmarking with known model
subject = ["management"]

state = State(
cache_dir=cache_dir,
build_name="test",
)

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyMMLU().run(state, ntrain=5, tests=subject)

stats = fs.Stats(state.cache_dir, state.build_name).stats
assert stats[f"mmlu_{subject[0]}_accuracy"] >= 0

def test_003_accuracy_humaneval(self):
"""Test HumanEval benchmarking with known model"""

state = State(
cache_dir=cache_dir,
build_name="test",
)

# Enable code evaluation for HumanEval
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyHumaneval().run(
state,
first_n_samples=1, # Test only one problem for speed
k_samples=1, # Single attempt per problem
timeout=30.0,
)

# Verify results
stats = fs.Stats(state.cache_dir, state.build_name).stats
assert "humaneval_pass@1" in stats, "HumanEval pass@1 metric not found"
assert isinstance(
stats["humaneval_pass@1"], (int, float)
), "HumanEval pass@1 metric should be numeric"


if __name__ == "__main__":
cache_dir, _ = common.create_test_dir(
"lemonade_oga_cpu_api", base_dir=os.path.abspath(".")
)
unittest.main()

0 comments on commit c1463b0

Please sign in to comment.