Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add workloads to the plan #571

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
60 changes: 48 additions & 12 deletions internals/overlord/servstate/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type serviceData struct {
manager *ServiceManager
state serviceState
config *plan.Service
workload *Workload
logs *servicelog.RingBuffer
started chan error
stopped chan error
Expand All @@ -120,8 +121,20 @@ func (m *ServiceManager) doStart(task *state.Task, tomb *tomb.Tomb) error {
return fmt.Errorf("cannot find service %q in plan", request.Name)
}

var workload *Workload
if s, ok := currentPlan.Sections[WorkloadsField]; ok {
ws, ok := s.(*WorkloadsSection)
if !ok {
return fmt.Errorf("internal error: invalid section type %T", ws)
}
workload, ok = ws.Entries[config.Workload]
if config.Workload != "" && !ok {
return fmt.Errorf("cannot find workload %q for service %q in plan", config.Workload, request.Name)
}
}

// Create the service object (or reuse the existing one by name).
service, taskLog := m.serviceForStart(config)
service, taskLog := m.serviceForStart(config, workload)
if taskLog != "" {
addTaskLog(task, taskLog)
}
Expand Down Expand Up @@ -167,27 +180,34 @@ func (m *ServiceManager) doStart(task *state.Task, tomb *tomb.Tomb) error {
// and is running.
//
// It also returns a message to add to the task's log, or empty string if none.
func (m *ServiceManager) serviceForStart(config *plan.Service) (service *serviceData, taskLog string) {
func (m *ServiceManager) serviceForStart(config *plan.Service, workload *Workload) (service *serviceData, taskLog string) {
m.servicesLock.Lock()
defer m.servicesLock.Unlock()

var w *Workload
if workload != nil {
w = workload.copy()
}

service = m.services[config.Name]
if service == nil {
// Not already started, create a new service object.
service = &serviceData{
manager: m,
state: stateInitial,
config: config.Copy(),
logs: servicelog.NewRingBuffer(maxLogBytes),
started: make(chan error, 1),
stopped: make(chan error, 2), // enough for killTimeElapsed to send, and exit if it happens after
manager: m,
state: stateInitial,
config: config.Copy(),
workload: w,
logs: servicelog.NewRingBuffer(maxLogBytes),
started: make(chan error, 1),
stopped: make(chan error, 2), // enough for killTimeElapsed to send, and exit if it happens after
}
m.services[config.Name] = service
return service, ""
}

// Ensure config is up-to-date from the plan whenever the user starts a service.
service.config = config.Copy()
service.workload = w

switch service.state {
case stateInitial, stateStarting, stateRunning:
Expand Down Expand Up @@ -338,17 +358,27 @@ func (s *serviceData) startInternal() error {
s.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

// Copy environment to avoid updating original.
environment := make(map[string]string)
environment := make(map[string]string, len(s.config.Environment))
for k, v := range s.config.Environment {
environment[k] = v
}

s.cmd.Dir = s.config.WorkingDir

// Start as another user if specified in plan.
uid, gid, err := osutil.NormalizeUidGid(s.config.UserID, s.config.GroupID, s.config.User, s.config.Group)
if err != nil {
return err
var uid, gid *int
if s.config.UserID != nil || s.config.GroupID != nil || s.config.User != "" || s.config.Group != "" {
// User/group config from the service takes precedence
uid, gid, err = osutil.NormalizeUidGid(s.config.UserID, s.config.GroupID, s.config.User, s.config.Group)
if err != nil {
return err
}
} else if s.workload != nil {
// Take user/group config from workload
uid, gid, err = osutil.NormalizeUidGid(s.workload.UserID, s.workload.GroupID, s.workload.User, s.workload.Group)
if err != nil {
return err
}
}
if uid != nil && gid != nil {
isCurrent, err := osutil.IsCurrent(*uid, *gid)
Expand Down Expand Up @@ -378,6 +408,12 @@ func (s *serviceData) startInternal() error {
}
}

if s.workload != nil && len(s.workload.Environment) != 0 {
for k, v := range s.workload.Environment {
environment[k] = v
}
}

// Pass service description's environment variables to child process.
s.cmd.Env = os.Environ()
for k, v := range environment {
Expand Down
10 changes: 9 additions & 1 deletion internals/overlord/servstate/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,15 @@ func (m *ServiceManager) Replan() ([][]string, [][]string, error) {
if config.Equal(s.config) {
continue
}
s.config = config.Copy() // update service config from plan
// Update service config and workload from plan
s.config = config.Copy()
if s.config.Workload != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If workload changed from "foo" to "" with the previous update, I think you do want to update s.workload to not contain a copy of the previous workload ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if the plan.Service is the same, and the workload changed, replan will not restart the service.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to comment the same thing as Fred: don't we want to also restart the service if the workload details have changed (its user/group), not just the service.workload name?

ws, ok := currentPlan.Sections[WorkloadsField].(*WorkloadsSection)
if !ok {
return nil, nil, fmt.Errorf("internal error: invalid section type %T", ws)
}
s.workload = ws.Entries[s.config.Workload].copy()
}
}
needsRestart[name] = true
stop = append(stop, name)
Expand Down
4 changes: 4 additions & 0 deletions internals/overlord/servstate/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ func (s *S) TearDownTest(c *C) {
s.stopRunningServices(c)
}
}

// General test cleanup
s.BaseTest.TearDownTest(c)

Expand Down Expand Up @@ -1797,12 +1798,15 @@ func (s *S) planAddLayer(c *C, layerYAML string) {
layers := append(s.plan.Layers, layer)
combined, err := plan.CombineLayers(layers...)
c.Assert(err, IsNil)
c.Assert(combined.Validate(), IsNil)
s.plan = &plan.Plan{
Layers: layers,
Services: combined.Services,
Checks: combined.Checks,
LogTargets: combined.LogTargets,
Sections: combined.Sections,
}
c.Assert(s.plan.Validate(), IsNil)
}

// Make sure services are all stopped before the next test starts.
Expand Down
209 changes: 209 additions & 0 deletions internals/overlord/servstate/workloads.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright (c) 2025 Canonical Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package servstate

import (
"bytes"
"errors"
"fmt"

yaml "gopkg.in/yaml.v3"

"github.com/canonical/pebble/internals/plan"
)

var _ plan.SectionExtension = (*WorkloadsSectionExtension)(nil)

type WorkloadsSectionExtension struct{}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue this shouldn't be in servstate -- that just happens to be the only "consumer" of workloads for now. I think should should go in the plan package (internals/plan/workloads.go). Or maybe it's own workloads package, but I think plan is fine.

Naming suggestions:

  • If it's in plan package, let's call it plan.WorkloadsExtension (section seems obvious/redundant there)
  • If it's in its own package, let's call it workloads.PlanExtension or similar


func (ext *WorkloadsSectionExtension) CombineSections(sections ...plan.Section) (plan.Section, error) {
ws := &WorkloadsSection{}
for _, section := range sections {
layer, ok := section.(*WorkloadsSection)
if !ok {
return nil, fmt.Errorf("internal error: invalid section type %T", layer)
}
if err := ws.combine(layer); err != nil {
return nil, err
}
}
return ws, nil
}

func (ext *WorkloadsSectionExtension) ParseSection(data yaml.Node) (plan.Section, error) {
ws := &WorkloadsSection{}
// The following issue prevents us from using the yaml.Node decoder
// with KnownFields = true behavior. Once one of the proposals get
// merged, we can remove the intermediate Marshal step.
if len(data.Content) != 0 {
yml, err := yaml.Marshal(data)
if err != nil {
return nil, fmt.Errorf(`internal error: cannot marshal "workloads" section: %w`, err)
Copy link
Contributor

@flotter flotter Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen people quote with %q which is a single quote. Would it make sense to follow that convention using 'workloads' and just use normal double quotes on the outside ? Maybe I am missing something. It also applies to the next line.

Copy link
Collaborator Author

@anpep anpep Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I did not understand correctly, but %q uses double quotation marks. I personally don't like mixing up single/double quotes, and also found this in existing code:

case UnknownOverride:
return nil, &FormatError{
Message: fmt.Sprintf(`layer %q must define "override" for service %q`,
layer.Label, service.Name),
}

}
dec := yaml.NewDecoder(bytes.NewReader(yml))
dec.KnownFields(true)
if err = dec.Decode(ws); err != nil {
return nil, &plan.FormatError{
Message: fmt.Sprintf(`cannot parse the "workloads" section: %v`, err),
}
}
}
for name, workload := range ws.Entries {
if workload != nil {
workload.Name = name
}
}
return ws, nil
}

func (ext *WorkloadsSectionExtension) ValidatePlan(p *plan.Plan) error {
ws, ok := p.Sections[WorkloadsField].(*WorkloadsSection)
if !ok {
return fmt.Errorf("internal error: invalid section type %T", ws)
}
for name, service := range p.Services {
_, ok := ws.Entries[service.Workload]
if service.Workload != "" && !ok {
return &plan.FormatError{
Message: fmt.Sprintf(`plan service %q cannot run in unknown workload %q`, name, service.Workload),
}
}
}
return nil
}

const WorkloadsField = "workloads"

var _ plan.Section = (*WorkloadsSection)(nil)

type WorkloadsSection struct {
Entries map[string]*Workload `yaml:",inline"`
}

func (ws *WorkloadsSection) IsZero() bool {
return len(ws.Entries) == 0
}

func (ws *WorkloadsSection) Validate() error {
for name, workload := range ws.Entries {
if workload == nil {
return &plan.FormatError{
Message: fmt.Sprintf("workload %q has a null value", name),
}
}
if err := workload.validate(); err != nil {
return &plan.FormatError{
Message: fmt.Sprintf("workload %q %v", name, err),
}
}
}
return nil
}

func (ws *WorkloadsSection) combine(other *WorkloadsSection) error {
if len(other.Entries) != 0 && ws.Entries == nil {
ws.Entries = make(map[string]*Workload, len(other.Entries))
}
for name, workload := range other.Entries {
switch workload.Override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's same way we can factor this logic out into a helper function (likely in the plan package), generic/parametrized on the section type, that does this tricky switch/merge dance and provides error messages (you'd pass it the field name and anything else needed). This is already done three times in plan.go and now again here, and it's quite tricky.

case plan.MergeOverride:
if current, ok := ws.Entries[name]; ok {
copied := current.copy()
copied.merge(workload)
ws.Entries[name] = copied
break
}
fallthrough
case plan.ReplaceOverride:
ws.Entries[name] = workload.copy()
case plan.UnknownOverride:
return &plan.FormatError{
Message: fmt.Sprintf(`workload %q must define an "override" policy`, name),
}
default:
return &plan.FormatError{
Message: fmt.Sprintf(`workload %q has an invalid "override" policy: %q`, name, workload.Override),
}
}
}
return nil
}

type Workload struct {
// Basic details
Name string `yaml:"-"`
Override plan.Override `yaml:"override,omitempty"`

// Options for command execution
Environment map[string]string `yaml:"environment,omitempty"`
UserID *int `yaml:"user-id,omitempty"`
User string `yaml:"user,omitempty"`
GroupID *int `yaml:"group-id,omitempty"`
Group string `yaml:"group,omitempty"`
}

func (w *Workload) validate() error {
if w.Name == "" {
return errors.New("cannot have an empty name")
}
// Value of Override is checked in the (*WorkloadSection).combine() method
return nil
}

func (w *Workload) copy() *Workload {
copied := *w
if w.Environment != nil {
copied.Environment = make(map[string]string, len(w.Environment))
for k, v := range w.Environment {
copied.Environment[k] = v
}
}
if w.UserID != nil {
copied.UserID = copyIntPtr(w.UserID)
}
if w.GroupID != nil {
copied.GroupID = copyIntPtr(w.GroupID)
}
return &copied
}

func (w *Workload) merge(other *Workload) {
if len(other.Environment) != 0 && w.Environment == nil {
w.Environment = make(map[string]string, len(other.Environment))
}
for k, v := range other.Environment {
w.Environment[k] = v
}
if other.UserID != nil {
w.UserID = copyIntPtr(other.UserID)
}
if other.User != "" {
w.User = other.User
}
if other.GroupID != nil {
w.GroupID = copyIntPtr(other.GroupID)
}
if other.Group != "" {
w.Group = other.Group
}
}

func copyIntPtr(p *int) *int {
if p == nil {
return nil
}
copied := *p
return &copied
}
Loading
Loading