Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/config/v1/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ type features struct {
// possibly bypassing other checks by an orchestration system such as
// kubernetes.
IgnoreImexChannelRequests *feature `toml:"ignore-imex-channel-requests,omitempty"`
// NoAdditionalGIDsForDeviceNodes disables the injection of additional GIDs
// for a device node when the node is not readable and writeable by the user.
NoAdditionalGIDsForDeviceNodes *feature `toml:"no-additional-gids-for-device-nodes,omitempty"`
}

type feature bool
Expand Down
40 changes: 38 additions & 2 deletions internal/edits/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package edits

import (
"io/fs"
"os"

"tags.cncf.io/container-device-interface/pkg/cdi"
Expand All @@ -26,7 +27,10 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
)

type device discover.Device
type device struct {
discover.Device
noAdditionalGIDs bool
}

// toEdits converts a discovered device to CDI Container Edits.
func (d device) toEdits() (*cdi.ContainerEdits, error) {
Expand All @@ -37,7 +41,8 @@ func (d device) toEdits() (*cdi.ContainerEdits, error) {

e := cdi.ContainerEdits{
ContainerEdits: &specs.ContainerEdits{
DeviceNodes: []*specs.DeviceNode{deviceNode},
DeviceNodes: []*specs.DeviceNode{deviceNode},
AdditionalGIDs: d.getAdditionalGIDs(deviceNode),
},
}
return &e, nil
Expand Down Expand Up @@ -110,3 +115,34 @@ func ptrIfNonZero[T uint32 | os.FileMode](id T) *T {
}
return &id
}

// getAdditionalGIDs returns the group id of the device if the device is not world read/writable.
// If the information cannot be extracted or an error occurs, 0 is returned.
func (d *device) getAdditionalGIDs(dn *specs.DeviceNode) []uint32 {
if d.noAdditionalGIDs {
return nil
}
// Handle the underdefined cases where we do not have enough information to
// extract the GID for the device OR whether the additional GID is required.
if dn == nil || dn.GID == nil || *dn.GID == 0 {
return nil
}
if dn.FileMode == nil {
return nil
}
if dn.FileMode.Type()&os.ModeCharDevice == 0 {
return nil
}
if permission := dn.FileMode.Perm(); isWorldReadable(permission) && isWorldWriteable(permission) {
return nil
}
return []uint32{*dn.GID}
}

func isWorldReadable(m fs.FileMode) bool {
return m&04 != 0
}

func isWorldWriteable(m fs.FileMode) bool {
return m&02 != 0
}
122 changes: 121 additions & 1 deletion internal/edits/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package edits

import (
"fmt"
"os"
"testing"

"github.com/opencontainers/cgroups/devices/config"
Expand All @@ -26,6 +27,7 @@ import (

"github.com/NVIDIA/nvidia-container-toolkit/internal/devices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test/to"
)

func TestDeviceToSpec(t *testing.T) {
Expand Down Expand Up @@ -94,14 +96,132 @@ func TestDeviceToSpec(t *testing.T) {
GID: ptrIfNonZero[uint32](44),
},
},
{
description: "device with additional GIDs",
device: discover.Device{
Path: "/foo",
},
deviceslib: &devices.InterfaceMock{
DeviceFromPathFunc: func(path, permissions string) (*devices.Device, error) {
if path != "/foo" {
return nil, fmt.Errorf("not found %v", path)
}
cd := &config.Device{
Rule: config.Rule{
Major: 100,
Minor: 200,
Permissions: config.Permissions("w"),
},
FileMode: 0660 | os.ModeCharDevice,
Uid: 11,
Gid: 44,
}

return (*devices.Device)(cd), nil
},
},
expected: &specs.DeviceNode{
Path: "/foo",
HostPath: "",
Permissions: "w",
Major: 100,
Minor: 200,
FileMode: to.Ptr(0660 | os.ModeCharDevice),
GID: ptrIfNonZero[uint32](44),
},
},
}

for _, tc := range testCases {
f := factory{}
t.Run(tc.description, func(t *testing.T) {
defer devices.SetInterfaceForTests(tc.deviceslib)()
spec, err := device(tc.device).toSpec()
spec, err := f.device(tc.device).toSpec()
require.NoError(t, err)
require.EqualValues(t, tc.expected, spec)
})
}
}

