Skip to content

Commit

Permalink
Mask based BGMV implementation for LoRA Embedding (#247)
Browse files Browse the repository at this point in the history
This PR contains mask based BGMV implementation for LoRA embedding
instead of index-select of LoRA-B weights.

Removing special handling in no LoRA case also.
  • Loading branch information
vivekgoe authored Sep 9, 2024
2 parents 00f1333 + 016f343 commit b764610
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 57 deletions.
2 changes: 1 addition & 1 deletion tests/lora/test_multilora_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _test_llama_multilora(sql_lora_files, tp_size):
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
max_num_seqs=256,
dtype='float32',
tensor_parallel_size=tp_size)
engine = LLMEngine.from_engine_args(engine_args)
Expand Down
60 changes: 28 additions & 32 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,24 @@ def dispatch_bgmv_linear(
):
"""
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization.
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
stacked at dimension 0 into single tensors, assuming same rank. `wa` is the
reshaped and transposed version of `wa_t_all` of shape
(h_in, max_loras * lora_rank) and `wb` is the transposed and reshaped
version of `wb_t_all` of shape (max_loras * lora_rank, h_out).
Matmul input `x` with `wa`. Multiply `x` with a mask to zero-out inputs of
inactive LoRA indices. Matmul masked output with `wb` and scale it to get
the final output.
"""

assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
mask = LoraMask.getLoraMask()

wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
wa = wa.reshape(wa.shape[0] * wa.shape[1], wa.shape[2]).transpose(0, 1)
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])

out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
Expand All @@ -241,34 +243,28 @@ def dispatch_bgmv_linear(
def dispatch_bgmv_embedding(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` contains all LoRA A weight matrices stacked into a single tensor
assuming same rank. HPU handles no-LoRA requests using zero valued A
tensor. This zero valued tensor is appended at the end of `wa_t_all` during
initialization. For custom BGMV, the corresponding wa for each batch is
created based on the lora_index of the sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The wa tensor for a batch of size batch_Size will have a
shape of (batch_size, num_layers, lora_rank, hidden_dim)
This method avoids for-loop as well as graph breaks.
`wb_t_all` contains all LoRA-B weight matrices stacked at dimension 0 into
a single tensor, assuming same rank. `wb` is the transposed and reshaped
version of `wb_t_all` of shape (num_loras * lora_rank, embedding_dim).
Output of LoRA-A embedding (tensor x) is repeated max_loras times to match
the shape of `wb`. Multiply `x` with a mask to zero-out inputs of inactive
LoRA indices. Matmul masked output with `wb` and scale it to get the final
output.
"""

assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)
max_loras = wb_t_all.size(0)

x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
x = x.repeat(1, max_loras)
x = x * LoraMask.getLoraMask()
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])
out = x @ wb
y += out * scale
20 changes: 3 additions & 17 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import get_device, is_hpu, is_pin_memory_available
from vllm.utils import get_device, is_pin_memory_available

logger = init_logger(__name__)

Expand Down Expand Up @@ -465,25 +465,11 @@ def __init__(

@property
def capacity(self) -> int:
if is_hpu():
# HPU handles no LoRA requests using zero valued A and B tensors.
# These zero valued tensors are appended at the end of A and B,
# making total number of loras to be lora_config.max_cpu_loras + 1.
# This demands the total number of max_cpu_loras to be
# lora_config.max_cpu_loras + 1
return self.lora_config.max_cpu_loras + 1
else:
return self.lora_config.max_cpu_loras
return self.lora_config.max_cpu_loras

@property
def lora_slots(self) -> int:
if is_hpu():
# HPU handles no LoRA requests using zero valued A and B tensors.
# These zero valued tensors are appended at the end of A and B,
# making total number of loras to be lora_config.max_cpu_loras + 1.
return self.lora_config.max_loras + 1
else:
return self.lora_config.max_loras
return self.lora_config.max_loras

@property
def adapter_slots(self) -> int:
Expand Down
13 changes: 6 additions & 7 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,13 +757,12 @@ def _prepare_prompt(
lora_logits_mask: torch.Tensor = None
counter = 0
if self.lora_config:
lora_mask = torch.zeros(len(seq_group_metadata_list) *
max_prompt_len,
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_mask = torch.zeros(
len(seq_group_metadata_list) * max_prompt_len,
(self.lora_config.max_loras) * self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras + 1) *
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

Expand Down Expand Up @@ -887,7 +886,7 @@ def _prepare_decode(

if self.lora_config:
lora_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras + 1) *
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
Expand Down

0 comments on commit b764610

Please sign in to comment.