Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor driver discovery #825

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ TRANSLATOR:
const:
- {action: accept, from: "^NVSANDBOXUTILS_"}
- {action: accept, from: "^nvSandboxUtils"}
- {action: replace, from: "^NVSANDBOXUTILS_255_MASK_", to: "MASK255_" }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this was required.

- {action: replace, from: "^NVSANDBOXUTILS_"}
- {action: replace, from: "^nvSandboxUtils"}
- {action: accept, from: "^NV"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
# limitations under the License.
**/

package nvcdi
package dgpu

import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"
"golang.org/x/sys/unix"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
Expand All @@ -32,80 +31,50 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
)

// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
// The supplied NVML Library is used to query the expected driver version.
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string, ldconfigPath string, nvmllib nvml.Interface) (discover.Discover, error) {
if r := nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initialize NVML: %v", r)
}
defer func() {
if r := nvmllib.Shutdown(); r != nvml.SUCCESS {
logger.Warningf("failed to shutdown NVML: %v", r)
}
}()

version, r := nvmllib.SystemGetDriverVersion()
if r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to determine driver version: %v", r)
}

return newDriverVersionDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
}

func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
// newNvmlDriverDiscoverer constructs a discoverer from the specified NVML library.
func (o *options) newNvmlDriverDiscoverer() (discover.Discover, error) {
libraries, err := o.newNvmlDriverLibraryDiscoverer()
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
}

ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
}

firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version)
firmwares, err := o.newNvmlDriverFirmwareDiscoverer()
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err)
}

binaries := NewDriverBinariesDiscoverer(logger, driver.Root)
binaries := o.newNvmlDriverBinariesDiscoverer()

d := discover.Merge(
libraries,
ipcs,
firmwares,
binaries,
)

return d, nil
}

// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
libraryPaths, err := getVersionLibs(logger, driver, version)
// newNvmlDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
func (o *options) newNvmlDriverLibraryDiscoverer() (discover.Discover, error) {
libraryPaths, err := getVersionLibs(o.logger, o.driver, o.version)
if err != nil {
return nil, fmt.Errorf("failed to get libraries for driver version: %v", err)
}

libraries := discover.NewMounts(
logger,
o.logger,
lookup.NewFileLocator(
lookup.WithLogger(logger),
lookup.WithRoot(driver.Root),
lookup.WithLogger(o.logger),
lookup.WithRoot(o.driver.Root),
),
driver.Root,
o.driver.Root,
libraryPaths,
)

updateLDCache, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath)

d := discover.Merge(
discover.WithDriverDotSoSymlinks(
libraries,
version,
nvidiaCDIHookPath,
),
updateLDCache,
d := discover.WithDriverDotSoSymlinks(
libraries,
o.version,
o.nvidiaCDIHookPath,
)

return d, nil
Expand Down Expand Up @@ -153,31 +122,31 @@ func getCustomFirmwareClassPath(logger logger.Interface) string {
return strings.TrimSpace(string(customFirmwareClassPath))
}

// NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version.
func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string) (discover.Discover, error) {
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger)
// newNvmlDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version.
func (o *options) newNvmlDriverFirmwareDiscoverer() (discover.Discover, error) {
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(o.logger)
if err != nil {
return nil, fmt.Errorf("failed to get firmware search paths: %v", err)
}
gspFirmwarePaths := filepath.Join("nvidia", version, "gsp*.bin")
gspFirmwarePaths := filepath.Join("nvidia", o.version, "gsp*.bin")
return discover.NewMounts(
logger,
o.logger,
lookup.NewFileLocator(
lookup.WithLogger(logger),
lookup.WithRoot(driverRoot),
lookup.WithLogger(o.logger),
lookup.WithRoot(o.driver.Root),
lookup.WithSearchPaths(gspFirmwareSearchPaths...),
),
driverRoot,
o.driver.Root,
[]string{gspFirmwarePaths},
), nil
}

