Skip to content

Commit

Permalink
vLLM-Base: Full enabling of ALiBi
Browse files Browse the repository at this point in the history
Changes:
- Added back alibi biases to decode stage.
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
  - Prompt biases instantiated once rather than each forward.
  - Prompt and decode biases are shared across encoder/decoder layers.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
  "VLLM_CONTIGUOUS_PA=true".
- NTT patch for GQA

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Jan 14, 2025
1 parent eb0d42f commit 787d66c
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 43 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0766759
177 changes: 154 additions & 23 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)

from vllm.distributed import get_tensor_model_parallel_rank
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
logits_soft_cap: Optional[float] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand All @@ -142,11 +142,20 @@ def __init__(
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
self.prompt_position_bias = None
self.tp_rank = get_tensor_model_parallel_rank()
self.prev_attn = None
self.alibi_slopes = None
if alibi_slopes is not None:
slope_tensor_dtype = {
True: torch.float32,
False: torch.bfloat16,
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower()
in ['1', 'true']]
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
dtype=slope_tensor_dtype)
self.alibi_slopes = alibi_slopes_tensor

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand All @@ -157,12 +166,49 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if not self.use_contiguous_pa:
assert alibi_slopes is None, \
'Non-contiguous PA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

def _maybe_init_alibi_biases(
self,
max_seq_len: int = 4096,
prev_attn: Optional[torch.nn.Module] = None,
) -> None:
# Set upper bound on sequence length
max_seq_len_upper = int(
os.getenv(
'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN',
max_seq_len,
))
# Set lower bound on sequence length
self.max_seq_len = max([
max_seq_len_upper,
int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')),
])
self.prev_attn = None if prev_attn is None else prev_attn.impl
if self.alibi_slopes is not None:
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.alibi_slopes = self.prev_attn.alibi_slopes
self.prompt_position_bias = self.prev_attn.prompt_position_bias
else:
# Creating the prompt_position_bias once and reusing it
# if seq_len permits.
self.prompt_position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=self.max_seq_len,
dtype=self.alibi_slopes.dtype,
)

def forward(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -230,27 +276,42 @@ def forward(
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
# If we have alibi_slopes, incorporate them with
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
position_bias = None
if (self.prompt_position_bias is not None
and self.alibi_slopes is not None):
if self.max_seq_len >= max(attn_bias.size(-2),
attn_bias.size(-1)):
# Using pre-computed prompt_position_bias subset.
position_bias = self.prompt_position_bias[:, :,
-attn_bias.size(-2):,
-attn_bias.size(-1):]
else:
# For longer sequences than precomputed,
# recreate the bias. This is memory inefficient.
position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=max(attn_bias.size(-2),
attn_bias.size(-1)),
dtype=self.alibi_slopes.dtype,
)
else:
attn_bias = None
position_bias = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
position_bias=position_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
Expand Down Expand Up @@ -278,6 +339,20 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.position_bias = self.prev_attn.position_bias
else:
# For decoding, compute position bias using alibi_blocks.
self.position_bias = _make_decode_alibi_bias(
alibi_blocks=alibi_blocks,
alibi_slopes=self.alibi_slopes,
dtype=self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -288,14 +363,18 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
position_bias=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output

def forward_encoder_decoder(
self,
Expand Down Expand Up @@ -409,12 +488,25 @@ def forward_encoder_decoder(
return output.view(batch_size, -1, hidden_size)


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for prompt stage.
This tensor is reused or tiled as needed for each forward pass.
Does not scale with batch size or number of blocks.
Args:
alibi_slopes: shape = [num_heads]
seq_len: int
dtype: torch.dtype
Returns:
A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].
This bias encodes positional information via ALiBi slopes.
"""
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
Expand All @@ -427,15 +519,54 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for decode stage.
Uses stored alibi_blocks and slopes for final scaling.
Scales with number of blocks, not with batch size.
Args:
alibi_blocks: shape = [num_blocks, block_size]
alibi_slopes: shape = [num_heads]
dtype: torch.dtype
Returns:
A per-head bias tensor of shape [num_blocks, num_heads, block_size].
Each row encodes position-dependent ALiBi slopes for decoding steps.
"""
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0),
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def create_dummy_seq_group_metadata(self,
seq_len,
is_prompt,
lora_request=None,
temperature=0):
temperature=0,
last_block_assigned=0):
sampling_params = SamplingParams(temperature=0)
num_blocks = math.ceil(seq_len / self.block_size)
cross_block_table: Optional[List[int]] = None
Expand Down
Loading

0 comments on commit 787d66c

Please sign in to comment.