Skip to content

Commit

Permalink
Update testing_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Cemberk authored Jan 13, 2025
1 parent 01fe200 commit a608d6b
Showing 1 changed file with 64 additions and 28 deletions.
92 changes: 64 additions & 28 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,35 +230,45 @@ def parse_int_from_env(key, default=None):
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
_test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False)

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None):

import platform

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None, os_name=None):
"""
Pytest decorator to skip a test on AMD systems running ROCm based on GPU architecture and/or ROCm version.
Pytest decorator to skip a test on AMD systems running ROCm, with additional conditions based on
GPU architecture, ROCm version, and/or operating system.
The decorator uses shell commands to:
- Detect the GPU vendor.
- Extract the GPU architecture (for AMD, using `/opt/rocm/bin/rocminfo`).
- Read the ROCm version (by reading the file `/opt/rocm/.info/version`).
- Extract the GPU architecture for AMD via `/opt/rocm/bin/rocminfo`.
- Read the ROCm version from `/opt/rocm/.info/version`.
In addition, it can detect the current operating system:
- On Linux, it attempts to parse `/etc/os-release` for the OS "ID" (e.g. "rhel", "sles", "ubuntu").
- If `/etc/os-release` is not available, it falls back to `platform.system()`.
Behavior on an AMD system:
- If neither `arch` nor `rocm_version` is provided, the test is skipped unconditionally.
- If `arch` is provided (as a string or a list/tuple), the test is skipped if the detected GPU
architecture is in the given list.
- If `rocm_version` is provided (as a string or a list/tuple), the test is skipped if the current
ROCm version matches (or starts with) one of the provided version strings.
- If both are provided, the test is skipped if either condition is met.
Behavior on an AMD (ROCm) system:
- If no additional conditions are provided (i.e. arch, rocm_version, and os_name are all None),
the test is skipped unconditionally.
- If `arch` is provided (as a string or list), the test is skipped if the detected GPU architecture
matches one of the provided values.
- If `rocm_version` is provided (as a string or list), the test is skipped if the ROCm version (from
`/opt/rocm/.info/version`) matches (or begins with) one of the provided strings.
- If `os_name` is provided (as a string or list), the test is skipped if the current OS is among the provided names.
- If more than one condition is provided, the test will be skipped if **any** of those conditions are met.
On non-AMD systems (e.g. NVIDIA), the test will run normally.
On non-AMD systems (e.g. if the GPU vendor is detected as NVIDIA), the test will run normally.
Parameters:
msg (str): The skip message.
arch (str or iterable of str, optional): The GPU architecture(s) for which to skip the test.
rocm_version (str or iterable of str, optional): The ROCm version(s) which should cause the test to be skipped.
arch (str or iterable of str, optional): GPU architecture(s) for which to skip the test.
rocm_version (str or iterable of str, optional): ROCm version(s) for which to skip the test.
os_name (str or iterable of str, optional): Operating system ID(s) (e.g. "rhel", "sles", "ubuntu")
for which to skip the test.
"""

def get_gpu_vendor():
"""
Returns the GPU vendor as detected by checking for NVIDIA or ROCm utilities.
"""
"""Returns the GPU vendor by checking for NVIDIA or ROCm utilities."""
cmd = (
"bash -c 'if [[ -f /usr/bin/nvidia-smi ]] && "
"$(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo \"NVIDIA\"; "
Expand All @@ -270,8 +280,8 @@ def get_gpu_vendor():
def get_system_gpu_architecture():
"""
Returns the GPU architecture string if the vendor is AMD.
For AMD, it extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`.
For NVIDIA, it returns the GPU name via `nvidia-smi` (though this is used only for information).
For AMD, extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`.
For NVIDIA, returns the GPU name using `nvidia-smi` (informational only).
"""
vendor = get_gpu_vendor()
if vendor == "AMD":
Expand All @@ -287,44 +297,70 @@ def get_system_gpu_architecture():

def get_rocm_version():
"""
Returns the ROCm version as a string by reading /opt/rocm/.info/version.
Expected format example: "6.4.0-12299"
Returns the ROCm version as a string by reading the file /opt/rocm/.info/version.
Expected format (example): "6.4.0-15396"
"""
cmd = "cat /opt/rocm/.info/version"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def get_current_os():
"""
Attempts to determine the current operating system.
On Linux, parses /etc/os-release for the OS ID (e.g., "rhel", "sles", "ubuntu").
Otherwise, falls back to platform.system().
"""
if os.name == "posix" and os.path.exists("/etc/os-release"):
try:
with open("/etc/os-release") as f:
for line in f:
if line.startswith("ID="):
# ID value may be quoted.
return line.split("=")[1].strip().strip('"').lower()
except Exception:
# Fallback to platform information
pass
# For non-Linux systems or if /etc/os-release is not available.
return platform.system().lower()

def dec_fn(fn):
reason = f"skipIfRocm: {msg}"

@wraps(fn)
def wrapper(*args, **kwargs):
vendor = get_gpu_vendor()
# Only consider the ROCm skip logic for AMD systems.
if vendor == "AMD":
# Determine if the test should be skipped based on GPU architecture or ROCm version.
should_skip = False

# No specific conditions provided; skip unconditionally on AMD.
if arch is None and rocm_version is None:
# If no specific conditions are provided, skip unconditionally.
if arch is None and rocm_version is None and os_name is None:
should_skip = True

# Check the GPU architecture condition if provided.
# Check GPU architecture if provided.
if arch is not None:
arch_list = (arch,) if isinstance(arch, str) else arch
current_gpu_arch = get_system_gpu_architecture()
if current_gpu_arch in arch_list:
should_skip = True

# Check the ROCm version condition if provided.
# Check the ROCm version if provided.
if rocm_version is not None:
ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version
current_version = get_rocm_version()
# Use startswith so you can specify major.minor.patch without the build suffix.
# Using startswith allows matching "6.4.0" even if the full version is "6.4.0-15396"
if any(current_version.startswith(v) for v in ver_list):
should_skip = True

# Check the operating system if provided.
if os_name is not None:
os_list = (os_name,) if isinstance(os_name, str) else os_name
current_os = get_current_os()
if current_os in os_list:
should_skip = True

if should_skip:
pytest.skip(reason)

# For non-AMD systems the test runs normally.
return fn(*args, **kwargs)
return wrapper

Expand Down

0 comments on commit a608d6b

Please sign in to comment.