diff --git a/internal/edits/edits.go b/internal/edits/edits.go index 4538ac317..9857f3d4a 100644 --- a/internal/edits/edits.go +++ b/internal/edits/edits.go @@ -123,3 +123,9 @@ func (e *edits) Modify(spec *ociSpecs.Spec) error { return e.Apply(spec) } + +func (e *edits) AddDeviceCgroupRules(spec *ociSpecs.Spec) error { + return nil +} + +func (e *edits) WithDeviceResolver(resolver oci.DeviceResolver) {} diff --git a/internal/modifier/cdi/registry.go b/internal/modifier/cdi/registry.go index d1faffad5..7f17df3da 100644 --- a/internal/modifier/cdi/registry.go +++ b/internal/modifier/cdi/registry.go @@ -64,3 +64,9 @@ func (m fromRegistry) Modify(spec *specs.Spec) error { return nil } + +func (m fromRegistry) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m fromRegistry) WithDeviceResolver(resolver oci.DeviceResolver) {} diff --git a/internal/modifier/cdi/spec.go b/internal/modifier/cdi/spec.go index 24b475ee0..8dafb0ba2 100644 --- a/internal/modifier/cdi/spec.go +++ b/internal/modifier/cdi/spec.go @@ -32,7 +32,7 @@ type fromCDISpec struct { var _ oci.SpecModifier = (*fromCDISpec)(nil) -// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec. +// Modify applies the modifications defined by the raw CDI spec to the incoming OCI spec. func (m fromCDISpec) Modify(spec *specs.Spec) error { for _, device := range m.cdiSpec.Devices { device := device @@ -46,3 +46,9 @@ func (m fromCDISpec) Modify(spec *specs.Spec) error { return m.cdiSpec.ApplyEdits(spec) } + +func (m fromCDISpec) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m fromCDISpec) WithDeviceResolver(resolver oci.DeviceResolver) {} diff --git a/internal/modifier/discover.go b/internal/modifier/discover.go index b249c5596..ad6139ef6 100644 --- a/internal/modifier/discover.go +++ b/internal/modifier/discover.go @@ -52,3 +52,9 @@ func (m discoverModifier) Modify(spec *specs.Spec) error { return specEdits.Modify(spec) } + +func (m discoverModifier) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m discoverModifier) WithDeviceResolver(resolver oci.DeviceResolver) {} diff --git a/internal/modifier/hook_remover.go b/internal/modifier/hook_remover.go index 24cf76662..8cafde1fb 100644 --- a/internal/modifier/hook_remover.go +++ b/internal/modifier/hook_remover.go @@ -73,6 +73,12 @@ func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error { return nil } +func (m nvidiaContainerRuntimeHookRemover) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m nvidiaContainerRuntimeHookRemover) WithDeviceResolver(resolver oci.DeviceResolver) {} + // isNVIDIAContainerRuntimeHook checks if the provided hook is an nvidia-container-runtime-hook // or nvidia-container-toolkit hook. These are included, for example, by the non-experimental // nvidia-container-runtime or docker when specifying the --gpus flag. diff --git a/internal/modifier/list.go b/internal/modifier/list.go index d1ce9d642..e1460cd36 100644 --- a/internal/modifier/list.go +++ b/internal/modifier/list.go @@ -51,3 +51,25 @@ func (m List) Modify(spec *specs.Spec) error { } return nil } + +func (m List) AddDeviceCgroupRules(spec *specs.Spec) error { + for _, mm := range m { + if mm == nil { + continue + } + err := mm.AddDeviceCgroupRules(spec) + if err != nil { + return err + } + } + return nil +} + +func (m List) WithDeviceResolver(resolver oci.DeviceResolver) { + for _, mm := range m { + if mm == nil { + continue + } + mm.WithDeviceResolver(resolver) + } +} diff --git a/internal/modifier/stable.go b/internal/modifier/stable.go index 3a842bffb..086ae040e 100644 --- a/internal/modifier/stable.go +++ b/internal/modifier/stable.go @@ -17,7 +17,10 @@ package modifier import ( + "fmt" "path/filepath" + "strconv" + "strings" "github.com/opencontainers/runtime-spec/specs-go" @@ -25,48 +28,201 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) -// NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI +const ( + visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES" + visibleDevicesVoid = "void" + visibleDevicesNone = "none" + visibleDevicesAll = "all" +) + +// NewstableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI // spec. The specified logger is used to capture log output. func NewStableRuntimeModifier(logger logger.Interface, nvidiaContainerRuntimeHookPath string) oci.SpecModifier { m := stableRuntimeModifier{ logger: logger, nvidiaContainerRuntimeHookPath: nvidiaContainerRuntimeHookPath, + deviceResolver: oci.NewRealDeviceResolver("/dev"), } - return &m } +func (m *stableRuntimeModifier) WithDeviceResolver(resolver oci.DeviceResolver) { + m.deviceResolver = resolver +} + // stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a // prestart hook. If the hook is already present, no modification is made. type stableRuntimeModifier struct { logger logger.Interface nvidiaContainerRuntimeHookPath string + deviceResolver oci.DeviceResolver } // Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook // as a prestart hook. func (m stableRuntimeModifier) Modify(spec *specs.Spec) error { // If an NVIDIA Container Runtime Hook already exists, we don't make any modifications to the spec. + hookExists := false if spec.Hooks != nil { for _, hook := range spec.Hooks.Prestart { hook := hook if isNVIDIAContainerRuntimeHook(&hook) { m.logger.Infof("Existing nvidia prestart hook (%v) found in OCI spec", hook.Path) - return nil + hookExists = true + break + } + } + } + + if !hookExists { + path := m.nvidiaContainerRuntimeHookPath + m.logger.Infof("Using prestart hook path: %v", path) + args := []string{filepath.Base(path)} + if spec.Hooks == nil { + spec.Hooks = &specs.Hooks{} + } + spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{ + Path: path, + Args: append(args, "prestart"), + }) + } + + if err := m.AddDeviceCgroupRules(spec); err != nil { + return err + } + + return nil +} + +func (m *stableRuntimeModifier) AddDeviceCgroupRules(spec *specs.Spec) error { + + visibleDevices := getEnvVar(spec, visibleDevicesEnvvar) + + if visibleDevices == "" || visibleDevices == visibleDevicesVoid || visibleDevices == visibleDevicesNone { + m.logger.Warning("NVIDIA_VISIBLE_DEVICES is void/none/empty, skipping cgroup rules") + return nil + } + + if spec.Linux == nil { + spec.Linux = &specs.Linux{} + } + + if spec.Linux.Resources == nil { + spec.Linux.Resources = &specs.LinuxResources{} + } + + if err := addCommonDevices(m, spec); err != nil { + return fmt.Errorf("failed to add common devices: %v", err) + } + + if err := addGPUDevices(m, spec, visibleDevices); err != nil { + return fmt.Errorf("failed to add GPU devices: %v", err) + } + + return nil +} + +func getEnvVar(spec *specs.Spec, key string) string { + if spec.Process == nil { + return "" + } + + prefix := key + "=" + for _, env := range spec.Process.Env { + if strings.HasPrefix(env, prefix) { + return strings.TrimPrefix(env, prefix) + } + } + return "" +} + +func addCommonDevices(m *stableRuntimeModifier, spec *specs.Spec) error { + commonDevices := []string{ + "/dev/nvidiactl", + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + "/dev/nvidia-modeset", + } + + for _, devicePath := range commonDevices { + rule, err := m.deviceResolver.DevicePathToRule(devicePath) + if err != nil { + return fmt.Errorf("failed to add common device %s: %v", devicePath, err) + } + spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule) + m.logger.Debugf("Added cgroup rule for %s (major=%d, minor=%d)", + devicePath, *rule.Major, *rule.Minor) + } + + return nil +} + +func addGPUDevices(m *stableRuntimeModifier, spec *specs.Spec, visibleDevices string) error { + deviceList := strings.Split(visibleDevices, ",") + + for _, device := range deviceList { + device = strings.TrimSpace(device) + if device == "" { + continue + } + + if device == visibleDevicesAll { + return addAllGPUDevices(m, spec) + } + + devicePaths, err := resolveDevicePaths(device) + if err != nil { + return fmt.Errorf("failed to resolve device %s: %v", device, err) + } + + for _, devicePath := range devicePaths { + rule, err := m.deviceResolver.DevicePathToRule(devicePath) + if err != nil { + return fmt.Errorf("failed to add device %s: %v", devicePath, err) } + spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule) + m.logger.Debugf("Added cgroup rule for %s (major=%d, minor=%d)", + devicePath, *rule.Major, *rule.Minor) } } - path := m.nvidiaContainerRuntimeHookPath - m.logger.Infof("Using prestart hook path: %v", path) - args := []string{filepath.Base(path)} - if spec.Hooks == nil { - spec.Hooks = &specs.Hooks{} + return nil +} + +func addAllGPUDevices(m *stableRuntimeModifier, spec *specs.Spec) error { + + matches, err := m.deviceResolver.GlobDevices("nvidia[0-9]*") + if err != nil { + return fmt.Errorf("failed to glob nvidia devices: %v", err) + } + + for _, devicePath := range matches { + rule, err := m.deviceResolver.DevicePathToRule(devicePath) + if err != nil { + return fmt.Errorf("failed to add device %s: %v", devicePath, err) + } + spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule) + } + + capsMatches, _ := m.deviceResolver.GlobDevices("nvidia-caps/nvidia-cap[0-9]*") + for _, devicePath := range capsMatches { + rule, err := m.deviceResolver.DevicePathToRule(devicePath) + if err != nil { + return fmt.Errorf("failed to add device %s: %v", devicePath, err) + } + spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, *rule) } - spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{ - Path: path, - Args: append(args, "prestart"), - }) return nil } + +func resolveDevicePaths(device string) ([]string, error) { + var paths []string + + if idx, err := strconv.Atoi(device); err == nil { + paths = append(paths, fmt.Sprintf("/dev/nvidia%d", idx)) + return paths, nil + } + + return nil, fmt.Errorf("unknown device format: %s", device) +} diff --git a/internal/modifier/stable_test.go b/internal/modifier/stable_test.go index 994f08c12..343421fdc 100644 --- a/internal/modifier/stable_test.go +++ b/internal/modifier/stable_test.go @@ -26,6 +26,7 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/test" ) @@ -151,7 +152,6 @@ func TestAddHookModifier(t *testing.T) { } for _, tc := range testCases { - tc := tc logHook.Reset() @@ -171,3 +171,46 @@ func TestAddHookModifier(t *testing.T) { } } + +func TestAddAllGPUDevicesWithMock(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + mockResolver := oci.NewMockDeviceResolver() + + m := NewStableRuntimeModifier(logger, "") + m.WithDeviceResolver(mockResolver) + + testCases := []struct { + description string + spec specs.Spec + expectedError error + expectedNoofDevices int + }{ + { + description: "adds all GPU devices", + spec: specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=all"}, + }, + }, + expectedNoofDevices: 7, // nvidia0, nvidia1, nvidia2, nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools + expectedError: nil, + }, + { + description: "only one gpu device", + spec: specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=0"}, + }, + }, + expectedNoofDevices: 5, // nvidia0, nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools + expectedError: nil, + }, + } + for _, tc := range testCases { + err := m.AddDeviceCgroupRules(&tc.spec) + require.NoError(t, err) + require.Equal(t, tc.expectedNoofDevices, len(tc.spec.Linux.Resources.Devices)) + } + +} diff --git a/internal/oci/device_resolver.go b/internal/oci/device_resolver.go new file mode 100644 index 000000000..c05f95ecf --- /dev/null +++ b/internal/oci/device_resolver.go @@ -0,0 +1,64 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import ( + "fmt" + "path/filepath" + + "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" +) + +type DeviceResolver interface { + GlobDevices(pattern string) ([]string, error) + DevicePathToRule(path string) (*specs.LinuxDeviceCgroup, error) +} + +type RealDeviceResolver struct { + devRoot string +} + +func NewRealDeviceResolver(devRoot string) *RealDeviceResolver { + return &RealDeviceResolver{devRoot: devRoot} +} + +func (r *RealDeviceResolver) GlobDevices(pattern string) ([]string, error) { + return filepath.Glob(filepath.Join(r.devRoot, pattern)) +} + +func (r *RealDeviceResolver) DevicePathToRule(path string) (*specs.LinuxDeviceCgroup, error) { + var stat unix.Stat_t + if err := unix.Stat(path, &stat); err != nil { + return nil, err + } + + if stat.Mode&unix.S_IFCHR == 0 { + return nil, fmt.Errorf("%s is not a character device", path) + } + + major := int64(unix.Major(stat.Rdev)) + minor := int64(unix.Minor(stat.Rdev)) + + return &specs.LinuxDeviceCgroup{ + Allow: true, + Type: "c", + Major: &major, + Minor: &minor, + Access: "rwm", + }, nil +} diff --git a/internal/oci/device_resolver_mock.go b/internal/oci/device_resolver_mock.go new file mode 100644 index 000000000..a80c73d8e --- /dev/null +++ b/internal/oci/device_resolver_mock.go @@ -0,0 +1,73 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "path/filepath" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +// MockDeviceResolver for testing +type MockDeviceResolver struct { + devices map[string]struct{ major, minor int64 } +} + +func NewMockDeviceResolver() *MockDeviceResolver { + return &MockDeviceResolver{ + devices: map[string]struct{ major, minor int64 }{ + "nvidia0": {195, 0}, + "nvidia1": {195, 1}, + "nvidia2": {195, 2}, + "nvidiactl": {195, 255}, + "nvidia-modeset": {195, 254}, + "nvidia-uvm": {236, 0}, + "nvidia-uvm-tools": {236, 1}, + }, + } +} + +func (r *MockDeviceResolver) GlobDevices(pattern string) ([]string, error) { + var matches []string + for name := range r.devices { + matched, _ := filepath.Match(pattern, name) + if matched { + matches = append(matches, name) + } + } + return matches, nil +} + +func (r *MockDeviceResolver) DevicePathToRule(path string) (*specs.LinuxDeviceCgroup, error) { + base := filepath.Base(path) + dev, ok := r.devices[base] + if !ok { + return nil, fmt.Errorf("device not found: %s", path) + } + + major := dev.major + minor := dev.minor + + return &specs.LinuxDeviceCgroup{ + Allow: true, + Type: "c", + Major: &major, + Minor: &minor, + Access: "rwm", + }, nil +} diff --git a/internal/oci/runtime_modifier_test.go b/internal/oci/runtime_modifier_test.go index 47aebc8d6..23e62595e 100644 --- a/internal/oci/runtime_modifier_test.go +++ b/internal/oci/runtime_modifier_test.go @@ -161,3 +161,9 @@ type modiferMock struct{} func (m modiferMock) Modify(*specs.Spec) error { return nil } + +func (m modiferMock) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m modiferMock) WithDeviceResolver(resolver DeviceResolver) {} diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 73030b76d..bb1c93379 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -27,6 +27,8 @@ type SpecModifier interface { // Modify is a method that accepts a pointer to an OCI Spec and returns an // error. The intention is that the function would modify the spec in-place. Modify(*specs.Spec) error + AddDeviceCgroupRules(spec *specs.Spec) error + WithDeviceResolver(DeviceResolver) } // SpecModifiers is a collection of OCI Spec modifiers that can be treated as a @@ -81,3 +83,21 @@ func (ms SpecModifiers) Modify(s *specs.Spec) error { } return nil } + +func (ms SpecModifiers) AddDeviceCgroupRules(spec *specs.Spec) error { + for _, m := range ms { + if m == nil { + continue + } + if err := m.AddDeviceCgroupRules(spec); err != nil { + return err + } + } + return nil +} + +func (ms SpecModifiers) WithDeviceResolver(resolver DeviceResolver) { + for _, m := range ms { + m.WithDeviceResolver(resolver) + } +} diff --git a/internal/oci/spec_memory_test.go b/internal/oci/spec_memory_test.go index ebf92e750..2bc0facfe 100644 --- a/internal/oci/spec_memory_test.go +++ b/internal/oci/spec_memory_test.go @@ -177,3 +177,9 @@ func (m modifier) Modify(spec *specs.Spec) error { } return m.modifierError } + +func (m modifier) AddDeviceCgroupRules(spec *specs.Spec) error { + return nil +} + +func (m modifier) WithDeviceResolver(resolver DeviceResolver) {}