diff --git a/cgroup.go b/cgroup.go index 93fa99c..6e4cc84 100644 --- a/cgroup.go +++ b/cgroup.go @@ -125,12 +125,20 @@ func currentProcCgroup(fs afero.Fs) (string, error) { return "", xerrors.Errorf("read %v: %w", procSelfCgroup, err) } - parts := strings.Split(strings.TrimSpace(string(data)), ":") - if len(parts) != 3 { - return "", xerrors.Errorf("parse %v: %w", procSelfCgroup, err) + entries := strings.Split(strings.TrimSpace(string(data)), "\n") + + for _, entry := range entries { + parts := strings.Split(strings.TrimSpace(entry), ":") + if len(parts) != 3 { + return "", xerrors.Errorf("parse entry %v: %w", procSelfCgroup, err) + } + + if parts[0] == "0" { + return parts[2], nil + } } - return parts[2], nil + return "", xerrors.Errorf("no cgroup entry for hierarchy 0 found") } // read an int64 value from path diff --git a/cgroup_internal_test.go b/cgroup_internal_test.go new file mode 100644 index 0000000..ff3c0fd --- /dev/null +++ b/cgroup_internal_test.go @@ -0,0 +1,79 @@ +package clistat + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCurrentProcCgroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + procFile string + expectError string + expectPath string + }{ + { + name: "RootPath", + procFile: `0::/`, + expectPath: "/", + }, + { + name: "SingleLevelPath", + procFile: `0::/init.slice`, + expectPath: "/init.slice", + }, + { + name: "MultipleLevelSlice", + procFile: `0::/wibble/wobble/init.slice`, + expectPath: "/wibble/wobble/init.slice", + }, + { + name: "MixOfHierachy", + procFile: `1:net_cls:/ +0::/`, + expectPath: "/", + }, + { + name: "MixOfHierarchyPaths", + procFile: `1:net_cls:/init.slice +0::/`, + expectPath: "/", + }, + { + name: "MixOfHierarchyPaths/Order", + procFile: `0::/ +1:net_cls:/init.slice`, + expectPath: "/", + }, + { + name: "MixOfHierarchyPaths/Paths", + procFile: `0::/init.slice +1:net_cls:/`, + expectPath: "/init.slice", + }, + { + name: "MissingHierarchy0", + procFile: `1:net_cls:/`, + expectError: "no cgroup entry for hierarchy 0 found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fs := initFS(t, map[string]string{procSelfCgroup: tt.procFile}) + + path, err := currentProcCgroup(fs) + if tt.expectError != "" { + require.ErrorContains(t, err, tt.expectError) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectPath, path) + } + }) + } +}