diff --git a/tools/container/container.go b/tools/container/container.go index c2c50c5b..857dc37c 100644 --- a/tools/container/container.go +++ b/tools/container/container.go @@ -36,13 +36,14 @@ const ( // Options defines the shared options for the CLIs to configure containers runtimes. type Options struct { - Config string - Socket string - RuntimeName string - RuntimeDir string - SetAsDefault bool - RestartMode string - HostRootMount string + Config string + Socket string + RuntimeName string + RuntimeDir string + SetAsDefault bool + RestartMode string + HostRootMount string + RuntimeEnableCDI bool } // ParseArgs parses the command line arguments to the CLI @@ -111,6 +112,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error { } } + if o.RuntimeEnableCDI { + cfg.EnableCDI() + } + return nil } diff --git a/tools/container/runtime/containerd/config_v1_test.go b/tools/container/runtime/containerd/config_v1_test.go index 26b673e4..5ec9d439 100644 --- a/tools/container/runtime/containerd/config_v1_test.go +++ b/tools/container/runtime/containerd/config_v1_test.go @@ -415,6 +415,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV1EnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + runtimeEnableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + runtimeEnableCDI: false, + expectedEnableCDIValue: nil, + }, + { + runtimeEnableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &container.Options{ + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + RuntimeEnableCDI: tc.runtimeEnableCDI, + } + + cfg, err := toml.Empty.Load() + require.NoError(t, err, "%d: %v", i, tc) + + v1 := &containerd.ConfigV1{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v1) + require.NoError(t, err, "%d: %v", i, tc) + + enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue, "%d: %v", i, tc) + }) + } +} + func TestRevertV1Config(t *testing.T) { testCases := []struct { config map[string]interface { diff --git a/tools/container/runtime/containerd/config_v2_test.go b/tools/container/runtime/containerd/config_v2_test.go index 8c4620eb..cc203cf0 100644 --- a/tools/container/runtime/containerd/config_v2_test.go +++ b/tools/container/runtime/containerd/config_v2_test.go @@ -369,6 +369,52 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV2ConfigEnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + runtimeEnableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + runtimeEnableCDI: false, + expectedEnableCDIValue: nil, + }, + { + runtimeEnableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &container.Options{ + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + SetAsDefault: false, + RuntimeEnableCDI: tc.runtimeEnableCDI, + } + + cfg, err := toml.LoadMap(map[string]interface{}{}) + require.NoError(t, err) + + v2 := &containerd.Config{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v2) + require.NoError(t, err) + + enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV2Config(t *testing.T) { testCases := []struct { config map[string]interface { diff --git a/tools/container/runtime/docker/docker.go b/tools/container/runtime/docker/docker.go index fd5a2750..809d6807 100644 --- a/tools/container/runtime/docker/docker.go +++ b/tools/container/runtime/docker/docker.go @@ -55,6 +55,10 @@ func Setup(c *cli.Context, o *container.Options) error { return fmt.Errorf("unable to configure docker: %v", err) } + if o.RuntimeEnableCDI { + cfg.Set("features", map[string]bool{"cdi": true}) + } + err = RestartDocker(o) if err != nil { return fmt.Errorf("unable to restart docker: %v", err) diff --git a/tools/container/runtime/runtime.go b/tools/container/runtime/runtime.go index 865f92e8..7ef428ba 100644 --- a/tools/container/runtime/runtime.go +++ b/tools/container/runtime/runtime.go @@ -30,8 +30,9 @@ import ( const ( defaultSetAsDefault = true // defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled - defaultRuntimeName = "nvidia" - defaultHostRootMount = "/host" + defaultRuntimeName = "nvidia" + defaultHostRootMount = "/host" + defaultRuntimeEnableCDI = false runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT" ) @@ -89,6 +90,13 @@ func Flags(opts *Options) []cli.Flag { EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CONTAINERD_SET_AS_DEFAULT", "DOCKER_SET_AS_DEFAULT"}, Hidden: true, }, + &cli.BoolFlag{ + Name: "runtime-enable-cdi", + Usage: "Enable CDI in the configured runtime", + Value: defaultRuntimeEnableCDI, + Destination: &opts.RuntimeEnableCDI, + EnvVars: []string{"RUNTIME_ENABLE_CDI"}, + }, } flags = append(flags, containerd.Flags(&opts.containerdOptions)...) @@ -124,6 +132,9 @@ func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error { if opts.RestartMode == runtimeSpecificDefault { opts.RestartMode = crio.DefaultRestartMode } + if opts.RuntimeEnableCDI { + opts.RuntimeEnableCDI = false + } case docker.Name: if opts.Config == runtimeSpecificDefault { opts.Config = docker.DefaultConfig @@ -142,6 +153,7 @@ func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error { } func Setup(c *cli.Context, opts *Options, runtime string) error { + opts. switch runtime { case containerd.Name: return containerd.Setup(c, &opts.Options, &opts.containerdOptions)