diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 38e7f311f55..15b9c304ecf 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -230,35 +230,45 @@ def parse_int_from_env(key, default=None): _run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) _test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False) -def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None): + +import platform + +def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None, os_name=None): """ - Pytest decorator to skip a test on AMD systems running ROCm based on GPU architecture and/or ROCm version. + Pytest decorator to skip a test on AMD systems running ROCm, with additional conditions based on + GPU architecture, ROCm version, and/or operating system. The decorator uses shell commands to: - Detect the GPU vendor. - - Extract the GPU architecture (for AMD, using `/opt/rocm/bin/rocminfo`). - - Read the ROCm version (by reading the file `/opt/rocm/.info/version`). + - Extract the GPU architecture for AMD via `/opt/rocm/bin/rocminfo`. + - Read the ROCm version from `/opt/rocm/.info/version`. + + In addition, it can detect the current operating system: + - On Linux, it attempts to parse `/etc/os-release` for the OS "ID" (e.g. "rhel", "sles", "ubuntu"). + - If `/etc/os-release` is not available, it falls back to `platform.system()`. - Behavior on an AMD system: - - If neither `arch` nor `rocm_version` is provided, the test is skipped unconditionally. - - If `arch` is provided (as a string or a list/tuple), the test is skipped if the detected GPU - architecture is in the given list. - - If `rocm_version` is provided (as a string or a list/tuple), the test is skipped if the current - ROCm version matches (or starts with) one of the provided version strings. - - If both are provided, the test is skipped if either condition is met. + Behavior on an AMD (ROCm) system: + - If no additional conditions are provided (i.e. arch, rocm_version, and os_name are all None), + the test is skipped unconditionally. + - If `arch` is provided (as a string or list), the test is skipped if the detected GPU architecture + matches one of the provided values. + - If `rocm_version` is provided (as a string or list), the test is skipped if the ROCm version (from + `/opt/rocm/.info/version`) matches (or begins with) one of the provided strings. + - If `os_name` is provided (as a string or list), the test is skipped if the current OS is among the provided names. + - If more than one condition is provided, the test will be skipped if **any** of those conditions are met. - On non-AMD systems (e.g. NVIDIA), the test will run normally. + On non-AMD systems (e.g. if the GPU vendor is detected as NVIDIA), the test will run normally. Parameters: msg (str): The skip message. - arch (str or iterable of str, optional): The GPU architecture(s) for which to skip the test. - rocm_version (str or iterable of str, optional): The ROCm version(s) which should cause the test to be skipped. + arch (str or iterable of str, optional): GPU architecture(s) for which to skip the test. + rocm_version (str or iterable of str, optional): ROCm version(s) for which to skip the test. + os_name (str or iterable of str, optional): Operating system ID(s) (e.g. "rhel", "sles", "ubuntu") + for which to skip the test. """ def get_gpu_vendor(): - """ - Returns the GPU vendor as detected by checking for NVIDIA or ROCm utilities. - """ + """Returns the GPU vendor by checking for NVIDIA or ROCm utilities.""" cmd = ( "bash -c 'if [[ -f /usr/bin/nvidia-smi ]] && " "$(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo \"NVIDIA\"; " @@ -270,8 +280,8 @@ def get_gpu_vendor(): def get_system_gpu_architecture(): """ Returns the GPU architecture string if the vendor is AMD. - For AMD, it extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`. - For NVIDIA, it returns the GPU name via `nvidia-smi` (though this is used only for information). + For AMD, extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`. + For NVIDIA, returns the GPU name using `nvidia-smi` (informational only). """ vendor = get_gpu_vendor() if vendor == "AMD": @@ -287,44 +297,70 @@ def get_system_gpu_architecture(): def get_rocm_version(): """ - Returns the ROCm version as a string by reading /opt/rocm/.info/version. - Expected format example: "6.4.0-12299" + Returns the ROCm version as a string by reading the file /opt/rocm/.info/version. + Expected format (example): "6.4.0-15396" """ cmd = "cat /opt/rocm/.info/version" return subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + def get_current_os(): + """ + Attempts to determine the current operating system. + On Linux, parses /etc/os-release for the OS ID (e.g., "rhel", "sles", "ubuntu"). + Otherwise, falls back to platform.system(). + """ + if os.name == "posix" and os.path.exists("/etc/os-release"): + try: + with open("/etc/os-release") as f: + for line in f: + if line.startswith("ID="): + # ID value may be quoted. + return line.split("=")[1].strip().strip('"').lower() + except Exception: + # Fallback to platform information + pass + # For non-Linux systems or if /etc/os-release is not available. + return platform.system().lower() + def dec_fn(fn): reason = f"skipIfRocm: {msg}" @wraps(fn) def wrapper(*args, **kwargs): vendor = get_gpu_vendor() + # Only consider the ROCm skip logic for AMD systems. if vendor == "AMD": - # Determine if the test should be skipped based on GPU architecture or ROCm version. should_skip = False - # No specific conditions provided; skip unconditionally on AMD. - if arch is None and rocm_version is None: + # If no specific conditions are provided, skip unconditionally. + if arch is None and rocm_version is None and os_name is None: should_skip = True - # Check the GPU architecture condition if provided. + # Check GPU architecture if provided. if arch is not None: arch_list = (arch,) if isinstance(arch, str) else arch current_gpu_arch = get_system_gpu_architecture() if current_gpu_arch in arch_list: should_skip = True - # Check the ROCm version condition if provided. + # Check the ROCm version if provided. if rocm_version is not None: ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version current_version = get_rocm_version() - # Use startswith so you can specify major.minor.patch without the build suffix. + # Using startswith allows matching "6.4.0" even if the full version is "6.4.0-15396" if any(current_version.startswith(v) for v in ver_list): should_skip = True + # Check the operating system if provided. + if os_name is not None: + os_list = (os_name,) if isinstance(os_name, str) else os_name + current_os = get_current_os() + if current_os in os_list: + should_skip = True + if should_skip: pytest.skip(reason) - + # For non-AMD systems the test runs normally. return fn(*args, **kwargs) return wrapper