diff --git a/internals/plan/export_test.go b/internals/plan/export_test.go new file mode 100644 index 000000000..3de7628a7 --- /dev/null +++ b/internals/plan/export_test.go @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package plan + +// ResetLayerExtensions resets the global state between tests. +func ResetLayerExtensions() { + layerExtensions = map[string]LayerSectionExtension{} +} diff --git a/internals/plan/package_test.go b/internals/plan/package_test.go new file mode 100644 index 000000000..9fb81895c --- /dev/null +++ b/internals/plan/package_test.go @@ -0,0 +1,34 @@ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this progral. If not, see . + +package plan_test + +import ( + "testing" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/plan" +) + +// Hook up check.v1 into the "go test" runner. +func Test(t *testing.T) { TestingT(t) } + +type planSuite struct{} + +var _ = Suite(&planSuite{}) + +func (ps *planSuite) SetUpTest(c *C) { + plan.ResetLayerExtensions() +} diff --git a/internals/plan/plan.go b/internals/plan/plan.go index 2d5f3bcdc..839f1420a 100644 --- a/internals/plan/plan.go +++ b/internals/plan/plan.go @@ -32,6 +32,32 @@ import ( "github.com/canonical/pebble/internals/osutil" ) +// LayerSectionExtension allows the plan layer schema to be extended without +// adding centralised schema knowledge to the plan library. +type LayerSectionExtension interface { + // ParseSection creates a new layer section containing the unmarshalled + // yaml.Node, and any additional section specifics it wishes to apply + // to the backing type. A nil LayerSection returned ensures that this + // section will be omitted by the caller. + ParseSection(data *yaml.Node) (LayerSection, error) + + // CombineSections creates a new layer section containing the result of + // combining the layer sections in the supplied order. A nil LayerSection + // returned ensures the combined section will be completely omitted by + // the caller. + CombineSections(sections ...LayerSection) (LayerSection, error) + + // ValidatePlan takes the complete plan as input, and allows the + // extension to validate the plan. This can be used for cross section + // dependency validation. + ValidatePlan(plan *Plan) error +} + +type LayerSection interface { + // Validate expects the section to validate itself. + Validate() error +} + const ( defaultBackoffDelay = 500 * time.Millisecond defaultBackoffFactor = 2.0 @@ -42,11 +68,33 @@ const ( defaultCheckThreshold = 3 ) +// layerExtensions keeps a map of registered extensions. +var layerExtensions = map[string]LayerSectionExtension{} + +// RegisterExtension must be called by the plan extension owners to +// extend the plan schema before the plan is loaded. +func RegisterExtension(field string, ext LayerSectionExtension) error { + if _, ok := layerExtensions[field]; ok { + return fmt.Errorf("internal error: extension %q already registered", field) + } + layerExtensions[field] = ext + return nil +} + type Plan struct { Layers []*Layer `yaml:"-"` Services map[string]*Service `yaml:"services,omitempty"` Checks map[string]*Check `yaml:"checks,omitempty"` LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"` + + // Extended schema sections. + Sections map[string]LayerSection `yaml:",inline,omitempty"` +} + +// Section retrieves a section from the plan. Returns nil if +// the field does not exist. +func (p *Plan) Section(field string) LayerSection { + return p.Sections[field] } type Layer struct { @@ -57,6 +105,19 @@ type Layer struct { Services map[string]*Service `yaml:"services,omitempty"` Checks map[string]*Check `yaml:"checks,omitempty"` LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"` + + Sections map[string]LayerSection `yaml:",inline,omitempty"` +} + +// addSection adds a new section to the layer. +func (layer *Layer) addSection(field string, section LayerSection) { + layer.Sections[field] = section +} + +// Section retrieves a layer section from a layer. Returns nil if +// the field does not exist. +func (layer *Layer) Section(field string) LayerSection { + return layer.Sections[field] } type Service struct { @@ -559,6 +620,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { Services: make(map[string]*Service), Checks: make(map[string]*Check), LogTargets: make(map[string]*LogTarget), + Sections: make(map[string]LayerSection), } if len(layers) == 0 { return combined, nil @@ -643,6 +705,30 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { } } + // Combine the same sections from each layer. + for field, extension := range layerExtensions { + var sections []LayerSection + for _, layer := range layers { + if section := layer.Section(field); section != nil { + sections = append(sections, section) + } + } + // Deliberately do not expose the zero section condition to the extension. + // For now, the result of combining nothing must result in an omitted section. + if len(sections) > 0 { + combinedSection, err := extension.CombineSections(sections...) + if err != nil { + return nil, &FormatError{ + Message: fmt.Sprintf(`cannot combine section %q: %v`, field, err), + } + } + // We support the ability for a valid combine to result in an omitted section. + if combinedSection != nil { + combined.addSection(field, combinedSection) + } + } + } + // Set defaults where required. for _, service := range combined.Services { if !service.BackoffDelay.IsSet { @@ -825,11 +911,18 @@ func (layer *Layer) Validate() error { } } + for field, section := range layer.Sections { + err := section.Validate() + if err != nil { + return fmt.Errorf("cannot validate layer section %q: %w", field, err) + } + } + return nil } -// Validate checks that the combined layers form a valid plan. -// See also Layer.Validate, which checks that the individual layers are valid. +// Validate checks that the combined layers form a valid plan. See also +// Layer.Validate, which checks that the individual layers are valid. func (p *Plan) Validate() error { for name, service := range p.Services { if service.Command == "" { @@ -917,6 +1010,15 @@ func (p *Plan) Validate() error { if err != nil { return err } + + // Each section extension must validate the combined plan. + for field, extension := range layerExtensions { + err = extension.ValidatePlan(p) + if err != nil { + return fmt.Errorf("cannot validate plan section %q: %w", field, err) + } + } + return nil } @@ -1020,19 +1122,80 @@ func (p *Plan) checkCycles() error { } func ParseLayer(order int, label string, data []byte) (*Layer, error) { - layer := Layer{ - Services: map[string]*Service{}, - Checks: map[string]*Check{}, - LogTargets: map[string]*LogTarget{}, - } - dec := yaml.NewDecoder(bytes.NewBuffer(data)) - dec.KnownFields(true) - err := dec.Decode(&layer) + layer := &Layer{ + Services: make(map[string]*Service), + Checks: make(map[string]*Check), + LogTargets: make(map[string]*LogTarget), + Sections: make(map[string]LayerSection), + } + + // The following manual approach is required because: + // + // 1. Extended sections are YAML inlined, and also do not have a + // concrete type at this level, we cannot simply unmarshal the layer + // at once. + // + // 2. We honor KnownFields = true behaviour for non extended schema + // sections, and at the top field level, which includes Section field + // names. + + builtinSections := map[string]interface{}{ + "summary": &layer.Summary, + "description": &layer.Description, + "services": &layer.Services, + "checks": &layer.Checks, + "log-targets": &layer.LogTargets, + } + + var layerSections map[string]yaml.Node + err := yaml.Unmarshal(data, &layerSections) if err != nil { return nil, &FormatError{ Message: fmt.Sprintf("cannot parse layer %q: %v", label, err), } } + + for field, section := range layerSections { + switch field { + case "summary", "description", "services", "checks", "log-targets": + // The following issue prevents us from using the yaml.Node decoder + // with KnownFields = true behaviour. Once one of the proposals get + // merged, we can remove the intermediate Marshall step. + // https://github.com/go-yaml/yaml/issues/460 + data, err := yaml.Marshal(§ion) + if err != nil { + return nil, fmt.Errorf("internal error: cannot marshal %v section: %w", field, err) + } + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err = dec.Decode(builtinSections[field]); err != nil { + return nil, &FormatError{ + Message: fmt.Sprintf("cannot parse layer %q section %q: %v", label, field, err), + } + } + default: + if extension, ok := layerExtensions[field]; ok { + // Section unmarshal rules are defined by the extension itself. + extendedSection, err := extension.ParseSection(§ion) + if err != nil { + return nil, &FormatError{ + Message: fmt.Sprintf("cannot parse layer %q section %q: %v", label, field, err), + } + } + if extendedSection != nil { + layer.addSection(field, extendedSection) + } + } else { + // At the top level we do not ignore keys we do not understand. + // This preserves the current Pebble behaviour of decoding with + // KnownFields = true. + return nil, &FormatError{ + Message: fmt.Sprintf("cannot parse layer %q: unknown section %q", label, field), + } + } + } + } + layer.Order = order layer.Label = label @@ -1060,7 +1223,7 @@ func ParseLayer(order int, label string, data []byte) (*Layer, error) { return nil, err } - return &layer, err + return layer, err } func validServiceAction(action ServiceAction, additionalValid ...ServiceAction) bool { @@ -1164,6 +1327,7 @@ func ReadDir(dir string) (*Plan, error) { Services: combined.Services, Checks: combined.Checks, LogTargets: combined.LogTargets, + Sections: combined.Sections, } err = plan.Validate() if err != nil { diff --git a/internals/plan/plan_test.go b/internals/plan/plan_test.go index 3a99fcdf4..3a11596d6 100644 --- a/internals/plan/plan_test.go +++ b/internals/plan/plan_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Canonical Ltd +// Copyright (c) 2024 Canonical Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -203,6 +203,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, { Order: 1, Label: "layer-1", @@ -253,6 +254,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }}, result: &plan.Layer{ Summary: "Simple override layer.", @@ -332,6 +334,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, start: map[string][]string{ "srv1": {"srv2", "srv1", "srv3"}, @@ -394,6 +397,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }}, }, { summary: "Unknown keys are not accepted", @@ -477,7 +481,7 @@ var planTests = []planTest{{ `}, }, { summary: `Invalid backoff-delay duration`, - error: `cannot parse layer "layer-0": invalid duration "foo"`, + error: `cannot parse layer "layer-0" section \"services\": invalid duration "foo"`, input: []string{` services: "svc1": @@ -507,7 +511,7 @@ var planTests = []planTest{{ `}, }, { summary: `Invalid backoff-factor`, - error: `cannot parse layer "layer-0": invalid floating-point number "foo"`, + error: `cannot parse layer "layer-0" section \"services\": invalid floating-point number "foo"`, input: []string{` services: "svc1": @@ -544,6 +548,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }}, }, { summary: `Invalid service command: cannot have any arguments after [ ... ] group`, @@ -652,6 +657,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Checks override replace works correctly", @@ -729,6 +735,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Checks override merge works correctly", @@ -812,6 +819,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Timeout is capped at period", @@ -841,6 +849,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Unset timeout is capped at period", @@ -869,6 +878,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "One of http, tcp, or exec must be present for check", @@ -989,6 +999,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Overriding log targets", @@ -1085,6 +1096,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, + Sections: map[string]plan.LayerSection{}, }, { Label: "layer-1", Order: 1, @@ -1123,6 +1135,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, + Sections: map[string]plan.LayerSection{}, }}, result: &plan.Layer{ Services: map[string]*plan.Service{ @@ -1168,6 +1181,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Log target requires type field", @@ -1277,6 +1291,7 @@ var planTests = []planTest{{ }, }, }, + Sections: map[string]plan.LayerSection{}, }, { Order: 1, Label: "layer-1", @@ -1302,6 +1317,7 @@ var planTests = []planTest{{ }, }, }, + Sections: map[string]plan.LayerSection{}, }}, result: &plan.Layer{ Services: map[string]*plan.Service{}, @@ -1329,6 +1345,7 @@ var planTests = []planTest{{ }, }, }, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Reserved log target labels", @@ -1379,6 +1396,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, + Sections: map[string]plan.LayerSection{}, }, }, { summary: "Three layers missing command", @@ -1402,7 +1420,7 @@ var planTests = []planTest{{ error: `plan must define "command" for service "srv1"`, }} -func (s *S) TestParseLayer(c *C) { +func (ps *planSuite) TestParseLayer(c *C) { for _, test := range planTests { c.Logf(test.summary) var sup plan.Plan @@ -1444,6 +1462,7 @@ func (s *S) TestParseLayer(c *C) { Services: result.Services, Checks: result.Checks, LogTargets: result.LogTargets, + Sections: result.Sections, } err = p.Validate() } @@ -1458,7 +1477,7 @@ func (s *S) TestParseLayer(c *C) { } } -func (s *S) TestCombineLayersCycle(c *C) { +func (ps *planSuite) TestCombineLayersCycle(c *C) { // Even if individual layers don't have cycles, combined layers might. layer1, err := plan.ParseLayer(1, "label1", []byte(` services: @@ -1486,6 +1505,7 @@ services: Services: combined.Services, Checks: combined.Checks, LogTargets: combined.LogTargets, + Sections: combined.Sections, } err = p.Validate() c.Assert(err, ErrorMatches, `services in before/after loop: .*`) @@ -1493,7 +1513,7 @@ services: c.Assert(ok, Equals, true, Commentf("error must be *plan.FormatError, not %T", err)) } -func (s *S) TestMissingOverride(c *C) { +func (ps *planSuite) TestMissingOverride(c *C) { layer1, err := plan.ParseLayer(1, "label1", []byte("{}")) c.Assert(err, IsNil) layer2, err := plan.ParseLayer(2, "label2", []byte(` @@ -1508,7 +1528,7 @@ services: c.Check(ok, Equals, true, Commentf("error must be *plan.FormatError, not %T", err)) } -func (s *S) TestMissingCommand(c *C) { +func (ps *planSuite) TestMissingCommand(c *C) { // Combine fails if no command in combined plan layer1, err := plan.ParseLayer(1, "label1", []byte("{}")) c.Assert(err, IsNil) @@ -1526,6 +1546,7 @@ services: Services: combined.Services, Checks: combined.Checks, LogTargets: combined.LogTargets, + Sections: combined.Sections, } err = p.Validate() c.Check(err, ErrorMatches, `plan must define "command" for service "srv1"`) @@ -1551,7 +1572,7 @@ services: c.Assert(combined.Services["srv1"].Command, Equals, "foo --bar") } -func (s *S) TestReadDir(c *C) { +func (ps *planSuite) TestReadDir(c *C) { tempDir := c.MkDir() for testIndex, test := range planTests { @@ -1607,7 +1628,7 @@ var readDirBadNames = []string{ "001-label--label.yaml", } -func (s *S) TestReadDirBadNames(c *C) { +func (ps *planSuite) TestReadDirBadNames(c *C) { pebbleDir := c.MkDir() layersDir := filepath.Join(pebbleDir, "layers") err := os.Mkdir(layersDir, 0755) @@ -1629,7 +1650,7 @@ var readDirDupNames = [][]string{ {"001-foo.yaml", "002-foo.yaml"}, } -func (s *S) TestReadDirDupNames(c *C) { +func (ps *planSuite) TestReadDirDupNames(c *C) { pebbleDir := c.MkDir() layersDir := filepath.Join(pebbleDir, "layers") err := os.Mkdir(layersDir, 0755) @@ -1651,7 +1672,7 @@ func (s *S) TestReadDirDupNames(c *C) { } } -func (s *S) TestMarshalLayer(c *C) { +func (ps *planSuite) TestMarshalLayer(c *C) { layerBytes := reindent(` summary: Simple layer description: A simple layer. @@ -1741,7 +1762,7 @@ var cmdTests = []struct { error: `cannot parse service "svc" command: cannot start command with \[ ... \] group`, }} -func (s *S) TestParseCommand(c *C) { +func (ps *planSuite) TestParseCommand(c *C) { for _, test := range cmdTests { service := plan.Service{Name: "svc", Command: test.command} @@ -1772,7 +1793,7 @@ func (s *S) TestParseCommand(c *C) { } } -func (s *S) TestLogsTo(c *C) { +func (ps *planSuite) TestLogsTo(c *C) { tests := []struct { services []string logsTo map[string]bool @@ -1859,7 +1880,7 @@ func (s *S) TestLogsTo(c *C) { } } -func (s *S) TestMergeServiceContextNoContext(c *C) { +func (ps *planSuite) TestMergeServiceContextNoContext(c *C) { userID, groupID := 10, 20 overrides := plan.ContextOptions{ Environment: map[string]string{"x": "y"}, @@ -1869,17 +1890,19 @@ func (s *S) TestMergeServiceContextNoContext(c *C) { Group: "grp", WorkingDir: "/working/dir", } + // This test ensures an empty service name results in no lookup, and + // simply leaves the provided context unchanged. merged, err := plan.MergeServiceContext(nil, "", overrides) c.Assert(err, IsNil) c.Check(merged, DeepEquals, overrides) } -func (s *S) TestMergeServiceContextBadService(c *C) { +func (ps *planSuite) TestMergeServiceContextBadService(c *C) { _, err := plan.MergeServiceContext(&plan.Plan{}, "nosvc", plan.ContextOptions{}) c.Assert(err, ErrorMatches, `context service "nosvc" not found`) } -func (s *S) TestMergeServiceContextNoOverrides(c *C) { +func (ps *planSuite) TestMergeServiceContextNoOverrides(c *C) { userID, groupID := 11, 22 p := &plan.Plan{Services: map[string]*plan.Service{"svc1": { Name: "svc1", @@ -1902,7 +1925,7 @@ func (s *S) TestMergeServiceContextNoOverrides(c *C) { }) } -func (s *S) TestMergeServiceContextOverrides(c *C) { +func (ps *planSuite) TestMergeServiceContextOverrides(c *C) { svcUserID, svcGroupID := 10, 20 p := &plan.Plan{Services: map[string]*plan.Service{"svc1": { Name: "svc1", @@ -1934,7 +1957,7 @@ func (s *S) TestMergeServiceContextOverrides(c *C) { }) } -func (s *S) TestPebbleLabelPrefixReserved(c *C) { +func (ps *planSuite) TestPebbleLabelPrefixReserved(c *C) { // Validate fails if layer label has the reserved prefix "pebble-" _, err := plan.ParseLayer(0, "pebble-foo", []byte("{}")) c.Check(err, ErrorMatches, `cannot use reserved label prefix "pebble-"`) diff --git a/internals/plan/suite_test.go b/internals/plan/suite_test.go deleted file mode 100644 index 1977d2394..000000000 --- a/internals/plan/suite_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// -// Copyright (c) 2020 Canonical Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package plan_test - -import ( - "testing" - - . "gopkg.in/check.v1" -) - -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{})