// NewDriverBinariesDiscoverer creates a discoverer for GSP firmware associated with the GPU driver.
func NewDriverBinariesDiscoverer(logger logger.Interface, driverRoot string) discover.Discover {
// newNvmlDriverBinariesDiscoverer creates a discoverer for binaries associated with the specified driver version.
func (o *options) newNvmlDriverBinariesDiscoverer() discover.Discover {
return discover.NewMounts(
logger,
lookup.NewExecutableLocator(logger, driverRoot),
driverRoot,
o.logger,
lookup.NewExecutableLocator(o.logger, o.driver.Root),
o.driver.Root,
[]string{
"nvidia-smi", /* System management interface */
"nvidia-debugdump", /* GPU coredump utility */
Expand Down
31 changes: 31 additions & 0 deletions internal/platform-support/dgpu/driver-nvsandboxutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package dgpu

import (
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
)

// newNvsandboxutilsDriverDiscoverer constructs a discoverer from the specified nvsandboxutils library.
func (o *options) newNvsandboxutilsDriverDiscoverer() (discover.Discover, error) {
if o.nvsandboxutilslib == nil {
return nil, nil
}
return nil, fmt.Errorf("nvsandboxutils driver discovery is not implemented")
}
74 changes: 74 additions & 0 deletions internal/platform-support/dgpu/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package dgpu

import (
"errors"
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
)

// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
func NewDriverDiscoverer(opts ...Option) (discover.Discover, error) {
o := new(opts...)

if o.version == "" {
return nil, fmt.Errorf("a version must be specified")
}

var discoverers []discover.Discover
var errs error

nvsandboxutilsDiscoverer, err := o.newNvsandboxutilsDriverDiscoverer()
if err != nil {
// TODO: Log a warning
errs = errors.Join(errs, err)
} else if nvsandboxutilsDiscoverer != nil {
discoverers = append(discoverers, nvsandboxutilsDiscoverer)
}

nvmlDiscoverer, err := o.newNvmlDriverDiscoverer()
if err != nil {
// TODO: Log a warning
errs = errors.Join(errs, err)
} else if nvmlDiscoverer != nil {
discoverers = append(discoverers, nvmlDiscoverer)
}

if len(discoverers) == 0 {
return nil, errs
}

cached := discover.WithCache(
discover.FirstValid(
discoverers...,
),
)
updateLDCache, _ := discover.NewLDCacheUpdateHook(o.logger, cached, o.nvidiaCDIHookPath, o.ldconfigPath)

ipcs, err := discover.NewIPCDiscoverer(o.logger, o.driver.Root)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
}

return discover.Merge(
cached,
updateLDCache,
ipcs,
), nil
}
25 changes: 25 additions & 0 deletions internal/platform-support/dgpu/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ package dgpu

import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
)

type options struct {
logger logger.Interface
driver *root.Driver
devRoot string
ldconfigPath string
nvidiaCDIHookPath string

isMigDevice bool
Expand All @@ -33,6 +36,9 @@ type options struct {
migCaps nvcaps.MigCaps
migCapsError error

// version stores the driver version.
version string

nvsandboxutilslib nvsandboxutils.Interface
}

Expand All @@ -45,6 +51,19 @@ func WithDevRoot(root string) Option {
}
}

func WithDriver(driver *root.Driver) Option {
return func(l *options) {
l.driver = driver
}
}

// WithLdconfigPath sets the path to the ldconfig program
func WithLdconfigPath(path string) Option {
return func(l *options) {
l.ldconfigPath = path
}
}

// WithLogger sets the logger for the library
func WithLogger(logger logger.Interface) Option {
return func(l *options) {
Expand Down Expand Up @@ -72,3 +91,9 @@ func WithNvsandboxuitilsLib(nvsandboxutilslib nvsandboxutils.Interface) Option {
l.nvsandboxutilslib = nvsandboxutilslib
}
}

func WithVersion(version string) Option {
return func(l *options) {
l.version = version
}
}
13 changes: 11 additions & 2 deletions pkg/nvcdi/common-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
)

// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
// This includes driver libraries and meta devices, for example.
func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
func (l *nvmllib) newCommonNVMLDiscoverer(version string) (discover.Discover, error) {
metaDevices := discover.NewCharDeviceDiscoverer(
l.logger,
l.devRoot,
Expand All @@ -41,7 +42,15 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
}

driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath, l.ldconfigPath, l.nvmllib)
driverFiles, err := dgpu.NewDriverDiscoverer(
dgpu.WithDevRoot(l.devRoot),
dgpu.WithDriver(l.driver),
dgpu.WithLdconfigPath(l.ldconfigPath),
dgpu.WithLogger(l.logger),
dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath),
dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib),
dgpu.WithVersion(version),
)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
}
Expand Down
22 changes: 20 additions & 2 deletions pkg/nvcdi/lib-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ Interface = (*nvmllib)(nil)

// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
}

// GetAllDeviceSpecs returns the device specs for all available devices.
Expand Down Expand Up @@ -83,7 +83,25 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {

// GetCommonEdits generates a CDI specification that can be used for ANY devices
func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
common, err := l.newCommonNVMLDiscoverer()
if l.nvsandboxutilslib != nil {
if r := l.nvsandboxutilslib.Init(l.driverRoot); r != nvsandboxutils.SUCCESS {
l.logger.Warningf("Failed to init nvsandboxutils: %v; ignoring", r)
l.nvsandboxutilslib = nil
}
defer func() {
if l.nvsandboxutilslib == nil {
return
}
_ = l.nvsandboxutilslib.Shutdown()
}()
}

version, err := (*nvcdilib)(l).getDriverVersion()
if err != nil {
return nil, fmt.Errorf("failed to get driver version: %v", err)
}

common, err := l.newCommonNVMLDiscoverer(version)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
}
Expand Down
Loading
Loading