Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate PagedAttention Optimization custom kernel into vLLM #22

Merged
merged 13 commits into from
May 30, 2024
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(CUSTOM_SRC
"csrc/custom/custom_kernels.cu"
"csrc/custom/fused_kernels.cu"
"csrc/custom/custom.cu"
"csrc/custom/paged_attention/attention_ll4mi.cu"
)

define_gpu_extension_target(
Expand Down
6 changes: 6 additions & 0 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ The default attention function on ROCm is using triton attention kernel. To fall
## Tunable ops
Pytorch tunable ops are supported.
Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, also define `PYTORCH_TUNABLEOP_TUNING=1`

## Custom PagedAttention

On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.
64 changes: 45 additions & 19 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch

from vllm._C import ops
from vllm._custom_C import paged_attention_custom
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
PARTITION_SIZE = 256


@torch.inference_mode()
Expand Down Expand Up @@ -77,6 +78,9 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if not args.custom_paged_attn:
global PARTITION_SIZE
PARTITION_SIZE = 512
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
tmp_output = torch.empty(
Expand Down Expand Up @@ -118,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
if not args.custom_paged_attn:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
paged_attention_custom(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -191,6 +214,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument("--custom-paged-attn",
action="store_true",
help="Use custom paged attention")
args = parser.parse_args()
print(args)

Expand Down
25 changes: 25 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
at::cuda::getCurrentCUDAStream());
}

void paged_attention_custom(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);

// declare the extension module with the AddGPU function:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.doc() = "pybind11 example plugin";
m.def("LLMM1", &LLMM1);
m.def("LLMM_Silu", &LLMM_Silu);
m.def("LLZZ", &LLZZ);
m.def(
"paged_attention_custom",
&paged_attention_custom,
"PagedAttention LL4Mi Custom.");
//m.def("MMCustomGPU", &MMCustomGPU);
}
Loading
Loading