Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/edits/edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,9 @@ func (e *edits) Modify(spec *ociSpecs.Spec) error {

return e.Apply(spec)
}

func (e *edits) AddDeviceCgroupRules(spec *ociSpecs.Spec) error {
return nil
}

func (e *edits) WithDeviceResolver(resolver oci.DeviceResolver) {}
6 changes: 6 additions & 0 deletions internal/modifier/cdi/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,9 @@ func (m fromRegistry) Modify(spec *specs.Spec) error {

return nil
}

func (m fromRegistry) AddDeviceCgroupRules(spec *specs.Spec) error {
return nil
}

func (m fromRegistry) WithDeviceResolver(resolver oci.DeviceResolver) {}
8 changes: 7 additions & 1 deletion internal/modifier/cdi/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type fromCDISpec struct {

var _ oci.SpecModifier = (*fromCDISpec)(nil)

// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec.
// Modify applies the modifications defined by the raw CDI spec to the incoming OCI spec.
func (m fromCDISpec) Modify(spec *specs.Spec) error {
for _, device := range m.cdiSpec.Devices {
device := device
Expand All @@ -46,3 +46,9 @@ func (m fromCDISpec) Modify(spec *specs.Spec) error {

return m.cdiSpec.ApplyEdits(spec)
}

func (m fromCDISpec) AddDeviceCgroupRules(spec *specs.Spec) error {
return nil
}

func (m fromCDISpec) WithDeviceResolver(resolver oci.DeviceResolver) {}
6 changes: 6 additions & 0 deletions internal/modifier/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ func (m discoverModifier) Modify(spec *specs.Spec) error {

return specEdits.Modify(spec)
}

func (m discoverModifier) AddDeviceCgroupRules(spec *specs.Spec) error {
return nil
}

func (m discoverModifier) WithDeviceResolver(resolver oci.DeviceResolver) {}
6 changes: 6 additions & 0 deletions internal/modifier/hook_remover.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error {
return nil
}

func (m nvidiaContainerRuntimeHookRemover) AddDeviceCgroupRules(spec *specs.Spec) error {
return nil
}

func (m nvidiaContainerRuntimeHookRemover) WithDeviceResolver(resolver oci.DeviceResolver) {}

// isNVIDIAContainerRuntimeHook checks if the provided hook is an nvidia-container-runtime-hook
// or nvidia-container-toolkit hook. These are included, for example, by the non-experimental
// nvidia-container-runtime or docker when specifying the --gpus flag.
Expand Down
22 changes: 22 additions & 0 deletions internal/modifier/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ func (m List) Modify(spec *specs.Spec) error {
}
return nil
}

func (m List) AddDeviceCgroupRules(spec *specs.Spec) error {
for _, mm := range m {
if mm == nil {
continue
}
err := mm.AddDeviceCgroupRules(spec)
if err != nil {
return err
}
}
return nil
}

func (m List) WithDeviceResolver(resolver oci.DeviceResolver) {
for _, mm := range m {
if mm == nil {
continue
}
mm.WithDeviceResolver(resolver)
}
}
180 changes: 168 additions & 12 deletions internal/modifier/stable.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,212 @@
package modifier

import (
"fmt"
"path/filepath"
"strconv"
"strings"

"github.com/opencontainers/runtime-spec/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

// NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI
const (
visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES"
visibleDevicesVoid = "void"
visibleDevicesNone = "none"
visibleDevicesAll = "all"
)

// NewstableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI
// spec. The specified logger is used to capture log output.
func NewStableRuntimeModifier(logger logger.Interface, nvidiaContainerRuntimeHookPath string) oci.SpecModifier {
m := stableRuntimeModifier{
logger: logger,
nvidiaContainerRuntimeHookPath: nvidiaContainerRuntimeHookPath,
deviceResolver: oci.NewRealDeviceResolver("/dev"),
}

return &m
}

func (m *stableRuntimeModifier) WithDeviceResolver(resolver oci.DeviceResolver) {
m.deviceResolver = resolver
}

// stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made.
type stableRuntimeModifier struct {
logger logger.Interface
nvidiaContainerRuntimeHookPath string
deviceResolver oci.DeviceResolver
}

// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
// as a prestart hook.
func (m stableRuntimeModifier) Modify(spec *specs.Spec) error {
// If an NVIDIA Container Runtime Hook already exists, we don't make any modifications to the spec.
hookExists := false
if spec.Hooks != nil {
for _, hook := range spec.Hooks.Prestart {
hook := hook
if isNVIDIAContainerRuntimeHook(&hook) {
m.logger.Infof("Existing nvidia prestart hook (%v) found in OCI spec", hook.Path)
return nil
hookExists = true
break
}
}
}

if !hookExists {
path := m.nvidiaContainerRuntimeHookPath
m.logger.Infof("Using prestart hook path: %v", path)
args := []string{filepath.Base(path)}
if spec.Hooks == nil {
spec.Hooks = &specs.Hooks{}
}
spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{
Path: path,
Args: append(args, "prestart"),
})
}

if err := m.AddDeviceCgroupRules(spec); err != nil {
return err
}

return nil
}

func (m *stableRuntimeModifier) AddDeviceCgroupRules(spec *specs.Spec) error {

visibleDevices := getEnvVar(spec, visibleDevicesEnvvar)

if visibleDevices == "" || visibleDevices == visibleDevicesVoid || visibleDevices == visibleDevicesNone {
m.logger.Warning("NVIDIA_VISIBLE_DEVICES is void/none/empty, skipping cgroup rules")
return nil
}

if spec.Linux == nil {
spec.Linux = &specs.Linux{}
}

if spec.Linux.Resources == nil {
spec.Linux.Resources = &specs.LinuxResources{}
}

if err := addCommonDevices(m, spec); err != nil {
return fmt.Errorf("failed to add common devices: %v", err)
}

if err := addGPUDevices(m, spec, visibleDevices); err != nil {
return fmt.Errorf("failed to add GPU devices: %v", err)
}

return nil
}

func getEnvVar(spec *specs.Spec, key string) string {
if spec.Process == nil {
return ""
}

prefix := key + "="
for _, env := range spec.Process.Env {
if strings.HasPrefix(env, prefix) {
return strings.TrimPrefix(env, prefix)
}
}
return ""
}

func addCommonDevices(m *stableRuntimeModifier, spec *specs.Spec) error {
commonDevices := []string{
"/dev/nvidiactl",
"/dev/nvidia-uvm",
"/dev/nvidia-uvm-tools",
"/dev/nvidia-modeset",
}

for _, devicePath := range commonDevices {
rule, err := m.deviceResolver.DevicePathToRule(devicePath)
if err != nil {
return fmt.Errorf("failed to add common device %s: %v", devicePath, err)
}
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule)
m.logger.Debugf("Added cgroup rule for %s (major=%d, minor=%d)",
devicePath, *rule.Major, *rule.Minor)
}

return nil
}

func addGPUDevices(m *stableRuntimeModifier, spec *specs.Spec, visibleDevices string) error {
deviceList := strings.Split(visibleDevices, ",")

for _, device := range deviceList {
device = strings.TrimSpace(device)
if device == "" {
continue
}

if device == visibleDevicesAll {
return addAllGPUDevices(m, spec)
}

devicePaths, err := resolveDevicePaths(device)
if err != nil {
return fmt.Errorf("failed to resolve device %s: %v", device, err)
}

for _, devicePath := range devicePaths {
rule, err := m.deviceResolver.DevicePathToRule(devicePath)
if err != nil {
return fmt.Errorf("failed to add device %s: %v", devicePath, err)
}
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule)
m.logger.Debugf("Added cgroup rule for %s (major=%d, minor=%d)",
devicePath, *rule.Major, *rule.Minor)
}
}

path := m.nvidiaContainerRuntimeHookPath
m.logger.Infof("Using prestart hook path: %v", path)
args := []string{filepath.Base(path)}
if spec.Hooks == nil {
spec.Hooks = &specs.Hooks{}
return nil
}

func addAllGPUDevices(m *stableRuntimeModifier, spec *specs.Spec) error {

matches, err := m.deviceResolver.GlobDevices("nvidia[0-9]*")
if err != nil {
return fmt.Errorf("failed to glob nvidia devices: %v", err)
}

for _, devicePath := range matches {
rule, err := m.deviceResolver.DevicePathToRule(devicePath)
if err != nil {
return fmt.Errorf("failed to add device %s: %v", devicePath, err)
}
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule)
}

capsMatches, _ := m.deviceResolver.GlobDevices("nvidia-caps/nvidia-cap[0-9]*")
for _, devicePath := range capsMatches {
rule, err := m.deviceResolver.DevicePathToRule(devicePath)
if err != nil {
return fmt.Errorf("failed to add device %s: %v", devicePath, err)
}
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule)
}
spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{
Path: path,
Args: append(args, "prestart"),
})

return nil
}

func resolveDevicePaths(device string) ([]string, error) {
var paths []string

if idx, err := strconv.Atoi(device); err == nil {
paths = append(paths, fmt.Sprintf("/dev/nvidia%d", idx))
return paths, nil
}

return nil, fmt.Errorf("unknown device format: %s", device)
}
45 changes: 44 additions & 1 deletion internal/modifier/stable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
)

Expand Down Expand Up @@ -151,7 +152,6 @@ func TestAddHookModifier(t *testing.T) {
}

for _, tc := range testCases {
tc := tc

logHook.Reset()

Expand All @@ -171,3 +171,46 @@ func TestAddHookModifier(t *testing.T) {
}

}

func TestAddAllGPUDevicesWithMock(t *testing.T) {
logger, _ := testlog.NewNullLogger()

mockResolver := oci.NewMockDeviceResolver()

m := NewStableRuntimeModifier(logger, "")
m.WithDeviceResolver(mockResolver)

testCases := []struct {
description string
spec specs.Spec
expectedError error
expectedNoofDevices int
}{
{
description: "adds all GPU devices",
spec: specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=all"},
},
},
expectedNoofDevices: 7, // nvidia0, nvidia1, nvidia2, nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools
expectedError: nil,
},
{
description: "only one gpu device",
spec: specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=0"},
},
},
expectedNoofDevices: 5, // nvidia0, nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools
expectedError: nil,
},
}
for _, tc := range testCases {
err := m.AddDeviceCgroupRules(&tc.spec)
require.NoError(t, err)
require.Equal(t, tc.expectedNoofDevices, len(tc.spec.Linux.Resources.Devices))
}

}
Loading