Skip to content

Commit

Permalink
[Model] [BUG] Fix code path logic to load mllama model (#234)
Browse files Browse the repository at this point in the history
* fix code path logic to load mllama model

* fix lint error

* fix lint error

---------

Co-authored-by: tjtanaa <[email protected]>
  • Loading branch information
tjtanaa and tjtanaa authored Oct 16, 2024
1 parent 82cfa5a commit 1658370
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import async_tensor_h2d, is_hip, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
Expand Down Expand Up @@ -334,11 +334,19 @@ def graph_capture_get_metadata_for_batch(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)

return attn_metadata

Expand All @@ -354,11 +362,19 @@ def get_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers

def prepare_graph_input_buffers(
Expand All @@ -373,11 +389,20 @@ def prepare_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

def begin_forward(self, model_input) -> None:
return
Expand Down

0 comments on commit 1658370

Please sign in to comment.