diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 9f9e994b2..598a40c1d 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -104,10 +104,12 @@ func (m command) build() *cli.Command { Destination: &opts.format, }, &cli.StringFlag{ - Name: "mode", - Aliases: []string{"discovery-mode"}, - Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.", - Value: nvcdi.ModeAuto, + Name: "mode", + Aliases: []string{"discovery-mode"}, + Usage: "The mode to use when discovering the available entities. " + + "One of [" + strings.Join(nvcdi.AllModes[string](), " | ") + "]. " + + "If mode is set to 'auto' the mode will be determined based on the system configuration.", + Value: string(nvcdi.ModeAuto), Destination: &opts.mode, }, &cli.StringFlag{ @@ -184,13 +186,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error { } opts.mode = strings.ToLower(opts.mode) - switch opts.mode { - case nvcdi.ModeAuto: - case nvcdi.ModeCSV: - case nvcdi.ModeNvml: - case nvcdi.ModeWsl: - case nvcdi.ModeManagement: - default: + if !nvcdi.IsValidMode(opts.mode) { return fmt.Errorf("invalid discovery mode: %v", opts.mode) } diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 1c84fefaa..f1c7b97ac 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -24,24 +24,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) -const ( - // ModeAuto configures the CDI spec generator to automatically detect the system configuration - ModeAuto = "auto" - // ModeNvml configures the CDI spec generator to use the NVML library. - ModeNvml = "nvml" - // ModeWsl configures the CDI spec generator to generate a WSL spec. - ModeWsl = "wsl" - // ModeManagement configures the CDI spec generator to generate a management spec. - ModeManagement = "management" - // ModeGds configures the CDI spec generator to generate a GDS spec. - ModeGds = "gds" - // ModeMofed configures the CDI spec generator to generate a MOFED spec. - ModeMofed = "mofed" - // ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV - // mountspec files. - ModeCSV = "csv" -) - // Interface defines the API for the nvcdi package type Interface interface { GetSpec() (spec.Interface, error) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 8ed9e5aa0..ef72efc39 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -46,7 +46,7 @@ type nvcdilib struct { logger logger.Interface nvmllib nvml.Interface nvsandboxutilslib nvsandboxutils.Interface - mode string + mode Mode devicelib device.Interface deviceNamers DeviceNamers driverRoot string @@ -206,28 +206,6 @@ func (m *wrapper) GetCommonEdits() (*cdi.ContainerEdits, error) { return edits, nil } -// resolveMode resolves the mode for CDI spec generation based on the current system. -func (l *nvcdilib) resolveMode() (rmode string) { - if l.mode != ModeAuto { - return l.mode - } - defer func() { - l.logger.Infof("Auto-detected mode as '%v'", rmode) - }() - - platform := l.infolib.ResolvePlatform() - switch platform { - case info.PlatformNVML: - return ModeNvml - case info.PlatformTegra: - return ModeCSV - case info.PlatformWSL: - return ModeWsl - } - l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) - return ModeNvml -} - // getCudaVersion returns the CUDA version of the current system. func (l *nvcdilib) getCudaVersion() (string, error) { version, err := l.getCudaVersionNvsandboxutils() diff --git a/pkg/nvcdi/mode.go b/pkg/nvcdi/mode.go new file mode 100644 index 000000000..ad08fa576 --- /dev/null +++ b/pkg/nvcdi/mode.go @@ -0,0 +1,117 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "sync" + + "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" +) + +type Mode string + +const ( + // ModeAuto configures the CDI spec generator to automatically detect the system configuration + ModeAuto = Mode("auto") + // ModeNvml configures the CDI spec generator to use the NVML library. + ModeNvml = Mode("nvml") + // ModeWsl configures the CDI spec generator to generate a WSL spec. + ModeWsl = Mode("wsl") + // ModeManagement configures the CDI spec generator to generate a management spec. + ModeManagement = Mode("management") + // ModeGds configures the CDI spec generator to generate a GDS spec. + ModeGds = Mode("gds") + // ModeMofed configures the CDI spec generator to generate a MOFED spec. + ModeMofed = Mode("mofed") + // ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV + // mountspec files. + ModeCSV = Mode("csv") +) + +type modeConstraint interface { + string | Mode +} + +type modes struct { + lookup map[Mode]bool + all []Mode +} + +var validModes modes +var validModesOnce sync.Once + +func getModes() modes { + validModesOnce.Do(func() { + all := []Mode{ + ModeAuto, + ModeNvml, + ModeWsl, + ModeManagement, + ModeGds, + ModeMofed, + ModeCSV, + } + lookup := make(map[Mode]bool) + + for _, m := range all { + lookup[m] = true + } + + validModes = modes{ + lookup: lookup, + all: all, + } + }, + ) + return validModes +} + +// AllModes returns the set of valid modes. +func AllModes[T modeConstraint]() []T { + var output []T + for _, m := range getModes().all { + output = append(output, T(m)) + } + return output +} + +// IsValidMode checks whether a specified mode is valid. +func IsValidMode[T modeConstraint](mode T) bool { + return getModes().lookup[Mode(mode)] +} + +// resolveMode resolves the mode for CDI spec generation based on the current system. +func (l *nvcdilib) resolveMode() (rmode Mode) { + if l.mode != ModeAuto { + return l.mode + } + defer func() { + l.logger.Infof("Auto-detected mode as '%v'", rmode) + }() + + platform := l.infolib.ResolvePlatform() + switch platform { + case info.PlatformNVML: + return ModeNvml + case info.PlatformTegra: + return ModeCSV + case info.PlatformWSL: + return ModeWsl + } + l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) + return ModeNvml +} diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 417687b96..362545d25 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -99,9 +99,9 @@ func WithNvmlLib(nvmllib nvml.Interface) Option { } // WithMode sets the discovery mode for the library -func WithMode(mode string) Option { +func WithMode[m modeConstraint](mode m) Option { return func(l *nvcdilib) { - l.mode = mode + l.mode = Mode(mode) } }