Skip to content

Commit

Permalink
Fix token mismatch in Phi3V and Ultravox
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
1 parent 10ae755 commit 85c5e2c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 43 deletions.
47 changes: 37 additions & 10 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand All @@ -36,7 +36,9 @@
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

Expand Down Expand Up @@ -342,12 +344,13 @@ def _call_hf_processor(
mm_kwargs=mm_kwargs,
)

input_ids = processed_outputs["input_ids"]
assert isinstance(input_ids, torch.Tensor)

# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)

return processed_outputs

Expand All @@ -372,8 +375,9 @@ def _get_prompt_replacements(
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore

mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)

def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
Expand All @@ -382,16 +386,39 @@ def get_replacement_phi3v(item_idx: int):
height=image_size.height,
)

return [_IMAGE_TOKEN_ID] * num_tokens
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]

return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
) for image_token in image_tokens[:len(mm_items.images)]
]

def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_item_counts=mm_item_counts,
)

# Keep the behavior in line with HF processor
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]

return token_ids, text, placeholders

def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,20 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self._get_tokenizer()

prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])

if not audios:
if not mm_data:
# Text-only input not supported in composite processor
prompt_ids = self._get_tokenizer().encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]),
tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
Expand Down
74 changes: 47 additions & 27 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ def _apply_hf_processor(
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the full prompt text and multi-modal data.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

processed_data = self._call_hf_processor(
Expand All @@ -832,12 +835,51 @@ def _apply_hf_processor(

return prompt_ids, mm_kwargs

def _apply_hf_processor_missing(
self,
prompt_text: str,
mm_missing_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
):
"""
Apply the HF processor on the full prompt text, but only on the
multi-modal data that are missing from the cache.
Note: We pass prompt text and multi-modal data into the HF processor
in separate calls to avoid HF prompt replacement being done for
cached items; instead, we rely on our own prompt replacement logic
for the full text.
"""
mm_missing_counts = mm_missing_data_items.get_item_counts()

prompt_ids, _ = self._apply_hf_processor(
prompt_text=prompt_text,
mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={},
)

# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)

_, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)

return prompt_ids, mm_missing_kwargs

def _cached_apply_hf_processor(
self,
prompt_text: str,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
"""
cache = self.cache

if cache is None:
Expand All @@ -864,35 +906,12 @@ def _cached_apply_hf_processor(
for modality, idxs in mm_missing_idxs.items()
}
mm_missing_data_items = self._get_mm_items(mm_missing_data)
mm_missing_counts = mm_missing_data_items.get_item_counts()

if any(count > 0 for count in mm_missing_counts.values()):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)

_, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
else:
# Avoid unnecessary tokenization of the prompt text
mm_missing_kwargs = MultiModalKwargs({})

# NOTE: Some HF processors insert BOS/EOS while others don't.
# We try to maintain consistent behavior when calling only
# the tokenizer vs when calling its parent processor
empty_ids, _ = self._apply_hf_processor(
prompt_text="",
mm_items=self._get_mm_items({}),
hf_processor_mm_kwargs={},
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
prompt_text=prompt_text,
mm_missing_data_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
# Rely on our placeholder replacement logic instead of HF
# to insert the placeholder tokens
prompt_ids = _encode(self._get_tokenizer(),
prompt_text,
add_special_tokens=len(empty_ids) > 0)

mm_missing_next_idx = {
modality: 0
Expand Down Expand Up @@ -925,6 +944,7 @@ def _cached_apply_hf_processor(
mm_merged_field_items[modality] = merged_modal_items_lst

if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_item_counts()
assert all(
item_count == mm_missing_counts[modality]
for modality, item_count in mm_missing_next_idx.items()), dict(
Expand Down

0 comments on commit 85c5e2c

Please sign in to comment.