Skip to content

Commit

Permalink
gfx1200 enablement (#80)
Browse files Browse the repository at this point in the history
* partial skips

* changes to fix some of the test runs with minor changes larger changes are skipped these are gfx1201 specific

* Update skip utility to test navi based changes  (#77)

added architecture based skip feature to test machine without other architectures being affected

* updated skip utility and changes accordingly to only affect gfx1201

* gfx1200 machine related changes

---------

Co-authored-by: Cemberk <[email protected]>
  • Loading branch information
gargrahul and Cemberk authored Jan 20, 2025
1 parent d8b4c30 commit 20d2e3e
Show file tree
Hide file tree
Showing 23 changed files with 461 additions and 11 deletions.
192 changes: 187 additions & 5 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,199 @@ def parse_int_from_env(key, default=None):
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
_test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False)

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):

import platform

class RocmUtil:
def __init__(self):
pass

def get_gpu_vendor(self):
"""Returns the GPU vendor by checking for NVIDIA or ROCm utilities."""
cmd = (
"bash -c 'if [[ -f /usr/bin/nvidia-smi ]] && "
"$(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo \"NVIDIA\"; "
"elif [[ -f /opt/rocm/bin/rocm-smi ]]; then echo \"AMD\"; "
"else echo \"Unable to detect GPU vendor\"; fi || true'"
)
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def get_system_gpu_architecture(self):
"""
Returns the GPU architecture string if the vendor is AMD.
For AMD, extracts a line starting with 'gfx' via `/opt/rocm/bin/rocminfo`.
For NVIDIA, returns the GPU name using `nvidia-smi` (informational only).
"""
vendor = self.get_gpu_vendor()
if vendor == "AMD":
cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
elif vendor == "NVIDIA":
cmd = (
"nvidia-smi -L | head -n1 | sed 's/(UUID: .*)//g' | sed 's/GPU 0: //g'"
)
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
else:
raise RuntimeError("Unable to determine GPU architecture due to unknown GPU vendor.")

def get_rocm_version(self):
"""
Returns the ROCm version as a string by reading the file /opt/rocm/.info/version.
Expected format (example): "6.4.0-15396"
"""
cmd = "cat /opt/rocm/.info/version"
return subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

def get_current_os(self):
"""
Attempts to determine the current operating system.
On Linux, parses /etc/os-release for the OS ID (e.g., "rhel", "sles", "ubuntu").
Otherwise, falls back to platform.system().
"""
if os.name == "posix" and os.path.exists("/etc/os-release"):
try:
with open("/etc/os-release") as f:
for line in f:
if line.startswith("ID="):
# ID value may be quoted.
return line.split("=")[1].strip().strip('"').lower()
except Exception:
# Fallback to platform information
pass
# For non-Linux systems or if /etc/os-release is not available.
return platform.system().lower()

def is_rocm_skippable(self, arch=None, rocm_version=None, os_name=None):
"""
Determines whether the current system should be considered "skippable" based on ROCm criteria.
This function returns True **only** if:
1. The GPU vendor is AMD (i.e. a ROCm system), and
2. EITHER no specific conditions are provided,
OR at least one of the provided conditions is met.
Parameters:
arch (str or iterable of str, optional): GPU architecture(s) that should cause skipping.
rocm_version (str or iterable of str, optional): ROCm version(s) (or version prefixes) that should cause skipping.
os_name (str or iterable of str, optional): OS name(s) (e.g., "rhel", "sles", "ubuntu", "windows", "darwin")
for which the test should be skipped.
Returns:
True if the system is AMD (ROCm) and meets any (or no) specified criteria (i.e. it is "skippable"),
otherwise False.
"""
vendor = self.get_gpu_vendor()
if vendor != "AMD":
# If the GPU vendor is not AMD, it is not a ROCm system and shouldn't be skipped.
return False

# If no conditions are provided, skip unconditionally on any AMD system.
if arch is None and rocm_version is None and os_name is None:
return True

# Check each condition; if any match, we mark the system as "skippable".
# Use OR logic.
# Check GPU architecture.
if arch is not None:
arch_list = (arch,) if isinstance(arch, str) else arch
current_gpu_arch = self.get_system_gpu_architecture()
if current_gpu_arch in arch_list:
return True

