diff --git a/cmd/nvidia-ctk/runtime/configure/configure.go b/cmd/nvidia-ctk/runtime/configure/configure.go index d2528853..5a7c16a1 100644 --- a/cmd/nvidia-ctk/runtime/configure/configure.go +++ b/cmd/nvidia-ctk/runtime/configure/configure.go @@ -292,9 +292,8 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error { return fmt.Errorf("unable to update config: %v", err) } - err = enableCDI(config, cfg) - if err != nil { - return fmt.Errorf("failed to enable CDI in %s: %w", config.runtime, err) + if config.cdi.enabled { + cfg.EnableCDI() } outputPath := config.getOutputConfigPath() @@ -354,19 +353,3 @@ func (m *command) configureOCIHook(c *cli.Context, config *config) error { } return nil } - -// enableCDI enables the use of CDI in the corresponding container engine -func enableCDI(config *config, cfg engine.Interface) error { - if !config.cdi.enabled { - return nil - } - switch config.runtime { - case "containerd": - cfg.Set("enable_cdi", true) - case "docker": - cfg.Set("features", map[string]bool{"cdi": true}) - default: - return fmt.Errorf("enabling CDI in %s is not supported", config.runtime) - } - return nil -} diff --git a/pkg/config/engine/api.go b/pkg/config/engine/api.go index bc6c4a68..addaad0b 100644 --- a/pkg/config/engine/api.go +++ b/pkg/config/engine/api.go @@ -24,6 +24,7 @@ type Interface interface { RemoveRuntime(string) error Save(string) (int64, error) GetRuntimeConfig(string) (RuntimeConfig, error) + EnableCDI() } // RuntimeConfig defines the interface to query container runtime handler configuration diff --git a/pkg/config/engine/containerd/config_v1.go b/pkg/config/engine/containerd/config_v1.go index db6cf2dc..a6626839 100644 --- a/pkg/config/engine/containerd/config_v1.go +++ b/pkg/config/engine/containerd/config_v1.go @@ -163,3 +163,7 @@ func (c *ConfigV1) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { tree: runtimeData, }, nil } + +func (c *ConfigV1) EnableCDI() { + c.Set("enable_cdi", true) +} diff --git a/pkg/config/engine/containerd/containerd.go b/pkg/config/engine/containerd/containerd.go index a5b08810..fa1f708f 100644 --- a/pkg/config/engine/containerd/containerd.go +++ b/pkg/config/engine/containerd/containerd.go @@ -126,6 +126,10 @@ func (c *Config) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { }, nil } +func (c *Config) EnableCDI() { + c.Set("enable_cdi", true) +} + // CommandLineSource returns the CLI-based containerd config loader func CommandLineSource(hostRoot string) toml.Loader { return toml.FromCommandLine(chrootIfRequired(hostRoot, "containerd", "config", "dump")...) diff --git a/pkg/config/engine/crio/crio.go b/pkg/config/engine/crio/crio.go index 3d5629d7..bba7c40d 100644 --- a/pkg/config/engine/crio/crio.go +++ b/pkg/config/engine/crio/crio.go @@ -153,6 +153,9 @@ func (c *Config) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { }, nil } +// no-op since CDI is always enabled in versions where CDI is supported +func (c *Config) EnableCDI() {} + // CommandLineSource returns the CLI-based crio config loader func CommandLineSource(hostRoot string) toml.Loader { return toml.LoadFirst( diff --git a/pkg/config/engine/docker/docker.go b/pkg/config/engine/docker/docker.go index 6ea64f06..4b092fda 100644 --- a/pkg/config/engine/docker/docker.go +++ b/pkg/config/engine/docker/docker.go @@ -166,3 +166,7 @@ func (c *Config) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { } return &dockerRuntime{}, nil } + +func (c *Config) EnableCDI() { + c.Set("features", map[string]bool{"cdi": true}) +} diff --git a/tools/container/container.go b/tools/container/container.go index c2c50c5b..857dc37c 100644 --- a/tools/container/container.go +++ b/tools/container/container.go @@ -36,13 +36,14 @@ const ( // Options defines the shared options for the CLIs to configure containers runtimes. type Options struct { - Config string - Socket string - RuntimeName string - RuntimeDir string - SetAsDefault bool - RestartMode string - HostRootMount string + Config string + Socket string + RuntimeName string + RuntimeDir string + SetAsDefault bool + RestartMode string + HostRootMount string + RuntimeEnableCDI bool } // ParseArgs parses the command line arguments to the CLI @@ -111,6 +112,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error { } } + if o.RuntimeEnableCDI { + cfg.EnableCDI() + } + return nil } diff --git a/tools/container/runtime/containerd/config_v1_test.go b/tools/container/runtime/containerd/config_v1_test.go index 26b673e4..5ec9d439 100644 --- a/tools/container/runtime/containerd/config_v1_test.go +++ b/tools/container/runtime/containerd/config_v1_test.go @@ -415,6 +415,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV1EnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + runtimeEnableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + runtimeEnableCDI: false, + expectedEnableCDIValue: nil, + }, + { + runtimeEnableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &container.Options{ + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + RuntimeEnableCDI: tc.runtimeEnableCDI, + } + + cfg, err := toml.Empty.Load() + require.NoError(t, err, "%d: %v", i, tc) + + v1 := &containerd.ConfigV1{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v1) + require.NoError(t, err, "%d: %v", i, tc) + + enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue, "%d: %v", i, tc) + }) + } +} + func TestRevertV1Config(t *testing.T) { testCases := []struct { config map[string]interface { diff --git a/tools/container/runtime/containerd/config_v2_test.go b/tools/container/runtime/containerd/config_v2_test.go index 8c4620eb..cc203cf0 100644 --- a/tools/container/runtime/containerd/config_v2_test.go +++ b/tools/container/runtime/containerd/config_v2_test.go @@ -369,6 +369,52 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV2ConfigEnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + runtimeEnableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + runtimeEnableCDI: false, + expectedEnableCDIValue: nil, + }, + { + runtimeEnableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &container.Options{ + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + SetAsDefault: false, + RuntimeEnableCDI: tc.runtimeEnableCDI, + } + + cfg, err := toml.LoadMap(map[string]interface{}{}) + require.NoError(t, err) + + v2 := &containerd.Config{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v2) + require.NoError(t, err) + + enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV2Config(t *testing.T) { testCases := []struct { config map[string]interface { diff --git a/tools/container/runtime/docker/docker.go b/tools/container/runtime/docker/docker.go index fd5a2750..809d6807 100644 --- a/tools/container/runtime/docker/docker.go +++ b/tools/container/runtime/docker/docker.go @@ -55,6 +55,10 @@ func Setup(c *cli.Context, o *container.Options) error { return fmt.Errorf("unable to configure docker: %v", err) } + if o.RuntimeEnableCDI { + cfg.Set("features", map[string]bool{"cdi": true}) + } + err = RestartDocker(o) if err != nil { return fmt.Errorf("unable to restart docker: %v", err) diff --git a/tools/container/runtime/runtime.go b/tools/container/runtime/runtime.go index 865f92e8..13232064 100644 --- a/tools/container/runtime/runtime.go +++ b/tools/container/runtime/runtime.go @@ -30,8 +30,9 @@ import ( const ( defaultSetAsDefault = true // defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled - defaultRuntimeName = "nvidia" - defaultHostRootMount = "/host" + defaultRuntimeName = "nvidia" + defaultHostRootMount = "/host" + defaultRuntimeEnableCDI = false runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT" ) @@ -89,6 +90,13 @@ func Flags(opts *Options) []cli.Flag { EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CONTAINERD_SET_AS_DEFAULT", "DOCKER_SET_AS_DEFAULT"}, Hidden: true, }, + &cli.BoolFlag{ + Name: "runtime-enable-cdi", + Usage: "Enable CDI in the configured runtime", + Value: defaultRuntimeEnableCDI, + Destination: &opts.RuntimeEnableCDI, + EnvVars: []string{"RUNTIME_ENABLE_CDI"}, + }, } flags = append(flags, containerd.Flags(&opts.containerdOptions)...) @@ -124,6 +132,9 @@ func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error { if opts.RestartMode == runtimeSpecificDefault { opts.RestartMode = crio.DefaultRestartMode } + if opts.RuntimeEnableCDI { + opts.RuntimeEnableCDI = false + } case docker.Name: if opts.Config == runtimeSpecificDefault { opts.Config = docker.DefaultConfig