From 9ba2fab2748c06f34816185118bb79b71544e790 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 01:07:21 +0000 Subject: [PATCH] Ingest FP8 attn scales and use them in Triton FA, if present --- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 2 +- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/hpu_attn.py | 2 +- vllm/attention/backends/ipex_attn.py | 2 +- vllm/attention/backends/pallas.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 10 ++++- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/xformers.py | 2 +- vllm/attention/layer.py | 14 ++++--- vllm/envs.py | 11 +++--- .../quantization/compressed_tensors/utils.py | 4 ++ .../layers/quantization/kv_cache.py | 39 +++++++++++++++++-- vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/exaone.py | 4 +- vllm/model_executor/models/granite.py | 4 +- vllm/model_executor/models/llama.py | 24 ++++++++---- vllm/model_executor/models/solar.py | 4 +- 19 files changed, 99 insertions(+), 37 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 3127285df6c63..de34c0da55b4f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -252,6 +252,6 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 1869dbab0cbf5..e6f63a1ff0fc9 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -363,7 +363,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c640de998149e..94a7478f8eb51 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -642,7 +642,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2285f20c1c8af..5301474833163 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -777,7 +777,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: # TODO: directly write to output tensor diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index b3065495ab396..efb3ac980aaba 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -154,7 +154,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index d02fbcb6ca0ae..0225554b0ceb1 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -174,7 +174,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a2612c97ca23b..7c702291513d2 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -152,7 +152,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index efaa74f67bafd..335bae56f6ab4 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -551,7 +551,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: torch.Tensor = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -601,6 +601,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None, + None) query = query.view(-1, self.num_heads, self.head_size) if key is not None: @@ -681,6 +683,10 @@ def forward( query.dtype, seq_lens, make_attn_mask=False) # type: ignore + full_scales = ( + 1.0 / q_scale.item(), 1.0 / k_scale.item(), + 1.0 / v_scale.item(), 1.0 / prob_scale.item(), + fp8_out_scale.item()) if fp8_comp_scales else None out, _ = self.attn_func( query, key, @@ -694,7 +700,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, - None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8e586c00024d5..660408f3ea477 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -434,7 +434,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dcd60da43d520..d0f1034dc966f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,7 +420,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 800a1ff2f4f65..d0d981e04db15 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -42,6 +42,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", + use_fp8: bool = False, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -73,8 +74,11 @@ def __init__( # with the model weights. self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales + self.use_fp8 = use_fp8 self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -106,11 +110,11 @@ def __init__( self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # For cuda and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( + self.use_direct_call = not current_platform.is_cuda( ) and not current_platform.is_cpu() # For some attention backends, we allocate an output tensor before @@ -135,7 +139,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: str = AttentionType.DECODER, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: @@ -150,7 +154,7 @@ def forward( self._k_scale, self._v_scale, attn_type=attn_type, - fp8_out_scale=fp8_out_scale) + fp8_comp_scales=fp8_comp_scales) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) diff --git a/vllm/envs.py b/vllm/envs.py index 19e520691e436..03646192e12d7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -16,7 +16,7 @@ VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_ATTN: bool = True RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -242,13 +242,12 @@ def get_default_config_root(): # custom paged attention implemented for MI3* cards "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1") != "0"), + ("true", "1")), # have custom paged attention implemented for MI3* cards write out fp8 - "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": - lambda: - (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in - ("true", "1") != "0"), + "VLLM_USE_ROCM_FP8_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in + ("true", "1")), # rank of the process in the distributed setting, used to determine # the driver worker diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index a74eaef5efdee..c474dcd0c5246 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index b386a9f309639..4502c8bea47c4 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -31,17 +31,20 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize Q and P = softmax(QK^T) scales + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if not layer.calculate_kv_scales: if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() @@ -75,11 +78,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" or layer.use_fp8) and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") + if layer.q_scale > 0.0 and layer.prob_scale > 0.0: + q_scale = layer.q_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + q_scale *= 2 + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + if not isinstance(q_scale, float) or not isinstance(prob_scale, float): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8: + print_warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") + del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index dd4b0c75cb84d..684e7f5382277 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -395,7 +395,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 6926de007bc84..ca85f418d5762 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -543,7 +543,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 6353a7703d6cb..3c5dfc2e7f45d 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 514ba946af6f7..05b237c955369 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -197,6 +197,14 @@ def __init__( else: sliding_window = None + # For CUDA devices and Navi4x, attn_fp8 will be set to false. + self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ + and current_platform.is_rocm() \ + and not is_navi() \ + and isinstance(quant_config, Fp8Config) \ + and hasattr(self.o_proj, "input_scale") \ + and self.o_proj.input_scale is not None + self.attn = Attention( self.num_heads, self.head_dim, @@ -206,12 +214,8 @@ def __init__( quant_config=quant_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", + use_fp8=self.attn_fp8, ) - # For CUDA devices and Navi4x, attn_fp8_out will be set to false. - self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ - and current_platform.is_rocm() \ - and not is_navi() \ - and isinstance(quant_config, Fp8Config) def forward( self, @@ -228,8 +232,10 @@ def forward( v, kv_cache, attn_metadata, - fp8_out_scale=self.o_proj.input_scale - if self.attn_fp8_out else None) + fp8_comp_scales=(self.attn._q_scale, + self.attn._prob_scale, + self.o_proj.input_scale) + if self.attn_fp8 else None) output, _ = self.o_proj(attn_output) return output @@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 087697bc45e61..4f3cdbbcee9f4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -502,7 +502,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue