Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate graph breaks for torch.compile mode #202

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def reshape_and_cache(key,
# lots of padding, or are doing warmup.
# This loop is a workaround for this issue. Please remove it
# once key_cache.index_put_(indices, offsets), key) works.
num_kv_cache_passes = torch.div(num_slots_requested,
num_slots_available).ceil().int().item()
num_kv_cache_passes = -(-num_slots_requested // num_slots_available)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not very readable - if the goal is to remove torch calls, is there any reason not to use math.ceil() here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to avoid the extra import of math library. For readability, I can change to math.ceil()

for i in range(num_kv_cache_passes):
start_idx = i * num_slots_available
end_idx = (i + 1) * num_slots_available
Expand All @@ -58,8 +57,7 @@ def prepare_to_cache(cache, slot_mapping):
# lots of padding, or are doing warmup.
# This loop is a workaround for this issue. Please remove it
# once key_cache.index_put_(indices, offsets), key) works.
num_kv_cache_passes = torch.div(num_slots_requested,
num_slots_available).ceil().int().item()
num_kv_cache_passes = -(-num_slots_requested // num_slots_available)

return num_kv_cache_passes, num_slots_available, indices, offsets

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .interfaces import SupportsLoRA

is_hpu = current_platform.is_hpu()

class GPTBigCodeAttention(nn.Module):

Expand Down Expand Up @@ -225,13 +226,13 @@ def forward(
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

if current_platform.is_hpu():
if is_hpu:
import habana_frameworks.torch as htorch
htorch.core.mark_step()
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
if current_platform.is_hpu():
if is_hpu:
htorch.core.mark_step()

hidden_states = self.ln_f(hidden_states)
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

is_hpu = current_platform.is_hpu()

class LlamaMLP(nn.Module):

Expand Down Expand Up @@ -318,7 +319,7 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

if current_platform.is_hpu():
if is_hpu:
import habana_frameworks.torch as htorch
htorch.core.mark_step()
for i in range(self.start_layer, self.end_layer):
Expand All @@ -330,7 +331,7 @@ def forward(
attn_metadata,
residual,
)
if current_platform.is_hpu():
if is_hpu:
htorch.core.mark_step()

if not get_pp_group().is_last_rank:
Expand Down