From c774517080728ff5061d846f1f83fdf0441f19c5 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 30 May 2024 10:06:31 +0000 Subject: [PATCH] fixes imports for custom paged attn --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention_custom.py | 2 +- vllm/attention/ops/paged_attn.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 2161f4021f0a5..24f734ce8cce4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,8 +6,8 @@ import torch from vllm._C import ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._custom_C import paged_attention_custom +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 PARTITION_SIZE = 256 diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py index cf4f5ea6f5eeb..5bdbf126c22fa 100644 --- a/tests/kernels/test_attention_custom.py +++ b/tests/kernels/test_attention_custom.py @@ -6,8 +6,8 @@ from allclose_default import get_default_atol, get_default_rtol from vllm._C import cache_ops, ops -from vllm.utils import get_max_shared_memory_bytes, is_hip from vllm._custom_C import paged_attention_custom +from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 18f9bb67212e1..72811e1468ab6 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -6,7 +7,6 @@ from vllm._C import cache_ops, ops from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.utils import is_hip -import os custom_attn_available = is_hip() and \ (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0")