Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Flash Attention 3 Support #12093

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
43 changes: 19 additions & 24 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
Expand Down Expand Up @@ -533,7 +530,7 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# If no component is specified, vllm-flash-attn is still installed.

# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
Expand All @@ -545,43 +542,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_TAG ed2f90d84cb8ffee24c324f93265b4890fa793ab
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)

# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)

# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above
12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,11 @@ def target_name(s: str) -> str:

# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
# We assume only the final component of extension prefix is added by
# CMake, this is currently true for current extensions but may not
# always be the case.
prefix = outdir
for i in range(ext.name.count('.')):
if '.' in ext.name:
prefix = prefix.parent

# prefix here should actually be the same for all components
Expand Down Expand Up @@ -298,7 +301,8 @@ def run(self) -> None:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py",
# "vllm/_version.py", # not available in nightly wheels yet
Expand Down Expand Up @@ -592,8 +596,8 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
Expand Down
19 changes: 9 additions & 10 deletions tests/kernels/test_cascade_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_merge_kernel(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
Expand All @@ -87,6 +88,7 @@ def test_cascade(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
Expand Down Expand Up @@ -118,9 +120,7 @@ def test_cascade(
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
Expand All @@ -140,7 +140,7 @@ def test_cascade(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
Expand All @@ -154,10 +154,8 @@ def test_cascade(
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
cu_suffix_kv_lens = (
cu_kv_lens -
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)
cascade_attention(
output=output,
Expand All @@ -167,15 +165,16 @@ def test_cascade(
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
sliding_window=window_size,
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
fa_version=fa_version,
)

# Compare the results.
Expand Down
12 changes: 8 additions & 4 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
Expand All @@ -91,6 +92,7 @@ def test_flash_attn_with_paged_kv(
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
fa_version: int,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
Expand Down Expand Up @@ -131,6 +133,7 @@ def test_flash_attn_with_paged_kv(
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
)
output = output if not use_out else out
output = output.squeeze(1)
Expand Down Expand Up @@ -159,6 +162,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
Expand All @@ -170,6 +174,7 @@ def test_varlen_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
Expand Down Expand Up @@ -198,9 +203,7 @@ def test_varlen_with_paged_kv(
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
Expand All @@ -215,14 +218,15 @@ def test_varlen_with_paged_kv(
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)
output = output if not use_out else out

Expand Down
27 changes: 24 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
flash_attn_with_kvcache,
is_fa_version_supported)


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -632,6 +635,20 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type

# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2

if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

assert is_fa_version_supported(self.fa_version)

def forward(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -751,6 +768,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)
else:
# prefix-enabled attention
Expand All @@ -764,7 +782,7 @@ def forward(
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
Expand All @@ -773,6 +791,7 @@ def forward(
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -792,7 +811,7 @@ def forward(
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
Expand All @@ -801,6 +820,7 @@ def forward(
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -821,6 +841,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
)
return output

Expand Down
12 changes: 12 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
Expand Down Expand Up @@ -88,6 +89,12 @@ def get_default_config_root():
)


def maybe_convert_int(value: Optional[str]) -> Optional[int]:
if value is None:
return None
return int(value)


# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.

Expand Down Expand Up @@ -201,6 +208,11 @@ def get_default_config_root():
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),

# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),

# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
Expand Down
Loading
Loading