Skip to content

Commit

Permalink
Update skip utility to test navi based changes (#77)
Browse files Browse the repository at this point in the history
added architecture based skip feature to test machine without other architectures being affected
  • Loading branch information
Cemberk authored and gargrahul committed Jan 20, 2025
1 parent 5e44593 commit 3bb06c0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 11 deletions.
133 changes: 128 additions & 5 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,140 @@ def parse_int_from_env(key, default=None):
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
_test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False)

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):

import platform

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None, os_name=None):
"""
Pytest decorator to skip a test on AMD systems running ROCm, with additional conditions based on
GPU architecture, ROCm version, and/or operating system.
The decorator uses shell commands to:
- Detect the GPU vendor.
- Extract the GPU architecture for AMD via `/opt/rocm/bin/rocminfo`.
- Read the ROCm version from `/opt/rocm/.info/version`.
In addition, it can detect the current operating system:
- On Linux, it attempts to parse `/etc/os-release` for the OS "ID" (e.g. "rhel", "sles", "ubuntu").
- If `/etc/os-release` is not available, it falls back to `platform.system()`.
Behavior on an AMD (ROCm) system:
- If no additional conditions are provided (i.e. arch, rocm_version, and os_name are all None),
the test is skipped unconditionally.
- If `arch` is provided (as a string or list), the test is skipped if the detected GPU architecture
matches one of the provided values.
- If `rocm_version` is provided (as a string or list), the test is skipped if the ROCm version (from
`/opt/rocm/.info/version`) matches (or begins with) one of the provided strings.
- If `os_name` is provided (as a string or list), the test is skipped if the current OS is among the provided names.
- If more than one condition is provided, the test will be skipped if **any** of those conditions are met.
On non-AMD systems (e.g. if the GPU vendor is detected as NVIDIA), the test will run normally.
Parameters:
msg (str): The skip message.
arch (str or iterable of str, optional): GPU architecture(s) for which to skip the test.
rocm_version (str or iterable of str, optional): ROCm version(s) for which to skip the test.
os_name (str or iterable of str, optional): Operating system ID(s) (e.g. "rhel", "sles", "ubuntu")
for which to skip the test.
"""

def get_gpu_vendor():
"""Returns the GPU vendor by checking for NVIDIA or ROCm utilities."""
cmd = (
"bash -c 'if [[ -f /usr/bin/nvidia-smi ]] && "
"$(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo \"NVIDIA\"; "
"elif [[ -f /opt/rocm/bin/rocm-smi ]]; then echo \"AMD\"; "
"else echo \"Unable to detect GPU vendor\"; fi || true'"
)
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def get_system_gpu_architecture():
"""
Returns the GPU architecture string if the vendor is AMD.
For AMD, extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`.
For NVIDIA, returns the GPU name using `nvidia-smi` (informational only).
"""
vendor = get_gpu_vendor()
if vendor == "AMD":
cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
elif vendor == "NVIDIA":
cmd = (
"nvidia-smi -L | head -n1 | sed 's/(UUID: .*)//g' | sed 's/GPU 0: //g'"
)
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
else:
raise RuntimeError("Unable to determine GPU architecture due to unknown GPU vendor.")

def get_rocm_version():
"""
Returns the ROCm version as a string by reading the file /opt/rocm/.info/version.
Expected format (example): "6.4.0-15396"
"""
cmd = "cat /opt/rocm/.info/version"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def get_current_os():
"""
Attempts to determine the current operating system.
On Linux, parses /etc/os-release for the OS ID (e.g., "rhel", "sles", "ubuntu").
Otherwise, falls back to platform.system().
"""
if os.name == "posix" and os.path.exists("/etc/os-release"):
try:
with open("/etc/os-release") as f:
for line in f:
if line.startswith("ID="):
# ID value may be quoted.
return line.split("=")[1].strip().strip('"').lower()
except Exception:
# Fallback to platform information
pass
# For non-Linux systems or if /etc/os-release is not available.
return platform.system().lower()

def dec_fn(fn):
reason = f"skipIfRocm: {msg}"

@wraps(fn)
def wrapper(*args, **kwargs):
if _test_with_rocm:
pytest.skip(reason)
else:
return fn(*args, **kwargs)
vendor = get_gpu_vendor()
# Only consider the ROCm skip logic for AMD systems.
if vendor == "AMD":
should_skip = False

# If no specific conditions are provided, skip unconditionally.
if arch is None and rocm_version is None and os_name is None:
should_skip = True

# Check GPU architecture if provided.
if arch is not None:
arch_list = (arch,) if isinstance(arch, str) else arch
current_gpu_arch = get_system_gpu_architecture()
if current_gpu_arch in arch_list:
should_skip = True

# Check the ROCm version if provided.
if rocm_version is not None:
ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version
current_version = get_rocm_version()
# Using startswith allows matching "6.4.0" even if the full version is "6.4.0-15396"
if any(current_version.startswith(v) for v in ver_list):
should_skip = True

# Check the operating system if provided.
if os_name is not None:
os_list = (os_name,) if isinstance(os_name, str) else os_name
current_os = get_current_os()
if current_os in os_list:
should_skip = True

if should_skip:
pytest.skip(reason)
# For non-AMD systems the test runs normally.
return fn(*args, **kwargs)
return wrapper

if func:
return dec_fn(func)
return dec_fn
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,7 @@ def _inner_training_loop(
else:
self.accelerator.gradient_state._set_sync_gradients(True)

if (self.state.global_step == 10):
if (self.state.global_step == args.stable_train_warmup_steps):
start_train_stable_time = time.time()

if self.args.include_num_input_tokens_seen:
Expand Down Expand Up @@ -2659,9 +2659,8 @@ def _inner_training_loop(

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps,num_tokens=num_train_tokens,)

total_samples = self.state.global_step*total_train_batch_size if args.max_steps > 0 else num_examples*num_train_epochs
perf_samples = total_samples - self.args.warmup_steps*total_train_batch_size
stable_train_metrics = speed_metrics("stable_train", start_train_stable_time, perf_samples)
stable_train_samples = num_train_samples - args.stable_train_warmup_steps*total_train_batch_size
stable_train_metrics = speed_metrics("stable_train", start_train_stable_time, stable_train_samples)

self.store_flos()
metrics["total_flos"] = self.state.total_flos
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ class TrainingArguments:
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
warmup_steps (`int`, *optional*, defaults to 0):
Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
stable_train_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to skip before collecting performance numbers for stable_train_samples_per_second.
log_level (`str`, *optional*, defaults to `passive`):
Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the
Expand Down Expand Up @@ -604,8 +606,7 @@ class TrainingArguments:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
ortmodule (:obj:`bool`, `optional`):
ort (:obj:`bool`, `optional`):
Use `ORTModule <https://github.com/microsoft/onnxruntime>`__.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
Expand Down Expand Up @@ -922,6 +923,7 @@ class TrainingArguments:
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)
warmup_steps: int = field(default=10, metadata={"help": "Linear warmup over warmup_steps."})
stable_train_warmup_steps: int = field(default=0, metadata={"help": "warmup steps to skip before collecting training performance."})

log_level: Optional[str] = field(
default="passive",
Expand Down

0 comments on commit 3bb06c0

Please sign in to comment.