diff --git a/pjrt/cuda.go b/pjrt/cuda.go index 889a5ce..c70ef7d 100644 --- a/pjrt/cuda.go +++ b/pjrt/cuda.go @@ -1,11 +1,13 @@ package pjrt import ( - "k8s.io/klog/v2" "os" + "os/exec" "path" "path/filepath" "strings" + + "k8s.io/klog/v2" ) // This file includes the required hacks to support Nvidia's Cuda based PJRT plugins. @@ -28,15 +30,32 @@ func hasNvidiaGPU() bool { matches, err := filepath.Glob("/dev/nvidia*") if err != nil { klog.Errorf("Failed to figure out if there is an Nvidia GPU installed while searching for files matching \"/dev/nvidia*\": %v", err) - return false } - hasGPU := len(matches) > 0 - if !hasGPU { - klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", assuming there are no GPU cards installed in the system. " + - "To force the attempt to use the \"cuda\" PJRT, use its absolute path.") + if len(matches) > 0 { + hasGPU := true + hasNvidiaGPUCache = &hasGPU + return hasGPU + } else { + klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", checking nvidia-smi command instead.") } - hasNvidiaGPUCache = &hasGPU - return hasGPU + + // Execute the nvidia-smi command if present + _, lookErr := exec.LookPath("nvidia-smi") + if lookErr == nil { + cmd := exec.Command("nvidia-smi") + output, cmdErr := cmd.CombinedOutput() + if cmdErr == nil { + if strings.Contains(string(output), "NVIDIA-SMI") { + hasGPU := true + hasNvidiaGPUCache = &hasGPU + return hasGPU + } + } + } + + klog.Infof("nvidia-smi command did not succeed, assuming there are no GPU cards installed in the system. " + + "To force the attempt to use the \"cuda\" PJRT, use its absolute path.") + return false } // cudaPluginCheckDrivers issues warning on cuda plugins if it cannot find the corresponding nvidia library files.