# Check ROCm version.
if rocm_version is not None:
ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version
current_ver = self.get_rocm_version()
if any(current_ver.startswith(v) for v in ver_list):
return True

# Check operating system.
if os_name is not None:
os_list = (os_name,) if isinstance(os_name, str) else os_name
current_os = self.get_current_os()
if current_os in os_list:
return True

return False

rocmUtils = RocmUtil()

def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack", arch=None, rocm_version=None, os_name=None):
"""
Pytest decorator to skip a test on AMD systems running ROCm, with additional conditions based on
GPU architecture, ROCm version, and/or operating system.
The decorator uses shell commands to:
- Detect the GPU vendor.
- Extract the GPU architecture for AMD via `/opt/rocm/bin/rocminfo`.
- Read the ROCm version from `/opt/rocm/.info/version`.
In addition, it can detect the current operating system:
- On Linux, it attempts to parse `/etc/os-release` for the OS "ID" (e.g. "rhel", "sles", "ubuntu").
- If `/etc/os-release` is not available, it falls back to `platform.system()`.
Behavior on an AMD (ROCm) system:
- If no additional conditions are provided (i.e. arch, rocm_version, and os_name are all None),
the test is skipped unconditionally.
- If `arch` is provided (as a string or list), the test is skipped if the detected GPU architecture
matches one of the provided values.
- If `rocm_version` is provided (as a string or list), the test is skipped if the ROCm version (from
`/opt/rocm/.info/version`) matches (or begins with) one of the provided strings.
- If `os_name` is provided (as a string or list), the test is skipped if the current OS is among the provided names.
- If more than one condition is provided, the test will be skipped if **any** of those conditions are met.
On non-AMD systems (e.g. if the GPU vendor is detected as NVIDIA), the test will run normally.
Parameters:
msg (str): The skip message.
arch (str or iterable of str, optional): GPU architecture(s) for which to skip the test.
rocm_version (str or iterable of str, optional): ROCm version(s) for which to skip the test.
os_name (str or iterable of str, optional): Operating system ID(s) (e.g. "rhel", "sles", "ubuntu")
for which to skip the test.
"""

def dec_fn(fn):
reason = f"skipIfRocm: {msg}"

@wraps(fn)
def wrapper(*args, **kwargs):
if _test_with_rocm:
pytest.skip(reason)
else:
return fn(*args, **kwargs)
vendor = rocmUtils.get_gpu_vendor()
# Only consider the ROCm skip logic for AMD systems.
if vendor == "AMD":
should_skip = False

# If no specific conditions are provided, skip unconditionally.
if arch is None and rocm_version is None and os_name is None:
should_skip = True

# Check GPU architecture if provided.
if arch is not None:
arch_list = (arch,) if isinstance(arch, str) else arch
current_gpu_arch = rocmUtils.get_system_gpu_architecture()
if current_gpu_arch in arch_list:
should_skip = True

# Check the ROCm version if provided.
if rocm_version is not None:
ver_list = (rocm_version,) if isinstance(rocm_version, str) else rocm_version
current_version = rocmUtils.get_rocm_version()
# Using startswith allows matching "6.4.0" even if the full version is "6.4.0-15396"
if any(current_version.startswith(v) for v in ver_list):
should_skip = True

# Check the operating system if provided.
if os_name is not None:
os_list = (os_name,) if isinstance(os_name, str) else os_name
current_os = rocmUtils.get_current_os()
if current_os in os_list:
should_skip = True

if should_skip:
pytest.skip(reason)
# For non-AMD systems the test runs normally.
return fn(*args, **kwargs)
return wrapper

if func:
return dec_fn(func)
return dec_fn
Expand Down
32 changes: 32 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,21 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
"return_tensors": "pt",
}

@skipIfRocm(arch='gfx1201')
def test_custom_logits_processor(self):
super().test_custom_logits_processor()
pass

@skipIfRocm(arch='gfx1201')
def test_max_new_tokens_encoder_decoder(self):
super().test_max_new_tokens_encoder_decoder()
pass

@skipIfRocm(arch='gfx1201')
def test_eos_token_id_int_and_list_beam_search(self):
super().test_eos_token_id_int_and_list_beam_search()
pass

@slow
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
Expand Down Expand Up @@ -2580,6 +2595,7 @@ def test_max_length_if_input_embeds(self):
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

