Skip to content

Commit

Permalink
vLLM-Base: Resolved ALiBI bias regression
Browse files Browse the repository at this point in the history
Changes:
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false",
  "VLLM_CONTIGUOUS_PA=false", "VLLM_PA_SOFTMAX_IMPL=wsum_head_amax".

Reamining TODO:
- Resolve quality issue when running prompts of significantly different
  lengths.
- Resolve issue with contiguous PA.
- Integrate support for GQA along with MHA.

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Nov 27, 2024
1 parent b7d75b8 commit 3c3e18a
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 68 deletions.
100 changes: 82 additions & 18 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,25 @@ def __init__(
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
self.alibi_slopes = None
self.prompt_position_bias = None
self.max_seq_len = max_seq_len
if alibi_slopes is not None:
self.max_seq_len = int(os.getenv('VLLM_PROMPT_ALIBI_MAX_SEQ_LEN',
self.max_seq_len))
slope_tensor_dtype = {
True: torch.float32,
False: torch.bfloat16,
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '0').lower() in ['1', 'true']]
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
dtype=slope_tensor_dtype)
self.alibi_slopes = alibi_slopes_tensor
self.prompt_position_bias = _make_prompt_alibi_bias(
self.alibi_slopes,
self.num_kv_heads,
self.alibi_slopes.dtype,
self.max_seq_len,
)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand All @@ -134,6 +148,16 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if self.use_contiguous_pa:
assert alibi_slopes is None, \
'Contiguous PA not supported with alibi slopes!'

if ops.ACTUAL_PA_SOFTMAX_IMPL != 'wsum_head_amax':
assert alibi_slopes is None, \
'Alibi slopes supports only "wsum_head_amax" softmax implementation!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -201,21 +225,19 @@ def forward(
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
position_bias = self.prompt_position_bias
if position_bias is not None:
position_bias = position_bias[:, :, -attn_bias.size(-2):, -attn_bias.size(-1):]
else:
attn_bias = None
position_bias = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
position_bias=position_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
Expand All @@ -242,6 +264,17 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
position_bias = None
attn_bias = attn_metadata.attn_bias
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
position_bias = _make_decode_alibi_bias(
alibi_blocks,
self.alibi_slopes,
self.num_kv_heads,
self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -252,17 +285,21 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
position_bias=position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
Expand All @@ -280,15 +317,42 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
) -> torch.Tensor:
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0), # num blocks
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return per_head_bias
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: Optional[int] = 4096,
prefix: str = "",
) -> None:
super().__init__()
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.position_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

Expand All @@ -146,7 +146,7 @@ def __init__(
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
if self.position_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
Expand All @@ -158,7 +158,9 @@ def __init__(
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=self.max_position_embeddings,
)
else:
self.rotary_emb = get_rope(
self.head_dim,
Expand All @@ -182,7 +184,7 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
if self.position_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
Expand Down Expand Up @@ -357,6 +359,7 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
self.use_alibi = position_embedding == "ALIBI"
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = True
self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def __init__(
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=max_position_embeddings,
)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
Expand Down Expand Up @@ -346,7 +348,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi

# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
Expand Down Expand Up @@ -417,6 +418,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = config.alibi
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __init__(
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=config.max_position_embeddings,
)

def forward(
self,
Expand Down Expand Up @@ -297,6 +299,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = config.position_embedding_type == "alibi"
self.transformer = JAISModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
self.max_seq_len = config.max_seq_len
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
Expand Down Expand Up @@ -115,7 +116,9 @@ def __init__(
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=self.max_seq_len,
)

def forward(
self,
Expand Down Expand Up @@ -281,7 +284,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config

self.use_alibi = config.attn_config['alibi']
self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte
Expand Down
Loading

0 comments on commit 3c3e18a

Please sign in to comment.