Skip to content

Commit

Permalink
[V1] Initial support of multimodal models for V1 re-arch (vllm-projec…
Browse files Browse the repository at this point in the history
…t#10699)

Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored and BKitor committed Dec 30, 2024
1 parent 64561f6 commit c812fd4
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 69 deletions.
16 changes: 8 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,9 +1050,12 @@ def create_engine_config(self,
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.

# Chunked prefill is currently disabled for multimodal models by
# default.
if use_long_context and not model_config.is_multimodal_model:
# For multimodal models, chunked prefill is disabled by default in
# V0, but enabled by design in V1
if model_config.is_multimodal_model:
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

elif use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window()
is not None)
Expand Down Expand Up @@ -1241,12 +1244,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
...

Expand Down
68 changes: 57 additions & 11 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
Expand All @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
data: NestedTensors
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Expand Down Expand Up @@ -349,10 +355,32 @@ def input_processor(
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]

# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1

return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})

def input_mapper(
self,
Expand Down Expand Up @@ -614,26 +642,46 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
patches_per_image=patches_per_image)

raise AssertionError("This line should be unreachable.")

def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None

image_embeds = self.extract_feature(image_input["data"])

patches_per_image = image_input["patches_per_image"]
if len(patches_per_image) == 1:
image_embeds = image_embeds.unsqueeze(0)
return image_embeds

# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds

def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -696,13 +744,11 @@ def forward(
"inputs_embeds": inputs_embeds,
}

# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None

hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
Expand Down
72 changes: 63 additions & 9 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
Expand All @@ -46,12 +46,16 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128
DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066
DEFAULT_IM_START_TOKEN_ID = 152067
DEFAULT_IM_END_TOKEN_ID = 152064
DEFAULT_IM_COL_TOKEN_ID = 152065


class MolmoImageInputs(TypedDict):
Expand All @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict):
`(batch_size, num_crops, num_patch)`
"""

image_start_end: Tuple[int, int]
"""Starting and ending index of placeholder
tokens
"""


@dataclass
class VisionBackboneConfig:
Expand Down Expand Up @@ -918,6 +927,8 @@ def image_input_mapper_for_molmo(
ctx: InputContext,
data: object,
):
if isinstance(data, list):
data = data[0]
return MultiModalKwargs(data)


Expand Down Expand Up @@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
return DummyData(dummy_seqdata, {"image": dummy_imgdata})
size = 0
offset = -1
for i in range(len(token_ids)):
if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
dummy_imgdata["image_start_end"] = (offset, offset + size)
return DummyData(seq_data=dummy_seqdata,
multi_modal_data={"image": dummy_imgdata},
multi_modal_placeholders={
"image":
[PlaceholderRange(offset=offset, length=size)]
})


def pad_images(
Expand Down Expand Up @@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
if image_masks is not None:
image_data["image_masks"] = image_masks

image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
new_prompt_token_ids = out["input_ids"].tolist()
image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids),
dtype=torch.long)

multi_modal_data = dict(image=image_data)
size = 0
offset = -1
for i in range(len(new_prompt_token_ids)):
if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)

prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(out["input_ids"])
prompt = tokenizer.decode(new_prompt_token_ids)

return token_inputs(
prompt_token_ids=out["input_ids"],
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={
"image": [PlaceholderRange(offset=offset, length=size)]
},
)


Expand Down Expand Up @@ -1113,6 +1154,7 @@ def _parse_and_validate_image_input(
) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
image_start_end = kwargs.pop("image_start_end", None)
if images is None:
return None

Expand All @@ -1130,6 +1172,7 @@ def _parse_and_validate_image_input(
image_input_idx=image_input_idx,
seq_len=seq_len,
image_masks=image_masks,
image_start_end=image_start_end,
)

def _process_image_input(
Expand Down Expand Up @@ -1178,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# of input embedddings, which is not very efficient.
# TODO(ywang96): see if this can be optimized.
# of input embeddings.
vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)

# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings = list(vision_embeddings.split(seq_len.tolist()))
for i in range(len(vision_embeddings)):
start, end = image_input['image_start_end'][i]
vision_embeddings[i] = vision_embeddings[i][start:end]

return vision_embeddings

def get_input_embeddings(
Expand All @@ -1190,7 +1240,11 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = inputs_embeds + multimodal_embeddings
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID
])
return inputs_embeds

def forward(
Expand Down
Loading

0 comments on commit c812fd4

Please sign in to comment.