Skip to content

Commit

Permalink
Update testing_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Cemberk authored Jan 13, 2025
1 parent 1d9fdbb commit 01fe200
Showing 1 changed file with 53 additions and 28 deletions.
81 changes: 53 additions & 28 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,31 +230,34 @@ def parse_int_from_env(key, default=None):
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
_test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False)


def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None):
def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None):
"""
Pytest decorator to skip a test on AMD systems running ROCm.
Pytest decorator to skip a test on AMD systems running ROCm based on GPU architecture and/or ROCm version.
The decorator uses shell commands to:
- Detect the GPU vendor.
- Extract the GPU architecture information:
* For AMD: extracts using `/opt/rocm/bin/rocminfo`.
* For NVIDIA: extracts using `nvidia-smi` (for informational purposes, though skip logic is only applied for AMD).
Behavior:
- If the detected GPU vendor is AMD:
* When `arch` is None, the test is skipped unconditionally.
* When `arch` is provided (as a string or an iterable of strings), the test is skipped only if the
detected GPU architecture matches one of the specified values.
- If the GPU vendor is not AMD (e.g. NVIDIA), the test runs normally.
- Extract the GPU architecture (for AMD, using `/opt/rocm/bin/rocminfo`).
- Read the ROCm version (by reading the file `/opt/rocm/.info/version`).
Behavior on an AMD system:
- If neither `arch` nor `rocm_version` is provided, the test is skipped unconditionally.
- If `arch` is provided (as a string or a list/tuple), the test is skipped if the detected GPU
architecture is in the given list.
- If `rocm_version` is provided (as a string or a list/tuple), the test is skipped if the current
ROCm version matches (or starts with) one of the provided version strings.
- If both are provided, the test is skipped if either condition is met.
On non-AMD systems (e.g. NVIDIA), the test will run normally.
Parameters:
msg (str): The skip message.
arch (str or iterable of str, optional): The GPU architecture(s) to match against.
arch (str or iterable of str, optional): The GPU architecture(s) for which to skip the test.
rocm_version (str or iterable of str, optional): The ROCm version(s) which should cause the test to be skipped.
"""

def get_gpu_vendor():
"""
Returns the GPU vendor as determined by checking for NVIDIA or ROCm utilities.
Returns the GPU vendor as detected by checking for NVIDIA or ROCm utilities.
"""
cmd = (
"bash -c 'if [[ -f /usr/bin/nvidia-smi ]] && "
Expand All @@ -266,10 +269,9 @@ def get_gpu_vendor():

def get_system_gpu_architecture():
"""
Returns the GPU architecture string.
For AMD GPUs, it extracts a line starting with 'gfx' using `/opt/rocm/bin/rocminfo`.
For NVIDIA GPUs, it extracts the GPU name using `nvidia-smi`.
Returns the GPU architecture string if the vendor is AMD.
For AMD, it extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`.
For NVIDIA, it returns the GPU name via `nvidia-smi` (though this is used only for information).
"""
vendor = get_gpu_vendor()
if vendor == "AMD":
Expand All @@ -283,23 +285,46 @@ def get_system_gpu_architecture():
else:
raise RuntimeError("Unable to determine GPU architecture due to unknown GPU vendor.")

def get_rocm_version():
"""
Returns the ROCm version as a string by reading /opt/rocm/.info/version.
Expected format example: "6.4.0-12299"
"""
cmd = "cat /opt/rocm/.info/version"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def dec_fn(fn):
reason = f"skipIfRocm: {msg}"

@wraps(fn)
def wrapper(*args, **kwargs):
vendor = get_gpu_vendor()
# Apply the ROCm (AMD) skip logic only if the detected vendor is AMD.
if vendor == "AMD":
# If no specific architecture is provided, skip unconditionally.
if arch is None:
pytest.skip(reason)
else:
# Allow `arch` to be provided as a single string or an iterable.
# Determine if the test should be skipped based on GPU architecture or ROCm version.
should_skip = False

# No specific conditions provided; skip unconditionally on AMD.
if arch is None and rocm_version is None:
should_skip = True

# Check the GPU architecture condition if provided.
if arch is not None:
arch_list = (arch,) if isinstance(arch, str) else arch
current_gpu_arch = get_system_gpu_architecture()
if current_gpu_arch in arch_list:
pytest.skip(reason)
should_skip = True

# Check the ROCm version condition if provided.
if rocm_version is not None:
ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version
current_version = get_rocm_version()
# Use startswith so you can specify major.minor.patch without the build suffix.
if any(current_version.startswith(v) for v in ver_list):
should_skip = True

if should_skip:
pytest.skip(reason)

return fn(*args, **kwargs)
return wrapper

Expand Down

0 comments on commit 01fe200

Please sign in to comment.