From 787d66c4b9c6684e2b2a996ed5cc8f062f0a563b Mon Sep 17 00:00:00 2001 From: Tanner Voas Date: Thu, 14 Nov 2024 08:21:50 +0000 Subject: [PATCH] vLLM-Base: Full enabling of ALiBi Changes: - Added back alibi biases to decode stage. - Optimized ALiBI memory usage. - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow large models to run with restricted prompt lengths. - Prompt biases instantiated once rather than each forward. - Prompt and decode biases are shared across encoder/decoder layers. - Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve accuracy issue on long sequences. - Works in lazy and eager mode. - ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and "VLLM_CONTIGUOUS_PA=true". - NTT patch for GQA Co-authored-by: Tanner Voas Co-authored-by: Haihao Xiang Signed-off-by: Tanner Voas --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 177 ++++++++++++++++++--- vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/worker/hpu_enc_dec_model_runner.py | 3 +- vllm/worker/hpu_model_runner.py | 197 +++++++++++++++++++++--- 5 files changed, 337 insertions(+), 43 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index f4fb89ef42834..3f4cf33f105d6 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0766759 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 96dafe8c2fcb1..0c50a559784e8 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -10,7 +10,7 @@ import vllm_hpu_extension.ops as ops from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, VLLMKVCache) - +from vllm.distributed import get_tensor_model_parallel_rank from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState @@ -124,7 +124,7 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, - max_seq_len: int = 4096, + logits_soft_cap: Optional[float] = None, ) -> None: super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype @@ -142,11 +142,20 @@ def __init__( else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.alibi_slopes = alibi_slopes + self.prompt_position_bias = None + self.tp_rank = get_tensor_model_parallel_rank() + self.prev_attn = None + self.alibi_slopes = None if alibi_slopes is not None: + slope_tensor_dtype = { + True: torch.float32, + False: torch.bfloat16, + }[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower() + in ['1', 'true']] alibi_slopes_tensor = torch.tensor(alibi_slopes, - dtype=torch.bfloat16) + dtype=slope_tensor_dtype) self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -157,12 +166,49 @@ def __init__( assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' + if not self.use_contiguous_pa: + assert alibi_slopes is None, \ + 'Non-contiguous PA not supported with alibi slopes!' + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + def _maybe_init_alibi_biases( + self, + max_seq_len: int = 4096, + prev_attn: Optional[torch.nn.Module] = None, + ) -> None: + # Set upper bound on sequence length + max_seq_len_upper = int( + os.getenv( + 'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN', + max_seq_len, + )) + # Set lower bound on sequence length + self.max_seq_len = max([ + max_seq_len_upper, + int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')), + ]) + self.prev_attn = None if prev_attn is None else prev_attn.impl + if self.alibi_slopes is not None: + if (self.prev_attn is not None + and self.prev_attn.tp_rank == self.tp_rank): + self.alibi_slopes = self.prev_attn.alibi_slopes + self.prompt_position_bias = self.prev_attn.prompt_position_bias + else: + # Creating the prompt_position_bias once and reusing it + # if seq_len permits. + self.prompt_position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=self.max_seq_len, + dtype=self.alibi_slopes.dtype, + ) + def forward( self, query: torch.Tensor, @@ -230,27 +276,42 @@ def forward( query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + if attn_metadata is None or attn_metadata.block_list is None: if not self.prefill_use_fusedsdpa: # TODO: move this outside of model assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward' + # If we have alibi_slopes, incorporate them with attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None: - position_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, - attn_bias.dtype, attn_bias.shape[-1]) - attn_bias = attn_bias.tile( - (1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) + position_bias = None + if (self.prompt_position_bias is not None + and self.alibi_slopes is not None): + if self.max_seq_len >= max(attn_bias.size(-2), + attn_bias.size(-1)): + # Using pre-computed prompt_position_bias subset. + position_bias = self.prompt_position_bias[:, :, + -attn_bias.size(-2):, + -attn_bias.size(-1):] + else: + # For longer sequences than precomputed, + # recreate the bias. This is memory inefficient. + position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=max(attn_bias.size(-2), + attn_bias.size(-1)), + dtype=self.alibi_slopes.dtype, + ) else: attn_bias = None + position_bias = None out = ops.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), attn_bias=attn_bias, + position_bias=position_bias, p=0.0, scale=self.scale, matmul_qk_op=self.matmul_qk, @@ -278,6 +339,20 @@ def forward( output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. + self.position_bias = None + alibi_blocks = attn_metadata.alibi_blocks + if self.alibi_slopes is not None and alibi_blocks is not None: + if (self.prev_attn is not None + and self.prev_attn.tp_rank == self.tp_rank): + self.position_bias = self.prev_attn.position_bias + else: + # For decoding, compute position bias using alibi_blocks. + self.position_bias = _make_decode_alibi_bias( + alibi_blocks=alibi_blocks, + alibi_slopes=self.alibi_slopes, + dtype=self.alibi_slopes.dtype, + ) + output = HPUPagedAttention.forward_decode( query=query, key_cache=key_cache, @@ -288,14 +363,18 @@ def forward( block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, scale=self.scale, + position_bias=self.position_bias, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, batch2block_matmul_op=self.batch2block_matmul, block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + values_fetch_func=self.v_cache.fetch_from_cache, + ) + # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + output = output.view(batch_size, seq_len, hidden_size) + return output def forward_encoder_decoder( self, @@ -409,12 +488,25 @@ def forward_encoder_decoder( return output.view(batch_size, -1, hidden_size) -def _make_alibi_bias( +def _make_prompt_alibi_bias( alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, seq_len: int, + dtype: torch.dtype, ) -> torch.Tensor: + """ + Create the ALiBi position bias tensor for prompt stage. + This tensor is reused or tiled as needed for each forward pass. + Does not scale with batch size or number of blocks. + + Args: + alibi_slopes: shape = [num_heads] + seq_len: int + dtype: torch.dtype + + Returns: + A per-head bias tensor of shape [1, num_heads, seq_len, seq_len]. + This bias encodes positional information via ALiBi slopes. + """ bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` @@ -427,15 +519,54 @@ def _make_alibi_bias( padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size + per_head_bias = torch.empty( + 1, num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - return bias + )[:, :, :, :seq_len] + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias + # to all 32 heads in Eager mode. + per_head_bias[:, :] = bias + per_head_bias.mul_(alibi_slopes[:, None, None]) + + return per_head_bias + + +def _make_decode_alibi_bias( + alibi_blocks: torch.Tensor, + alibi_slopes: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Create the ALiBi position bias tensor for decode stage. + Uses stored alibi_blocks and slopes for final scaling. + Scales with number of blocks, not with batch size. + + Args: + alibi_blocks: shape = [num_blocks, block_size] + alibi_slopes: shape = [num_heads] + dtype: torch.dtype + + Returns: + A per-head bias tensor of shape [num_blocks, num_heads, block_size]. + Each row encodes position-dependent ALiBi slopes for decoding steps. + """ + num_heads = alibi_slopes.shape[0] + per_head_bias = torch.empty( + alibi_blocks.size(0), + num_heads, + alibi_blocks.size(-1), + device=alibi_slopes.device, + dtype=dtype, + ) + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias + # to all 32 heads in Eager mode. + per_head_bias[:, :] = alibi_blocks.unsqueeze(-2) + per_head_bias.mul_(alibi_slopes[None, :, None]) + + return per_head_bias diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index e55a4de11fd6c..d1235e6ec7aa7 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata: block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] + alibi_blocks: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 2951a4db2e478..9d24327d57ecc 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -411,7 +411,8 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - temperature=0): + temperature=0, + last_block_assigned=0): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) cross_block_table: Optional[List[int]] = None diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 0658d17edb0bc..3c0e4d8dfc263 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -685,6 +685,63 @@ def _set_gc_threshold(self) -> None: self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' + def _maybe_init_alibi_biases(self) -> None: + layers = None + layer_alibi_config = None + if (not hasattr(self.model, "config") + or not hasattr(self.model.config, "architectures")): + pass + elif "BaichuanForCausalLM" in self.model.config.architectures: + if self.model.config.hidden_size != 4096: + layers = self.model.model.layers + layer_alibi_config = lambda layer: ( + layer.self_attn.attn, + layer.self_attn.max_position_embeddings, + ) + elif "JAISLMHeadModel" in self.model.config.architectures: + if self.model.config.position_embedding_type == "alibi": + layers = self.model.transformer.h + layer_alibi_config = lambda layer: ( + layer.attn.attn, + self.model.config.max_position_embeddings, + ) + elif "FalconForCausalLM" in self.model.config.architectures: + if self.model.config.alibi: + layers = self.model.transformer.h + layer_alibi_config = lambda layer: ( + layer.self_attention.attn, + getattr(self.model.config, + "max_position_embeddings", 8192), + ) + elif "MPTForCausalLM" in self.model.config.architectures: + if self.model.config.attn_config['alibi']: + layers = self.model.transformer.blocks + layer_alibi_config = lambda layer: ( + layer.attn.attn, + self.model.config.max_seq_len, + ) + elif "BloomForCausalLM" in self.model.config.architectures: + layers = self.model.transformer.h + layer_alibi_config = lambda layer: ( + layer.self_attention.attn, + None, + ) + + if (layers is not None + and layer_alibi_config is not None): + self.use_alibi = True + prev_attn = None + for layer in layers: + attn, max_seq_len = layer_alibi_config(layer) + if (hasattr(attn.impl, "_maybe_init_alibi_biases")): + attn.impl._maybe_init_alibi_biases( + max_seq_len=max_seq_len, + prev_attn=prev_attn, + ) + prev_attn = attn + else: + self.use_alibi = False + def load_model(self) -> None: import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc' or \ @@ -756,6 +813,7 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() + self._maybe_init_alibi_biases() hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) @@ -999,10 +1057,14 @@ def _prepare_prompt( block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, + # Set by later "precompute_indices_and_offsets" function call block_indices=None, + # Set by later "precompute_indices_and_offsets" function call block_offsets=None, + # Set by later "_set_block_scales" function call block_scales=None, block_groups=None, + # Set by later "_set_attn_bias" function call attn_bias=None, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -1011,8 +1073,9 @@ def _prepare_prompt( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + alibi_blocks=None, + # FIXME(kzawora): mutli-modality will not work here + multi_modal_placeholder_index_maps=None, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) for t in multi_modal_kwargs: @@ -1151,6 +1214,13 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) + alibi_blocks = None + if self.use_alibi: + alibi_blocks = self._compute_alibi_block(block_tables, seq_lens, + len(block_groups)) + alibi_blocks = alibi_blocks.to( # type: ignore + self.device, non_blocking=True) + block_list = torch.tensor(block_list, dtype=torch.int, device='cpu') block_groups = torch.tensor(block_groups, dtype=torch.int, @@ -1178,12 +1248,17 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, + # Set by later "_set_block_mapping" function call block_mapping=None, block_usage=block_usage, + # Set by later "precompute_indices_and_offsets" function call block_indices=None, + # Set by later "precompute_indices_and_offsets" function call block_offsets=None, + # Set by later "_set_block_scales" function call block_scales=None, block_groups=block_groups, + # Set by later "_set_block_mapping" function call attn_bias=None, seq_lens_tensor=None, context_lens_tensor=None, @@ -1191,7 +1266,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + alibi_blocks=alibi_blocks, + multi_modal_placeholder_index_maps=None, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, @@ -1201,6 +1278,63 @@ def _prepare_decode( slot_mapping=slot_mapping, lora_ids=lora_ids) + def _compute_alibi_block(self, block_tables, seq_lens, num_blocks): + """ + Compute the ALiBi offsets for each block during decoding. + + For each block in each sequence, this function assigns position-based + offsets according to ALiBi logic. It returns a tensor that captures + these offsets for all sequences and blocks, which is then used for + decode-time ALiBi bias creation. + + Args: + block_tables: + A list of lists, where each inner list contains block indices + assigned to a particular sequence. + seq_lens: + A list of sequence lengths corresponding to each sequence. + num_blocks: + The total number of blocks across all sequences for which + ALiBi offsets need to be computed. + + Returns: + A torch.Tensor of shape [num_blocks, block_size], containing ALiBi + offsets for each block. + """ + # Create intermediary and output structures + max_block_table_len = max( + len(block_table) for block_table in block_tables) + alibi_offsets = torch.arange(-max_block_table_len * self.block_size + + 1, + 1, + dtype=torch.long, + device='cpu') + alibi_blocks = torch.zeros((num_blocks, self.block_size), + dtype=torch.long, + device='cpu') + + # Assign biases per token + for batch_idx in range(len(block_tables)): + seq_len = seq_lens[batch_idx] + for seq_idx in range(len(block_tables[batch_idx])): + block_idx = block_tables[batch_idx][seq_idx] + + # Calculate the number of valid positions in the current block + valid_length = seq_len - seq_idx * self.block_size + if valid_length > 0: + current_block_length = min(valid_length, self.block_size) + offset_end = current_block_length - valid_length + if offset_end == 0: + alibi_blocks[ + block_idx][:current_block_length] = alibi_offsets[ + -valid_length:] + else: + alibi_blocks[ + block_idx][:current_block_length] = alibi_offsets[ + -valid_length:offset_end] + + return alibi_blocks + def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1405,6 +1539,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_offsets', 'block_scales', 'block_groups', + 'alibi_blocks', ]) return attention_metadata @@ -1413,18 +1548,31 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - temperature=0): + temperature=0, + last_block_assigned=0): sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) - seq_len = max(seq_len, 1) + # FIXME(Tanner): + # When num_scheduler_steps>1 an additional + # token gets appended to dummy groups at some point + # This causes an RTE during warmup. Hence, subtracting 1 from seq_len. + seq_len = max(seq_len - 1, 1) + block_tables: Optional[dict[Any, Any]] = None if is_prompt: input_len = seq_len output_len = 0 - block_tables = None else: input_len = seq_len - 1 output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + # NOTE(Tanner): + # ALiBI biases fail if block_tables for + # dummy sequences are all zeros. + # By default "_PAD_BLOCK_ID" is "0" and this + # is not a realistic value for block tables. + block_tables = {group_id: []} + for block_idx in range(num_blocks): + last_block_assigned += 1 + block_tables[group_id] += [last_block_assigned] prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 @@ -1498,18 +1646,31 @@ def warmup_scenario(self, temperature=temperature) for i in range(batch_size) ] else: - # FIXME: seq_len is actually number of blocks - blocks = [seq_len // batch_size for _ in range(batch_size)] - blocks[0] += seq_len % batch_size - seqs = [ - self.create_dummy_seq_group_metadata( - i, - b * self.block_size - 1, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None, - temperature=temperature) for i, b in enumerate(blocks) + # NOTE(Tanner): + # seq_len is num blocks + # Here we assign as many blocks to each sequence as we can + blocks_per_seq = (seq_len - 1) // batch_size + extra_blocks = (seq_len - 1) % batch_size + blocks = [ + blocks_per_seq + (1 if i < extra_blocks else 0) + for i in range(batch_size) ] + seqs = [] + last_block_assigned = 0 + for i, b in enumerate(blocks): + seqs += [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None, + temperature=temperature, + last_block_assigned=last_block_assigned, + ) + ] + if len(seqs[-1].block_tables[i]) > 0: + last_block_assigned = seqs[-1].block_tables[i][-1] torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: