diff --git a/systemd/systemd_test.go b/systemd/systemd_test.go index f770ab1..4d536a3 100644 --- a/systemd/systemd_test.go +++ b/systemd/systemd_test.go @@ -1,8 +1,10 @@ package systemd import ( + "context" "os" "reflect" + "strconv" "testing" systemdDbus "github.com/coreos/go-systemd/v22/dbus" @@ -157,6 +159,12 @@ func TestUnifiedResToSystemdProps(t *testing.T) { newProp("CPUWeight", uint64(1000)), }, }, + { + name: "memory.oom.group handled by Apply method", + res: map[string]string{ + "memory.oom.group": "1", + }, + }, } for _, tc := range testCases { @@ -236,3 +244,98 @@ func TestAddCPUQuota(t *testing.T) { }) } } + +func TestOOMPolicyApply(t *testing.T) { + if !IsRunningSystemd() { + t.Skip("Test requires systemd.") + } + if !cgroups.IsCgroup2UnifiedMode() { + t.Skip("cgroup v2 is required") + } + if os.Geteuid() != 0 { + t.Skip("Test requires root.") + } + + testCases := []struct { + name string + oomGroupValue string + expectedPolicy string + expectError bool + }{ + { + name: "memory.oom.group=0 sets OOMPolicy=continue", + oomGroupValue: "0", + expectedPolicy: "continue", + expectError: false, + }, + { + name: "memory.oom.group=1 sets OOMPolicy=kill", + oomGroupValue: "1", + expectedPolicy: "kill", + expectError: false, + }, + { + name: "invalid memory.oom.group value", + oomGroupValue: "invalid", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + config := &cgroups.Cgroup{ + Name: "test-oom-policy-" + strconv.FormatInt(int64(os.Getpid()), 10), + Resources: &cgroups.Resources{ + Unified: map[string]string{ + "memory.oom.group": tc.oomGroupValue, + }, + }, + } + + manager, err := NewUnifiedManager(config, "") + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer func() { + _ = manager.Destroy() + }() + + err = manager.Apply(-1) + if tc.expectError { + if err == nil { + t.Fatal("Expected error but got none") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + unitName := getUnitName(config) + conn, err := systemdDbus.NewSystemdConnectionContext(context.Background()) + if err != nil { + t.Fatalf("Failed to connect to systemd: %v", err) + } + defer conn.Close() + + properties, err := conn.GetUnitPropertiesContext(context.Background(), unitName) + if err != nil { + t.Fatalf("Failed to get unit properties: %v", err) + } + + oomPolicyValue, exists := properties["OOMPolicy"] + if !exists { + t.Fatal("OOMPolicy property not found") + } + + oomPolicyStr, ok := oomPolicyValue.(string) + if !ok { + t.Fatalf("OOMPolicy value is not a string: %T", oomPolicyValue) + } + + if oomPolicyStr != tc.expectedPolicy { + t.Errorf("Expected OOMPolicy=%s, got %s", tc.expectedPolicy, oomPolicyStr) + } + }) + } +} diff --git a/systemd/v2.go b/systemd/v2.go index c2f2e87..e480f98 100644 --- a/systemd/v2.go +++ b/systemd/v2.go @@ -180,13 +180,8 @@ func unifiedResToSystemdProps(cm *dbusConnManager, res map[string]string) (props newProp("TasksMax", num)) case "memory.oom.group": - // Setting this to 1 is roughly equivalent to OOMPolicy=kill - // (as per systemd.service(5) and - // https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html), - // but it's not clear what to do if it is unset or set - // to 0 in runc update, as there are two other possible - // values for OOMPolicy (continue/stop). - fallthrough + // This was set before the unit started, so no need to + // warn about it here. default: // Ignore the unknown resource here -- will still be @@ -327,6 +322,24 @@ func (m *UnifiedManager) Apply(pid int) error { properties = append(properties, c.SystemdProps...) + if c.Resources != nil && c.Resources.Unified != nil { + if v, ok := c.Resources.Unified["memory.oom.group"]; ok { + value, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return fmt.Errorf("unified resource %q value conversion error: %w", "memory.oom.group", err) + } + + switch value { + case 0: + properties = append(properties, newProp("OOMPolicy", "continue")) + case 1: + properties = append(properties, newProp("OOMPolicy", "kill")) + default: + logrus.Debugf("don't know how to convert memory.oom.group=%d; skipping (will still be applied to cgroupfs)", value) + } + } + } + if err := startUnit(m.dbus, unitName, properties, pid == -1); err != nil { return fmt.Errorf("unable to start unit %q (properties %+v): %w", unitName, properties, err) }