From 431276837a1b8fb8878abbcb3e971ca1c43f0b82 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Tue, 10 Dec 2024 09:21:06 +0530 Subject: [PATCH] Update LoRA function names to align with vllm.lora.punica (#51) --- vllm_hpu_extension/ops.py | 3 +- vllm_hpu_extension/punica_hpu.py | 47 +++++++++++--------------------- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 36c52ea8..eb4cbded 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -322,7 +322,6 @@ def dispatch_bgmv_embedding( x: torch.Tensor, wb_t_all: torch.Tensor, layer_idx: int, - scale: float, ): """ `wb_t_all` contains all LoRA-B weight matrices stacked at dimension 0 into @@ -343,7 +342,7 @@ def dispatch_bgmv_embedding( wb = wb_t_all[:, 0, :, :].transpose(1, 2) wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) out = x @ wb - y += out * scale + y += out class MoeMatmul(torch.nn.Module): diff --git a/vllm_hpu_extension/punica_hpu.py b/vllm_hpu_extension/punica_hpu.py index 9b726156..9bf214e1 100644 --- a/vllm_hpu_extension/punica_hpu.py +++ b/vllm_hpu_extension/punica_hpu.py @@ -20,31 +20,16 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: str): super().__init__(max_num_batched_tokens, max_batches, device) - def add_lora(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None) -> None: - y_org = y - x = x.view(-1, x.shape[-1]) - y = y.view(-1, y.shape[-1]) - dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0) - y = y.view_as(y_org) - - def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, - output_slices: Tuple[int, ...]) -> None: + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: y_org = y x = x.view(-1, x.shape[-1]) y = y.view(-1, y.shape[-1]) @@ -53,29 +38,29 @@ def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, for slice_idx in range(len(output_slices)): dispatch_bgmv_linear( y[:, offset_left:offset_left + output_slices[slice_idx]], x, - lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, 1.0) + lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) offset_left += output_slices[slice_idx] y = y.view_as(y_org) def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, scale, *, buffer: Optional[torch.Tensor] = None) -> None: y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0) + dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) y = y.view_as(y_org) def add_lora_embedding( self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, + lora_b_stacked: torch.Tensor, add_input: bool = True, ): - dispatch_bgmv_embedding(y, x, w_t_all, 0, 1.0) + dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)