diff --git a/pjrt/cuda.go b/pjrt/cuda.go index 889a5ce..922bd03 100644 --- a/pjrt/cuda.go +++ b/pjrt/cuda.go @@ -1,11 +1,13 @@ package pjrt import ( - "k8s.io/klog/v2" "os" + "os/exec" "path" "path/filepath" "strings" + + "k8s.io/klog/v2" ) // This file includes the required hacks to support Nvidia's Cuda based PJRT plugins. @@ -31,8 +33,23 @@ func hasNvidiaGPU() bool { return false } hasGPU := len(matches) > 0 + + if !hasGPU { + // Execute the nvidia-smi command if present + _, lookErr := exec.LookPath("nvidia-smi") + if lookErr == nil { + cmd := exec.Command("nvidia-smi") + output, cmdErr := cmd.CombinedOutput() + if cmdErr == nil { + if strings.Contains(string(output), "NVIDIA-SMI") { + hasGPU = true + } + } + } + } + if !hasGPU { - klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", assuming there are no GPU cards installed in the system. " + + klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", and nvidia-smi command did not succeed, assuming there are no GPU cards installed in the system. " + "To force the attempt to use the \"cuda\" PJRT, use its absolute path.") } hasNvidiaGPUCache = &hasGPU