Skip to content

Commit

Permalink
Better compatibility for OGA
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers committed Jan 20, 2025
1 parent b43c0e9 commit ac1a740
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"uvicorn[standard]",
],
"llm-oga-cpu": [
"onnxruntime-genai==0.5.2",
"onnxruntime-genai>=0.5.2",
"torch>=2.0.0,<2.4",
"transformers<4.45.0",
"turnkeyml[llm]",
Expand Down
21 changes: 17 additions & 4 deletions src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
from fnmatch import fnmatch
from queue import Queue
from packaging.version import Version
from huggingface_hub import snapshot_download
import onnxruntime_genai as og
import onnxruntime_genai.models.builder as model_builder
Expand Down Expand Up @@ -120,12 +121,19 @@ def generate(
):
params = og.GeneratorParams(self.model)

# There is a breaking API change in OGA 0.6.0
# Determine whether we should use the old or new APIs
use_oga_pre_6_api = Version(og.__version__) < Version("0.6.0")
use_oga_post_6_api = not use_oga_pre_6_api

if pad_token_id:
params.pad_token_id = pad_token_id

max_length = len(input_ids) + max_new_tokens

params.input_ids = input_ids
if use_oga_pre_6_api:
params.input_ids = input_ids

if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -159,10 +167,13 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
if use_oga_post_6_api:
generator.append_tokens(input_ids)

if streamer is None:
prompt_start_time = time.perf_counter()
generator.compute_logits()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
prompt_end_time = time.perf_counter()

Expand All @@ -173,7 +184,8 @@ def generate(
token_gen_times = []
while not generator.is_done():
token_gen_start_time = time.perf_counter()
generator.compute_logits()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
token_gen_end_time = time.perf_counter()

Expand All @@ -194,7 +206,8 @@ def generate(
stop_early = False

while not generator.is_done() and not stop_early:
generator.compute_logits()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]
Expand Down

0 comments on commit ac1a740

Please sign in to comment.