Skip to content

Commit

Permalink
* Fix tests on NVIDIA GPUs after CK recent changes (fairinternal/xfor…
Browse files Browse the repository at this point in the history
…mers#1280)

* Fix cusparselt detection with PT upgrade

__original_commit__ = fairinternal/xformers@cc525a3
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 15, 2025
1 parent 6440945 commit 08cc74d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 38 deletions.
6 changes: 3 additions & 3 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def _get_drop_mask(op, batch_size, q_len, kv_len, p, device):
return mask


@cuda_only
@rocm_only
@pytest.mark.parametrize("attn_bias", [None, fmha.attn_bias.LowerTriangularMask()])
@pytest.mark.parametrize("seed", [42, 124])
@pytest.mark.parametrize("p", [0.3, 0.7])
Expand Down Expand Up @@ -888,7 +888,7 @@ def test_dropout_ck(q_len, kv_len, batch_size, k_len, p, seed, attn_bias):
assert all(p_values > p_val_tol)


@cuda_only
@rocm_only
@pytest.mark.parametrize("p", [0.000001, 0.3, 0.7])
@pytest.mark.parametrize("k", [16, 64, 128])
@pytest.mark.parametrize("batch_size", [1, 2])
Expand Down Expand Up @@ -2280,7 +2280,7 @@ def test_paged_attention(
)


@cuda_only
@rocm_only
@pytest.mark.parametrize("B", [1, 5, 128])
@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192])
@pytest.mark.parametrize("page_size", [128, 256])
Expand Down
42 changes: 7 additions & 35 deletions xformers/ops/sp24.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import ctypes
import glob
import os
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast

import torch
Expand Down Expand Up @@ -46,32 +42,16 @@ class Sp24Gemm(BaseOperator):
NAME = "_sparse24_gemm"


def _get_cusparselt_lib() -> Optional[str]:
libs = glob.glob(
str(Path(torch._C.__file__).parent / "lib" / "libcusparseLt*.so.0")
)
if len(libs) != 1:
return None
return libs[0]


def _get_cusparselt_torch_version() -> Tuple[int, int, int]:
"""
Returns the version of the cusparselt.so library that ships with pytorch 2.2+
Returns the version of the cusparselt.so library used by pytorch
"""
lib_path = _get_cusparselt_lib()
if lib_path is None:
if not torch.backends.cusparselt.is_available():
return (0, 0, 0)
lib = ctypes.CDLL(lib_path)

def get_version_part(version_part: int) -> int:
value = ctypes.c_int()
ret = lib.cusparseLtGetProperty(version_part, ctypes.byref(value))
if ret != 0:
return -1
return value.value

return (get_version_part(0), get_version_part(1), get_version_part(2))
version: Optional[int] = torch.backends.cusparselt.version()
if version is None:
return (0, 0, 0)
return ((version // 10000) % 100, (version // 100) % 100, version % 100)


_cusplt_version = _get_cusparselt_torch_version()
Expand All @@ -93,17 +73,9 @@ class Sp24GemmCusplt(BaseOperator):


def _has_cusparseLt() -> bool:
available = _cusplt_version >= (0, 4, 0)
available = _cusplt_version >= (0, 5, 0)
if not available:
return False
if _cusplt_version < (0, 5, 0):
# Version 0.5.0 has much better perf because it can fuse the
# transpose within the GEMM epilogue
warnings.warn(
f"You have cusparseLt version {_cusplt_version_str} "
f"but you get better performance with v0.5.0+ if "
f"you replace the .so file ({_get_cusparselt_lib()})"
)

# Sm90 added in 6.0
compute_capability = (0, 0)
Expand Down

0 comments on commit 08cc74d

Please sign in to comment.