Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/N] Initial prototype for multi-modal processor #10044

Merged
merged 22 commits into from
Nov 13, 2024
Merged

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Nov 5, 2024

This PR adds the core code for multi-modal processor while maintaining backward compatibility. The main purpose of this PR is to "reserve" code changes (mainly related to the dependencies of multi-modal processor) to reduce the risk of merge conflicts caused by subsequent PRs.

Note that currently there are no models that use the new multi-modal processor - I will implement a few of them in the next PR. As such, the details of multi-modal processor are still subject to change.

Part of #10114

Copy link

github-actions bot commented Nov 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Nov 5, 2024
@DarkLight1337 DarkLight1337 force-pushed the mm-processor branch 3 times, most recently from ce2c9ef to 1c699eb Compare November 5, 2024 17:20
Copy link

mergify bot commented Nov 6, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @DarkLight1337 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Nov 9, 2024
Comment on lines +127 to +168
def candidate_placeholders(
tokenizer: AnyTokenizer,
placeholder_text: str,
) -> Collection[List[int]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string

# Matches the placeholder when it is in the middle of the string
start_id, = encode_no_special_tokens(tokenizer, "a")
end_id, = encode_no_special_tokens(tokenizer, "b")

candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)

start_id_, *candidate_a = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}",
)
assert start_id == start_id_

start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}b",
)
assert start_id == start_id_ and end_id == end_id_

*candidate_b, end_id_ = encode_no_special_tokens(
tokenizer,
f"{placeholder_text}b",
)
assert end_id == end_id_

# Remove duplicates (need to convert to tuple to be hashable)
unique_candidates = {
tuple(c)
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
}

# Convert back to list
return [list(c) for c in unique_candidates]
Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A generalization of our existing code for Phi-3V

Comment on lines +171 to +191
def apply_placeholders(
token_ids: List[int],
placeholder_ids: List[int],
get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
"""
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.

This function updates :code:`token_ids` in place.
"""
placeholder_length = len(placeholder_ids)

for start_idx in range(len(token_ids) - placeholder_length + 1):
if token_ids[start_idx:placeholder_length] == placeholder_ids:
token_ids[start_idx:placeholder_length] = get_replacement_ids()

return PlaceholderRange(offset=start_idx,
length=placeholder_length)

return None
Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A generalization of our existing code for HF Pixtral

Signed-off-by: DarkLight1337 <[email protected]>
Comment on lines -49 to -52
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalKwargs(data.data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added by the Llama-3.2 PR, but I found that this model doesn't use HF processor in vLLM input processor anymore, so it should be safe to remove.

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Isotr0py
Copy link
Collaborator

Copy link

mergify bot commented Nov 13, 2024

rebase

❌ Base branch update has failed

Git reported the following error:

Rebasing (1/7)
Auto-merging tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py
CONFLICT (content): Merge conflict in tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py
Auto-merging vllm/config.py
Auto-merging vllm/engine/async_llm_engine.py
Auto-merging vllm/engine/llm_engine.py
Auto-merging vllm/entrypoints/openai/serving_chat.py
Auto-merging vllm/entrypoints/openai/serving_completion.py
Auto-merging vllm/inputs/preprocess.py
Auto-merging vllm/model_executor/models/chatglm.py
Auto-merging vllm/model_executor/models/fuyu.py
Auto-merging vllm/model_executor/models/internvl.py
Auto-merging vllm/model_executor/models/minicpmv.py
Auto-merging vllm/model_executor/models/mllama.py
Auto-merging vllm/model_executor/models/molmo.py
Auto-merging vllm/model_executor/models/pixtral.py
Auto-merging vllm/model_executor/models/qwen.py
Auto-merging vllm/model_executor/models/qwen2_audio.py
Auto-merging vllm/model_executor/models/qwen2_vl.py
Auto-merging vllm/model_executor/models/ultravox.py
Auto-merging vllm/multimodal/__init__.py
CONFLICT (content): Merge conflict in vllm/multimodal/__init__.py
Auto-merging vllm/multimodal/base.py
CONFLICT (content): Merge conflict in vllm/multimodal/base.py
Auto-merging vllm/multimodal/registry.py
CONFLICT (content): Merge conflict in vllm/multimodal/registry.py
Auto-merging vllm/multimodal/video.py
Auto-merging vllm/sequence.py
Auto-merging vllm/spec_decode/draft_model_runner.py
Auto-merging vllm/v1/engine/llm_engine.py
CONFLICT (content): Merge conflict in vllm/v1/engine/llm_engine.py
Auto-merging vllm/worker/cpu_enc_dec_model_runner.py
Auto-merging vllm/worker/cpu_model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/cpu_model_runner.py
Auto-merging vllm/worker/enc_dec_model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/enc_dec_model_runner.py
Auto-merging vllm/worker/model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/model_runner.py
Auto-merging vllm/worker/neuron_model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/neuron_model_runner.py
Auto-merging vllm/worker/openvino_model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/openvino_model_runner.py
Auto-merging vllm/worker/xpu_model_runner.py
CONFLICT (content): Merge conflict in vllm/worker/xpu_model_runner.py
error: could not apply 5108119b... Initial prototype for multi-modal processor
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 5108119b... Initial prototype for multi-modal processor

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2024
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Copy link

mergify bot commented Nov 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DarkLight1337.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 13, 2024
@mergify mergify bot removed the needs-rebase label Nov 13, 2024
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 13, 2024 12:19
@DarkLight1337 DarkLight1337 merged commit 0b8bb86 into main Nov 13, 2024
53 checks passed
@DarkLight1337 DarkLight1337 deleted the mm-processor branch November 13, 2024 12:39
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 13, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants