Skip to content

Commit

Permalink
Update LoRA function names to align with vllm.lora.punica (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjuCSudhakaran authored Dec 10, 2024
1 parent e096d6f commit 4312768
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 33 deletions.
3 changes: 1 addition & 2 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def dispatch_bgmv_embedding(
x: torch.Tensor,
wb_t_all: torch.Tensor,
layer_idx: int,
scale: float,
):
"""
`wb_t_all` contains all LoRA-B weight matrices stacked at dimension 0 into
Expand All @@ -343,7 +342,7 @@ def dispatch_bgmv_embedding(
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])
out = x @ wb
y += out * scale
y += out


class MoeMatmul(torch.nn.Module):
Expand Down
47 changes: 16 additions & 31 deletions vllm_hpu_extension/punica_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,16 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: str):
super().__init__(max_num_batched_tokens, max_batches, device)

def add_lora(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
scale: float,
y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None,
*,
buffer: Optional[torch.Tensor] = None) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0)
y = y.view_as(y_org)

def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
scale: float,
output_slices: Tuple[int, ...]) -> None:
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
Expand All @@ -53,29 +38,29 @@ def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, 1.0)
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)

def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0)
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
y = y.view_as(y_org)

def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
):
dispatch_bgmv_embedding(y, x, w_t_all, 0, 1.0)
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)

0 comments on commit 4312768

Please sign in to comment.