From 0c1cfc091803c2dd4920a4c3b94d8d89438a29d2 Mon Sep 17 00:00:00 2001 From: Rob Keevil Date: Mon, 2 Dec 2024 14:49:42 +0100 Subject: [PATCH] Backup check of nvidia-smi to detect nvidia cards e.g. in WSL --- pjrt/cuda.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/pjrt/cuda.go b/pjrt/cuda.go index 889a5ce..922bd03 100644 --- a/pjrt/cuda.go +++ b/pjrt/cuda.go @@ -1,11 +1,13 @@ package pjrt import ( - "k8s.io/klog/v2" "os" + "os/exec" "path" "path/filepath" "strings" + + "k8s.io/klog/v2" ) // This file includes the required hacks to support Nvidia's Cuda based PJRT plugins. @@ -31,8 +33,23 @@ func hasNvidiaGPU() bool { return false } hasGPU := len(matches) > 0 + + if !hasGPU { + // Execute the nvidia-smi command if present + _, lookErr := exec.LookPath("nvidia-smi") + if lookErr == nil { + cmd := exec.Command("nvidia-smi") + output, cmdErr := cmd.CombinedOutput() + if cmdErr == nil { + if strings.Contains(string(output), "NVIDIA-SMI") { + hasGPU = true + } + } + } + } + if !hasGPU { - klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", assuming there are no GPU cards installed in the system. " + + klog.Infof("No NVidia devices found matching \"/dev/nvidia*\", and nvidia-smi command did not succeed, assuming there are no GPU cards installed in the system. " + "To force the attempt to use the \"cuda\" PJRT, use its absolute path.") } hasNvidiaGPUCache = &hasGPU