func TestGetAdditionalGIDs(t *testing.T) {
testCases := []struct {
description string
device *device
deviceNode *specs.DeviceNode
expectedAdditionalGIDs []uint32
}{
{
description: "feature disabled",
device: &device{noAdditionalGIDs: true},
},
{
description: "device node has no GID",
device: &device{},
},
{
description: "device node has zero GID",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](0),
},
},
{
description: "filemode not specified",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
},
},
{
description: "device node is not a character device",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
FileMode: to.Ptr(0666 | os.ModeSymlink),
},
},
{
description: "character device is world read-writeable",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
FileMode: to.Ptr(0666 | os.ModeCharDevice),
},
},
{
description: "character device is only world readable",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
FileMode: to.Ptr(0664 | os.ModeCharDevice),
},
expectedAdditionalGIDs: []uint32{1},
},
{
description: "character device is only world writeable",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
FileMode: to.Ptr(0662 | os.ModeCharDevice),
},
expectedAdditionalGIDs: []uint32{1},
},
{
description: "character device is not world read-writeable",
device: &device{},
deviceNode: &specs.DeviceNode{
GID: to.Ptr[uint32](1),
FileMode: to.Ptr(0660 | os.ModeCharDevice),
},
expectedAdditionalGIDs: []uint32{1},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
additionalGIDs := tc.device.getAdditionalGIDs(tc.deviceNode)

require.EqualValues(t, tc.expectedAdditionalGIDs, additionalGIDs)
})
}
}
92 changes: 54 additions & 38 deletions internal/edits/edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,51 @@ package edits
import (
"fmt"

ociSpecs "github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

type edits struct {
cdi.ContainerEdits
logger logger.Interface
const (
// An EmptyFactory is an edits factory that always returns empty CDI
// container edits.
EmptyFactory = empty("empty")
)

type Factory interface {
New() *cdi.ContainerEdits
FromDiscoverer(discover.Discover) (*cdi.ContainerEdits, error)
}

// NewSpecEdits creates a SpecModifier that defines the required OCI spec edits (as CDI ContainerEdits) from the specified
// discoverer.
func NewSpecEdits(logger logger.Interface, d discover.Discover) (oci.SpecModifier, error) {
c, err := FromDiscoverer(d)
if err != nil {
return nil, fmt.Errorf("error constructing container edits: %v", err)
type empty string

type factory struct {
Comment thread
elezar marked this conversation as resolved.
logger logger.Interface
noAdditionalGIDsForDeviceNodes bool
}

var _ Factory = (*empty)(nil)
var _ Factory = (*factory)(nil)

type Option func(*factory)

func NewFactory(opts ...Option) Factory {
f := &factory{
logger: &logger.NullLogger{},
}
e := edits{
ContainerEdits: *c,
logger: logger,
for _, opt := range opts {
opt(f)
}
return f
}

return &e, nil
func (f *factory) New() *cdi.ContainerEdits {
return EmptyFactory.New()
}

// FromDiscoverer creates CDI container edits for the specified discoverer.
func FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) {
func (f *factory) FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) {
devices, err := d.Devices()
if err != nil {
return nil, fmt.Errorf("failed to discover devices: %v", err)
Expand All @@ -70,9 +84,9 @@ func FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("failed to discover hooks: %v", err)
}

c := NewContainerEdits()
c := EmptyFactory.New()
for _, d := range devices {
edits, err := device(d).toEdits()
edits, err := f.device(d).toEdits()
if err != nil {
return nil, fmt.Errorf("failed to created container edits for device: %v", err)
}
Expand All @@ -94,32 +108,34 @@ func FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) {
return c, nil
}

// NewContainerEdits is a utility function to create a CDI ContainerEdits struct.
func NewContainerEdits() *cdi.ContainerEdits {
func (f *factory) device(d discover.Device) *device {
return &device{
Device: d,
noAdditionalGIDs: f.noAdditionalGIDsForDeviceNodes,
}
}

// New creates a set of empty CDI container edits for an empty factory.
func (e empty) New() *cdi.ContainerEdits {
c := cdi.ContainerEdits{
ContainerEdits: &specs.ContainerEdits{},
}
return &c
}

// Modify applies the defined edits to the incoming OCI spec
func (e *edits) Modify(spec *ociSpecs.Spec) error {
if e == nil || e.ContainerEdits.ContainerEdits == nil {
return nil
}
// FromDiscoverer creates a set of empty CDI container edits for ANY discoverer.
func (e empty) FromDiscoverer(_ discover.Discover) (*cdi.ContainerEdits, error) {
return e.New(), nil
}

e.logger.Infof("Mounts:")
for _, mount := range e.Mounts {
e.logger.Infof("Mounting %v at %v", mount.HostPath, mount.ContainerPath)
}
e.logger.Infof("Devices:")
for _, device := range e.DeviceNodes {
e.logger.Infof("Injecting %v", device.Path)
}
e.logger.Infof("Hooks:")
for _, hook := range e.Hooks {
e.logger.Infof("Injecting %v %v", hook.Path, hook.Args)
func WithLogger(logger logger.Interface) Option {
return func(f *factory) {
f.logger = logger
}
}

return e.Apply(spec)
func WithNoAdditionalGIDsForDeviceNodes(noAdditionalGIDsForDeviceNodes bool) Option {
return func(f *factory) {
f.noAdditionalGIDsForDeviceNodes = noAdditionalGIDsForDeviceNodes
}
}
4 changes: 3 additions & 1 deletion internal/edits/edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
)

// TODO: This test doesn't actually do anything.
func TestFromDiscovererAllowsMountsToIterate(t *testing.T) {
edits, err := FromDiscoverer(discover.None{})
t.Skip("This test does not test anything significant")
edits, err := NewFactory().FromDiscoverer(discover.None{})
require.NoError(t, err)

require.Empty(t, edits.Mounts)
Expand Down
Loading