diff --git a/src/lemonade/tools/ort_genai/oga.py b/src/lemonade/tools/ort_genai/oga.py index 99eaf9e..db3d68e 100644 --- a/src/lemonade/tools/ort_genai/oga.py +++ b/src/lemonade/tools/ort_genai/oga.py @@ -123,8 +123,11 @@ def generate( # There is a breaking API change in OGA 0.6.0 # Determine whether we should use the old or new APIs - use_oga_pre_6_api = Version(og.__version__) < Version("0.6.0") - use_oga_post_6_api = not use_oga_pre_6_api + # This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version + use_oga_post_6_api = ( + Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__ + ) + use_oga_pre_6_api = not use_oga_post_6_api if pad_token_id: params.pad_token_id = pad_token_id