@skipIfRocm(arch='gfx1201')
def test_min_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
Expand Down Expand Up @@ -2632,6 +2648,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
)

# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
@skipIfRocm(arch='gfx1201')
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
Expand Down Expand Up @@ -3214,6 +3231,7 @@ def test_logits_processor_not_inplace(self):
self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())

@skipIfRocm(arch='gfx1201')
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {
Expand Down Expand Up @@ -3242,6 +3260,7 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self):
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))

@skipIfRocm(arch='gfx1201')
def test_model_kwarg_encoder_signature_filtering(self):
# Has TF equivalent: ample use of framework-specific code
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
Expand Down Expand Up @@ -3279,6 +3298,7 @@ def forward(self, input_ids, **kwargs):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model.generate(input_ids, foo="bar")

@skipIfRocm(arch='gfx1201')
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
Expand Down Expand Up @@ -3336,6 +3356,7 @@ def test_default_assisted_generation(self):
self.assertEqual(config.assistant_confidence_threshold, 0.4)
self.assertEqual(config.is_assistant, False)

@skipIfRocm(arch='gfx1201')
def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down Expand Up @@ -3364,6 +3385,7 @@ def test_generated_length_assisted_generation(self):
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)

@skipIfRocm(arch='gfx1201')
def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down Expand Up @@ -3398,6 +3420,7 @@ def test_model_kwarg_assisted_decoding_decoder_only(self):
)
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())

@skipIfRocm(arch='gfx1201')
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
Expand Down Expand Up @@ -3464,6 +3487,7 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

@skipIfRocm(arch='gfx1201')
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
Expand Down Expand Up @@ -3542,6 +3566,7 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

@skipIfRocm(arch='gfx1201')
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.

Expand Down Expand Up @@ -3788,6 +3813,7 @@ def test_special_tokens_fall_back_to_model_default(self):
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

@skipIfRocm(arch='gfx1201')
def test_speculative_decoding_equals_regular_decoding(self):
draft_name = "double7/vicuna-68m"
target_name = "Qwen/Qwen2-0.5B-Instruct"
Expand Down Expand Up @@ -3818,6 +3844,7 @@ def test_speculative_decoding_equals_regular_decoding(self):

@pytest.mark.generate
@require_torch_multi_gpu
@skipIfRocm(arch='gfx1201')
def test_generate_with_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
Expand Down Expand Up @@ -3853,6 +3880,7 @@ def test_generate_with_static_cache_multi_gpu(self):

@pytest.mark.generate
@require_torch_multi_gpu
@skipIfRocm(arch='gfx1201')
def test_init_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
Expand Down Expand Up @@ -4034,6 +4062,7 @@ def test_padding_input_contrastive_search_t5(self):
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")

@skipIfRocm(arch='gfx1201')
def test_prepare_inputs_for_generation_decoder_llm(self):
"""Tests GenerationMixin.prepare_inputs_for_generation against expected usage with decoder-only llms."""

Expand Down Expand Up @@ -4150,6 +4179,7 @@ def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.

@skipIfRocm(arch='gfx1201')
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
Expand All @@ -4173,6 +4203,7 @@ def test_generate_compile_fullgraph_tiny(self):
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated

@skipIfRocm(arch='gfx1201')
def test_assisted_generation_early_exit(self):
"""
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
Expand Down Expand Up @@ -4209,6 +4240,7 @@ class TokenHealingTestCase(unittest.TestCase):
("empty_prompt", "", ""),
]
)
@skipIfRocm(arch='gfx1201')
def test_prompts(self, name, input, expected):
model_name_or_path = "distilbert/distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
Expand Down
12 changes: 11 additions & 1 deletion tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device, skipIfRocm

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -327,6 +327,16 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
test_headmasking = False
test_pruning = False

@skipIfRocm(arch='gfx1201')
def test_generate_with_static_cache(self):
super().test_generate_with_static_cache()
pass

@skipIfRocm(arch='gfx1201')
def test_generate_from_inputs_embeds_with_static_cache(self):
super().test_generate_from_inputs_embeds_with_static_cache()
pass

def setUp(self):
self.model_tester = DbrxModelTester(self)
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
Expand Down
Loading

0 comments on commit 20d2e3e

Please sign in to comment.