Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions cmd/nvidia_gpu/nvidia_gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"os"
"time"

gpumanager "github.com/GoogleCloudPlatform/container-engine-accelerators/pkg/gpu/nvidia"
Expand All @@ -27,6 +29,8 @@ import (
util "github.com/GoogleCloudPlatform/container-engine-accelerators/pkg/gpu/nvidia/util"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/golang/glog"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/watch"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

Expand All @@ -35,6 +39,9 @@ const (
kubeletEndpoint = "kubelet.sock"
pluginEndpointPrefix = "nvidiaGPU"
devDirectory = "/dev"
nodeNameEnv = "NODE_NAME"
lockFilePath = "/device-plugin/tpu-device-plugin.lock"

// Proc directory is used to lookup the access files for each GPU partition.
procDirectory = "/proc"
)
Expand All @@ -47,6 +54,7 @@ var (
pluginMountPath = flag.String("plugin-directory", "/device-plugin", "The directory path to create plugin socket")
enableContainerGPUMetrics = flag.Bool("enable-container-gpu-metrics", false, "If true, the device plugin will expose GPU metrics for containers with allocated GPU")
enableHealthMonitoring = flag.Bool("enable-health-monitoring", false, "If true, the device plugin will detect critical Xid errors and mark the GPUs unallocatable")
enableFlockWait = flag.Bool("enable-flock-wait", false, "If true, the device plugin will wait until the old device plugin release the lock")
gpuMetricsPort = flag.Int("gpu-metrics-port", 2112, "Port on which GPU metrics for containers are exposed")
gpuMetricsCollectionIntervalMs = flag.Int("gpu-metrics-collection-interval", 30000, "Collection interval (in milli seconds) for container GPU metrics")
gpuConfigFile = flag.String("gpu-config", "/etc/nvidia/gpu_config.json", "File with GPU configurations for device plugin")
Expand All @@ -73,6 +81,8 @@ func parseGPUConfig(gpuConfigFile string) (gpumanager.GPUConfig, error) {

func main() {
flag.Parse()
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // Ensure the context is canceled when main exits
glog.Infoln("device-plugin started")
mountPaths := []pluginapi.Mount{
{HostPath: *hostPathPrefix, ContainerPath: *containerPathPrefix, ReadOnly: true},
Expand Down Expand Up @@ -151,6 +161,34 @@ func main() {
return
}
defer hc.Stop()

}

if *enableFlockWait {
kubeClient, err := util.BuildKubeClient()
if err != nil {
glog.Infof("Failed to build kube client: %v", err)
return
}
nodeName, err := util.GetEnv(nodeNameEnv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require a change in DaemonSet yaml, right? should we getHostname if failed to get this env var?

if err != nil {
glog.Warningf("Failed to get node name from env %q: %v. Falling back to hostname.", nodeNameEnv, err)
var hostnameErr error
if nodeName, hostnameErr = os.Hostname(); hostnameErr != nil {
glog.Errorf("Failed to get hostname: %v", hostnameErr)
return
}
}

watchfunc := func(options metav1.ListOptions) (watch.Interface, error) {
return kubeClient.CoreV1().Nodes().Watch(ctx, metav1.ListOptions{
FieldSelector: "metadata.name=" + nodeName,
})
}
if err := util.SafelyUsingFlockWait(ctx, lockFilePath, watchfunc, util.CheckLockFileExists, util.UseRetryWatch); err != nil {
glog.Errorf("Failed to safely use flock wait, exiting... %v", err)
os.Exit(1)
}
}

ngm.Serve(*pluginMountPath, kubeletEndpoint, fmt.Sprintf("%s-%d.sock", pluginEndpointPrefix, time.Now().Unix()))
Expand Down
108 changes: 106 additions & 2 deletions pkg/gpu/nvidia/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,27 @@
package util

import (
"context"
"fmt"
"os"
"regexp"

"github.com/fsnotify/fsnotify"
"github.com/golang/glog"
"golang.org/x/sys/unix"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
client "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
retryWatch "k8s.io/client-go/tools/watch"
)

const ()

func DeviceNameFromPath(path string) (string, error) {
gpuPathRegex := regexp.MustCompile("/dev/(nvidia[0-9]+)$")
m := gpuPathRegex.FindStringSubmatch(path)
Expand All @@ -52,7 +62,7 @@ func Files(files ...string) (*fsnotify.Watcher, error) {
return watcher, nil
}

func BuildKubeClient() (client.Interface, error) {
func BuildKubeClient() (kubernetes.Interface, error) {
config, err := rest.InClusterConfig()
if err != nil {
glog.Errorf("failed to get kube config. Error: %v", err)
Expand All @@ -68,3 +78,97 @@ func BuildKubeClient() (client.Interface, error) {

return kubeClient, nil
}

func GetEnv(envName string) (string, error) {
env := os.Getenv(envName)
if len(env) == 0 {
return "", fmt.Errorf("empty %s environment variable", envName)
}
return env, nil
}

func CheckLockFileExists(lockFilePath string) (bool, error) {
if _, err := os.Stat(lockFilePath); err == nil {
return true, nil
} else if os.IsNotExist(err) {
return false, nil
} else {
return false, err
}
}

// Function containing the blocking logic that processes node events
func WaitForDeviceUnregistered(event watch.Event) (bool, error) {
if event.Type == watch.Modified || event.Type == watch.Added {
node, ok := event.Object.(*v1.Node)
if !ok {
glog.Warningf("unexpected object type: %T", event.Object)
return false, nil
}

tpuQuantity, exists := node.Status.Allocatable["nvidia.com/gpu"]
if !exists || tpuQuantity.Value() == 0 {
glog.Infoln("nvidia.com/gpu is 0. Proceeding to critical section.")
return true, nil
}
glog.Infoln("Waiting for nvidia.com/gpu to be 0...", tpuQuantity.Value())
return false, nil
}
if event.Type == watch.Deleted {
return true, fmt.Errorf("node deleted, exit here")
}
if event.Type == watch.Error {
return true, fmt.Errorf("node error received, exit here: %v", apierrors.FromObject(event.Object))
}
return false, nil
}

// Copyied from k8s.io/kubernetes/pkg/util/flock
// Acquire acquires a lock on a file for the duration of the process. This method
// is reentrant.
func Acquire(path string) error {
fd, err := unix.Open(path, unix.O_CREAT|unix.O_RDWR|unix.O_CLOEXEC, 0600)
if err != nil {
return err
}

// We don't need to close the fd since we should hold
// it until the process exits.

return unix.Flock(fd, unix.LOCK_EX)
}

func UseRetryWatch(ctx context.Context, watchFunc func(metav1.ListOptions) (watch.Interface, error), conditions func(watch.Event) (bool, error)) error {
_, err := retryWatch.Until(ctx, "1", &cache.ListWatch{WatchFunc: watchFunc}, conditions)
if err != nil {
return fmt.Errorf("failed to wait for device unregistered: %v", err)
}
return nil
}

func SafelyUsingFlockWait(
ctx context.Context,
lockFilePath string,
watchFunc func(metav1.ListOptions) (watch.Interface, error),
checkLockFileExists func(lockFilePath string) (bool, error),
useRetryWatch func(
ctx context.Context,
watchFunc func(metav1.ListOptions) (watch.Interface, error),
conditions func(watch.Event) (bool, error),
) error,
) error {
if val, err := checkLockFileExists(lockFilePath); err != nil {
return fmt.Errorf("error checking lock file %q: %q", lockFilePath, err)
} else if !val {
glog.Infof("Lock file %q does not exist\n", lockFilePath)
if err := useRetryWatch(ctx, watchFunc, WaitForDeviceUnregistered); err != nil {
return fmt.Errorf("failed to use retry watch: %v", err)
}
}

glog.Infof("Attempting to acquire lock on %q...\n", lockFilePath)
if err := Acquire(lockFilePath); err != nil {
return fmt.Errorf("error acquiring lock: %v", err)
}
return nil
}
Loading