Skip to content

Commit

Permalink
[Bugfix] Fall back to outlines for grammars that can't convert to GBNF
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin committed Dec 21, 2024
1 parent d573aea commit ca5d141
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 81 deletions.
87 changes: 17 additions & 70 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import TYPE_CHECKING

from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform

if TYPE_CHECKING:
Expand All @@ -15,76 +18,6 @@
logger = init_logger(__name__)


def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
Expand Down Expand Up @@ -127,6 +60,20 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif (guided_params.grammar is not None
and grammar_is_likely_lark(guided_params.grammar)):
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
Expand Down
23 changes: 14 additions & 9 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@

import numpy as np
import torch
from lark import Lark
from outlines import grammars
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
RegexGuide, Write)
from outlines.fsm.parsing import PartialLark
from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
Expand All @@ -34,7 +35,9 @@ class BaseLogitsProcessor:

def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
Expand All @@ -54,15 +57,13 @@ def __call__(self, input_ids: List[int],
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark(
self._guide.parser = PartialLark(
self._guide.cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)
self._fsm_state[seq_id] = CFGState(
parser_state=self._guide.parser.parse(""), prev_token=None)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
Expand Down Expand Up @@ -200,7 +201,8 @@ def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"):
return " " + string

return string
Expand All @@ -211,6 +213,9 @@ def change_decoder(
"""Sync vLLM's decoder with the outlines by returning list."""

def new_decoder(inp_tokens: List[int]) -> List[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0]
return [decoder(inp_tokens)]

return new_decoder
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,74 @@
import re

def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
except ImportError:
pass

from vllm.model_executor.guided_decoding.xgrammar_utils import (
convert_lark_to_gbnf, grammar_is_likely_lark)
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

if TYPE_CHECKING:
Expand Down

0 comments on commit ca5d141

Please sign in to comment.