diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d81ffbc371f6d..48eae03ed703c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -161,14 +161,11 @@ def is_full_nvlink(physical_device_ids: List[int]) -> bool: def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) handle = amdsmi_get_processor_handles()[physical_device_id] - market_name = amdsmi_get_gpu_asic_info(handle)["market_name"] - # Note: this may not be exactly the same as the torch device name - # E.g. `AMD Instinct MI300X OAM` vs `AMD Instinct MI300X` - if "MI308" in market_name: - return "AMD Instinct MI308X" - if "MI300" in market_name: - return "AMD Instinct MI300X" - return market_name + gpu_info = amdsmi_get_gpu_asic_info(handle) + # Using num_cu to distinguish mi300 and mi308 + num_cu = gpu_info["num_compute_units"] + gpu_target_name = gpu_info["target_graphics_version"] + return f"AMD_{gpu_target_name}_{num_cu}CU" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: