Skip to content

Commit

Permalink
[Doc][V1] Update model implementation guide for V1 support (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#11998)

Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: ice-tong <[email protected]>
  • Loading branch information
2 people authored and ice-tong committed Jan 18, 2025
1 parent e1a2399 commit 8ab4b0d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
12 changes: 11 additions & 1 deletion docs/source/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,17 @@ class MyModelForCausalLM(nn.Module):

### Computation Code

Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.

```python
class MyModel(nn.Module):
...

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...
```

- Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.

```python
def forward(
Expand Down
87 changes: 72 additions & 15 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,78 @@ This document walks you through the steps to extend a basic model so that it acc
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
Further update the model as follows:

- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example:

```diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```

More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it.

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.

```python
class YourModelForImage2Seq(nn.Module):
...

def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:

assert self.vision_encoder is not None
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:

# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

# Run multimodal inputs through encoder and projector
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
```

```{important}
The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
```

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.

```python
from .utils import merge_multimodal_embeddings

class YourModelForImage2Seq(nn.Module):
...

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:

# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)

return inputs_embeds
```

- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.

```diff
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
Expand All @@ -23,20 +94,6 @@ Further update the model as follows:
Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
```

- If you haven't already done so, reserve a keyword parameter in {meth}`~torch.nn.Module.forward`
for each input tensor that corresponds to a multi-modal input, as shown in the following example:

```diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```

## 2. Specify processing information

Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
Expand Down

0 comments on commit 8ab4b0d

Please sign in to comment.