Skip to content

Commit

Permalink
fixes imports for custom paged attn
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna committed May 30, 2024
1 parent afc7dc1 commit c774517
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch

from vllm._C import ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._custom_C import paged_attention_custom
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random

NUM_BLOCKS = 1024
PARTITION_SIZE = 256
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_attention_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from allclose_default import get_default_atol, get_default_rtol

from vllm._C import cache_ops, ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm._custom_C import paged_attention_custom
from vllm.utils import get_max_shared_memory_bytes, is_hip

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

Expand All @@ -6,7 +7,6 @@
from vllm._C import cache_ops, ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip
import os

custom_attn_available = is_hip() and \
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0")
Expand Down

0 comments on commit c774517

Please sign in to comment.