diff --git a/.gitignore b/.gitignore index 6e2f206..6ccd937 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ cover.out examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js +.scp \ No newline at end of file diff --git a/cmd/scp/main.go b/cmd/scp/main.go new file mode 100644 index 0000000..52fc52a --- /dev/null +++ b/cmd/scp/main.go @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package main + +import ( + "os" + + "github.com/pion/scp/internal/cli" +) + +func main() { + if err := cli.Execute(os.Args[1:]); err != nil { + cli.PrintError(err) + os.Exit(1) + } +} diff --git a/go.mod b/go.mod index 35d4027..09e0cb4 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,27 @@ -module github.com/pion/template +module github.com/pion/scp -go 1.21 +go 1.24.0 + +toolchain go1.24.4 + +require ( + github.com/Masterminds/semver/v3 v3.4.0 + github.com/pion/logging v0.2.3 + github.com/pion/sctp v1.8.39 + github.com/pion/transport v0.14.1 + github.com/spf13/cobra v1.10.1 + github.com/stretchr/testify v1.10.0 + golang.org/x/mod v0.30.0 + golang.org/x/sys v0.22.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/transport/v3 v3.0.7 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.9 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e24d20f --- /dev/null +++ b/go.sum @@ -0,0 +1,73 @@ +github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= +github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= +github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE= +github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= +github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= +github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cli/cli.go b/internal/cli/cli.go new file mode 100644 index 0000000..30dcc23 --- /dev/null +++ b/internal/cli/cli.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package cli wires up the cobra command tree for the scp generator. +package cli + +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" +) + +func Execute(args []string) error { + root := newRootCmd() + root.SetArgs(args) + ctx := context.Background() + + if err := root.ExecuteContext(ctx); err != nil { + return err + } + + return nil +} + +func PrintError(err error) { + fmt.Fprintf(os.Stderr, "error: %v\n", err) +} + +func newRootCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "scp", + Short: "scp is a generator for multi-revision Pion SCTP testing", + Long: `scp resolves references to github.com/pion/sctp, generates deterministic +harnesses, and runs cross-version compatibility tests.`, + SilenceUsage: true, + SilenceErrors: true, + } + + cmd.PersistentFlags().BoolP("verbose", "v", false, "enable verbose logging") + cmd.PersistentFlags().Bool("dry-run", false, "show actions without writing results") + + cmd.AddCommand(newResolveCmd()) + cmd.AddCommand(newUpdateCmd()) + cmd.AddCommand(newGenerateCmd()) + cmd.AddCommand(newTestCmd()) + + return cmd +} diff --git a/internal/cli/generate.go b/internal/cli/generate.go new file mode 100644 index 0000000..31b8cee --- /dev/null +++ b/internal/cli/generate.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package cli + +import ( + "github.com/pion/scp/internal/generate" + "github.com/spf13/cobra" +) + +func newGenerateCmd() *cobra.Command { + opts := generate.DefaultOptions() + cmd := &cobra.Command{ + Use: "generate", + Short: "Generate runners, wrappers, and harness code from a lock file", + RunE: func(cmd *cobra.Command, args []string) error { + return generate.Run(cmd.Context(), opts) + }, + } + + cmd.Flags().StringVar(&opts.LockPath, "lock", opts.LockPath, "path to lock.json") + cmd.Flags().StringVar(&opts.FeaturesPath, "features", opts.FeaturesPath, "path to features.yaml") + cmd.Flags().StringVar(&opts.OutputDir, "out", opts.OutputDir, "output directory for generated code") + cmd.Flags().StringVar(&opts.APIName, "package", opts.APIName, "name of generated API package") + cmd.Flags().StringVar( + &opts.RunnerProtocol, + "runner-proto", + opts.RunnerProtocol, + "runner transport protocol (stdio-json|rpc)", + ) + cmd.Flags().StringVar( + &opts.ModuleMode, + "modmode", + opts.ModuleMode, + "module resolve mode (remote|local-cache)", + ) + cmd.Flags().StringVar( + &opts.LicensePath, + "license", + opts.LicensePath, + "optional license header file path", + ) + cmd.Flags().StringSliceVar( + &opts.OnlyNames, + "only", + nil, + "optional comma-separated list of lock entries to generate", + ) + + return cmd +} diff --git a/internal/cli/resolve.go b/internal/cli/resolve.go new file mode 100644 index 0000000..e300be7 --- /dev/null +++ b/internal/cli/resolve.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package cli + +import ( + "errors" + + "github.com/pion/scp/internal/resolve" + "github.com/pion/scp/internal/scp" + "github.com/spf13/cobra" +) + +var errNoRefsProvided = errors.New("no refs specified: use --refs or provide positional selectors") + +func newResolveCmd() *cobra.Command { + opts := resolve.Options{}.WithDefaults() + cmd := &cobra.Command{ + Use: "resolve", + Short: "Resolve ref selectors into manifest and lock files", + RunE: func(cmd *cobra.Command, args []string) error { + if len(opts.Refs) == 0 && len(args) == 0 { + return errNoRefsProvided + } + + if len(args) > 0 { + opts.Refs = append(opts.Refs, args...) + } + + opts.Refs = scp.SplitAndTrim(opts.Refs) + + ctx := cmd.Context() + if err := resolve.Run(ctx, opts); err != nil { + return err + } + + return nil + }, + } + + cmd.Flags().StringSliceVar(&opts.Refs, "refs", nil, "comma-separated selector list (may repeat)") + cmd.Flags().StringVar(&opts.Repository, "repo", resolve.DefaultRepository, "repository URL to mirror") + cmd.Flags().StringVar(&opts.CacheDir, "cache", scp.DefaultCacheDir(), "cache directory for mirrors and checkouts") + cmd.Flags().BoolVar( + &opts.IncludePreRelease, + "include-pre", + false, + "include pre-release tags when resolving ranges", + ) + cmd.Flags().StringVar( + &opts.ManifestPath, + "out-manifest", + scp.DefaultManifestPath(), + "output path for manifest JSON", + ) + cmd.Flags().StringVar( + &opts.LockPath, + "out-lock", + scp.DefaultLockPath(), + "output path for lock JSON", + ) + cmd.Flags().StringVar(&opts.FreezeAt, "freeze-at", "", "RFC3339 timestamp to pin moving refs") + cmd.Flags().BoolVar( + &opts.AllowDirtyLocal, + "local-allow-dirty", + false, + "permit path selectors with local modifications", + ) + + return cmd +} diff --git a/internal/cli/test.go b/internal/cli/test.go new file mode 100644 index 0000000..c488d88 --- /dev/null +++ b/internal/cli/test.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package cli + +import ( + "github.com/pion/scp/internal/testcmd" + "github.com/spf13/cobra" +) + +func newTestCmd() *cobra.Command { + opts := testcmd.DefaultOptions() + cmd := &cobra.Command{ + Use: "test", + Short: "Build runners and execute cross-revision scenarios", + RunE: func(cmd *cobra.Command, args []string) error { + return testcmd.Run(cmd.Context(), opts) + }, + } + + cmd.Flags().StringVar(&opts.LockPath, "lock", opts.LockPath, "path to lock.json") + cmd.Flags().StringVar( + &opts.PairMode, + "pairs", + opts.PairMode, + "pair selection mode (adjacent|latest-prev|matrix|explicit|self)", + ) + cmd.Flags().StringSliceVar( + &opts.IncludeNames, + "include", + nil, + "include only these entries (comma-separated)", + ) + cmd.Flags().StringSliceVar( + &opts.ExcludeNames, + "exclude", + nil, + "exclude these entries (comma-separated)", + ) + cmd.Flags().StringSliceVar( + &opts.ExplicitPairs, + "explicit", + nil, + "explicit pairs when --pairs=explicit (comma-separated A:B)", + ) + cmd.Flags().StringSliceVar( + &opts.Cases, + "cases", + nil, + "scenario IDs to run (comma-separated)", + ) + cmd.Flags().StringVar(&opts.Timeout, "timeout", opts.Timeout, "overall timeout for each pair") + cmd.Flags().Int64Var(&opts.Seed, "seed", opts.Seed, "random seed (0=random)") + cmd.Flags().StringVar(&opts.JUnitPath, "out", opts.JUnitPath, "path to write JUnit XML results") + cmd.Flags().IntVar(&opts.Repeat, "repeat", opts.Repeat, "number of times to run each pair (>=1)") + + return cmd +} diff --git a/internal/cli/update.go b/internal/cli/update.go new file mode 100644 index 0000000..b528139 --- /dev/null +++ b/internal/cli/update.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package cli + +import ( + "github.com/pion/scp/internal/update" + "github.com/spf13/cobra" +) + +func newUpdateCmd() *cobra.Command { + opts := update.DefaultOptions() + cmd := &cobra.Command{ + Use: "update", + Short: "Update lock files from manifest entries with floating refs", + RunE: func(cmd *cobra.Command, args []string) error { + return update.Run(cmd.Context(), opts) + }, + } + + cmd.Flags().StringVar(&opts.ManifestPath, "manifest", opts.ManifestPath, "path to manifest.json") + cmd.Flags().StringVar(&opts.LockPath, "lock", opts.LockPath, "path to lock.json to update") + cmd.Flags().StringSliceVar(&opts.OnlyNames, "only", nil, "comma-separated entry names to refresh") + cmd.Flags().StringVar(&opts.FreezeAt, "freeze-at", "", "RFC3339 timestamp to pin moving refs") + + return cmd +} diff --git a/internal/generate/errors.go b/internal/generate/errors.go new file mode 100644 index 0000000..9f2b3a2 --- /dev/null +++ b/internal/generate/errors.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package generate + +import "errors" + +var ( + errMissingLockPath = errors.New("generate: lock path is required") + errMissingOutputDir = errors.New("generate: output directory is required") + errMissingAPIName = errors.New("generate: package name is required") +) diff --git a/internal/generate/options.go b/internal/generate/options.go new file mode 100644 index 0000000..306d9c5 --- /dev/null +++ b/internal/generate/options.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package generate implements the generator CLI that builds multi-version SCTP wrappers and harnesses. +package generate + +import ( + "github.com/pion/scp/internal/scp" +) + +const ( + DefaultRunnerProtocol = "stdio-json" + DefaultModuleMode = "local-cache" + ModuleModeRemote = "remote" + DefaultAPIName = "sctpapi" +) + +type Options struct { + LockPath string + FeaturesPath string + OutputDir string + APIName string + RunnerProtocol string + ModuleMode string + LicensePath string + OnlyNames []string +} + +func DefaultOptions() Options { + return Options{ + LockPath: scp.DefaultLockPath(), + FeaturesPath: scp.DefaultFeaturesPath(), + OutputDir: scp.DefaultOutputDir(), + APIName: DefaultAPIName, + RunnerProtocol: DefaultRunnerProtocol, + ModuleMode: DefaultModuleMode, + } +} + +func (o Options) WithDefaults() Options { + def := DefaultOptions() + + if o.LockPath == "" { + o.LockPath = def.LockPath + } + if o.FeaturesPath == "" { + o.FeaturesPath = def.FeaturesPath + } + if o.OutputDir == "" { + o.OutputDir = def.OutputDir + } + if o.APIName == "" { + o.APIName = def.APIName + } + if o.RunnerProtocol == "" { + o.RunnerProtocol = def.RunnerProtocol + } + if o.ModuleMode == "" { + o.ModuleMode = def.ModuleMode + } + + return o +} diff --git a/internal/generate/run.go b/internal/generate/run.go new file mode 100644 index 0000000..8986bc8 --- /dev/null +++ b/internal/generate/run.go @@ -0,0 +1,739 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package generate + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + + "github.com/Masterminds/semver/v3" + "github.com/pion/scp/internal/scp" + "golang.org/x/mod/modfile" + "gopkg.in/yaml.v3" +) + +const ( + rootModulePath = "generated" + apiPackageName = "api" + harnessPackageName = "harness" + internalDepsDirName = "internaldeps" + wrappersDirName = "wrappers" + harnessCmdDir = "cmd/scp-harness" +) + +var ( + errEmptyLock = errors.New("generate: lock file is empty") + errNoSelectableEntries = errors.New("generate: no entries selected") + errRequestedEntryMissing = errors.New("generate: requested entry missing") + errLocalPathMissingLabel = errors.New("generate: local path entry missing path label") + errInvalidInputPath = errors.New("generate: invalid input path") + errUnsupportedModuleMode = errors.New("generate: unsupported module mode") +) + +type generationConfig struct { + Header string + APIName string + OutDir string + Entries []scp.LockEntry + FeatureSpec featureSpec + Repository string +} + +func Run(ctx context.Context, opts Options) error { + cfg, err := prepareConfig(opts) + if err != nil { + return err + } + + return executeGeneration(cfg) +} + +func validateOptions(opts Options) error { + if opts.LockPath == "" { + return errMissingLockPath + } + if opts.OutputDir == "" { + return errMissingOutputDir + } + if opts.APIName == "" { + return errMissingAPIName + } + if opts.ModuleMode != "" && opts.ModuleMode != DefaultModuleMode && opts.ModuleMode != ModuleModeRemote { + return fmt.Errorf("%w: %s", errUnsupportedModuleMode, opts.ModuleMode) + } + + return nil +} + +func prepareConfig(opts Options) (generationConfig, error) { + opts = opts.WithDefaults() + + if err := validateOptions(opts); err != nil { + return generationConfig{}, err + } + + lock, err := scp.ReadLock(opts.LockPath) + if err != nil { + return generationConfig{}, fmt.Errorf("generate: read lock: %w", err) + } + features, err := loadFeatureSpec(opts.FeaturesPath) + if err != nil { + return generationConfig{}, fmt.Errorf("generate: read features: %w", err) + } + header, err := buildHeader(opts.LicensePath) + if err != nil { + return generationConfig{}, fmt.Errorf("generate: build header: %w", err) + } + entries, err := selectEntries(lock, opts.OnlyNames) + if err != nil { + return generationConfig{}, err + } + + outDir, err := filepath.Abs(opts.OutputDir) + if err != nil { + return generationConfig{}, fmt.Errorf("generate: resolve output dir: %w", err) + } + + repo := "https://github.com/pion/sctp" + if lock.Metadata.Repository != "" { + repo = lock.Metadata.Repository + } + + return generationConfig{ + Header: header, + APIName: opts.APIName, + OutDir: outDir, + Entries: entries, + FeatureSpec: features, + Repository: repo, + }, nil +} + +func executeGeneration(cfg generationConfig) error { + if err := os.RemoveAll(cfg.OutDir); err != nil { + return fmt.Errorf("generate: clear output dir: %w", err) + } + if err := os.MkdirAll(cfg.OutDir, 0o750); err != nil { + return fmt.Errorf("generate: create output dir: %w", err) + } + + if err := writeRootModule(cfg.OutDir); err != nil { + return err + } + if err := writeAPIPackage(cfg.OutDir, cfg.Header, cfg.APIName); err != nil { + return err + } + + wrappers := make([]wrapperInfo, 0, len(cfg.Entries)) + for _, entry := range cfg.Entries { + depInfo, err := stageInternalDependency(cfg.OutDir, cfg.Repository, entry) + if err != nil { + return err + } + + wrapper, err := writeWrapper(cfg.OutDir, cfg.Header, entry, depInfo) + if err != nil { + return err + } + wrappers = append(wrappers, wrapper) + } + + featureMatrix := computeEntryFeatures(wrappers, cfg.FeatureSpec) + if err := writeHarness(cfg.OutDir, cfg.Header, wrappers, featureMatrix); err != nil { + return err + } + + return writeHarnessCommand(cfg.OutDir, cfg.Header) +} + +type wrapperInfo struct { + Entry scp.LockEntry + PackageName string + ImportPath string + FactoryAlias string +} + +type dependencyInfo struct { + ModulePath string + ImportPath string +} + +func selectEntries(lock *scp.Lockfile, only []string) ([]scp.LockEntry, error) { + if lock == nil { + return nil, errEmptyLock + } + + allow := normalizeAllowList(only) + var selected []scp.LockEntry + for _, entry := range lock.Entries { + if len(allow) > 0 { + if _, ok := allow[entry.Name]; !ok { + continue + } + } + selected = append(selected, entry) + } + + if err := ensureRequestedEntriesPresent(selected, allow); err != nil { + return nil, err + } + if len(selected) == 0 { + return nil, errNoSelectableEntries + } + + sort.Slice(selected, func(i, j int) bool { return selected[i].Name < selected[j].Name }) + + return selected, nil +} + +func writeRootModule(outDir string) error { + goModPath := filepath.Join(outDir, "go.mod") + content := "module " + rootModulePath + "\n\ngo 1.25\n\nrequire github.com/pion/transport v0.14.1\n" + data := []byte(content) + + return os.WriteFile(goModPath, data, 0o600) +} + +func writeAPIPackage(outDir, header, apiName string) error { + apiDir := filepath.Join(outDir, apiPackageName) + if err := os.MkdirAll(apiDir, 0o750); err != nil { + return fmt.Errorf("generate: create api dir: %w", err) + } + + pkg := fmt.Sprintf(`%spackage %s + +// Adapter represents a minimal SCTP adapter implementation. +type Adapter interface { + // Name returns a human readable adapter name. + Name() string +} + +// DialRequest contains placeholder dial parameters. +type DialRequest struct { + Remote string +} +`, header, apiName) + + return os.WriteFile(filepath.Join(apiDir, "sctp_api.go"), []byte(pkg), 0o600) +} + +func stageInternalDependency(outDir string, repo string, entry scp.LockEntry) (dependencyInfo, error) { + sourcePath, err := resolveSourcePath(repo, entry) + if err != nil { + return dependencyInfo{}, err + } + + depDir := filepath.Join(outDir, internalDepsDirName, entry.Name) + if err := copyTree(sourcePath, depDir); err != nil { + return dependencyInfo{}, fmt.Errorf("generate: copy %s: %w", entry.Name, err) + } + + modulePath := strings.Join([]string{rootModulePath, internalDepsDirName, entry.Name}, "/") + if err := rewriteModule(depDir, modulePath); err != nil { + return dependencyInfo{}, err + } + if err := rewriteImports(depDir, "github.com/pion/sctp", modulePath); err != nil { + return dependencyInfo{}, err + } + + return dependencyInfo{ModulePath: modulePath, ImportPath: modulePath}, nil +} + +func writeWrapper(outDir, header string, entry scp.LockEntry, dep dependencyInfo) (wrapperInfo, error) { + pkgName := sanitizePackage(entry.Name) + wrapperDir := filepath.Join(outDir, wrappersDirName, entry.Name) + if err := os.MkdirAll(wrapperDir, 0o750); err != nil { + return wrapperInfo{}, fmt.Errorf("generate: create wrapper dir: %w", err) + } + + content := fmt.Sprintf(`%spackage %s + +import ( + adapter %q + api %q +) + +// Adapter is a placeholder implementation for %s. +type Adapter struct{} + +// New returns a new Adapter instance. +func New() api.Adapter { + return &Adapter{} +} + +// Name returns the adapter identifier. +func (a *Adapter) Name() string { + return %q +} +`, header, pkgName, dep.ImportPath, rootModulePath+"/"+apiPackageName, entry.Name, entry.Name) + + if err := os.WriteFile(filepath.Join(wrapperDir, "adapter.go"), []byte(content), 0o600); err != nil { + return wrapperInfo{}, fmt.Errorf("generate: write wrapper for %s: %w", entry.Name, err) + } + + return wrapperInfo{ + Entry: entry, + PackageName: pkgName, + ImportPath: strings.Join([]string{rootModulePath, wrappersDirName, entry.Name}, "/"), + FactoryAlias: "wrapper_" + pkgName, + }, nil +} + +func writeHarness(outDir, header string, wrappers []wrapperInfo, featureMatrix map[string][]string) error { + harnessDir := filepath.Join(outDir, harnessPackageName) + if err := os.MkdirAll(harnessDir, 0o750); err != nil { + return fmt.Errorf("generate: create harness dir: %w", err) + } + + imports := make([]string, 0, len(wrappers)) + for _, wrapper := range wrappers { + imports = append(imports, fmt.Sprintf(" %s \"%s\"", wrapper.FactoryAlias, wrapper.ImportPath)) + } + sort.Strings(imports) + + registryEntries := make([]string, 0, len(wrappers)) + for _, wrapper := range wrappers { + registryEntries = append(registryEntries, fmt.Sprintf(" %q: %s.New,", wrapper.Entry.Name, wrapper.FactoryAlias)) + } + sort.Strings(registryEntries) + + matrixEntries := make([]string, 0, len(featureMatrix)) + for name, features := range featureMatrix { + matrixEntries = append(matrixEntries, fmt.Sprintf(" %q: %#v,", name, features)) + } + sort.Strings(matrixEntries) + + var builder strings.Builder + builder.WriteString(header) + builder.WriteString("package ") + builder.WriteString(harnessPackageName) + builder.WriteString("\n\nimport (\n") + builder.WriteString(strings.Join(imports, "\n")) + if len(imports) > 0 { + builder.WriteByte('\n') + } + builder.WriteString(" api \"") + builder.WriteString(rootModulePath + "/" + apiPackageName) + builder.WriteString("\"\n)\n\n") + builder.WriteString("type AdapterFactory func() api.Adapter\n\n") + builder.WriteString("var Registry = map[string]AdapterFactory{\n") + builder.WriteString(strings.Join(registryEntries, "\n")) + if len(registryEntries) > 0 { + builder.WriteByte('\n') + } + builder.WriteString("}\n\nvar EntryFeatures = map[string][]string{\n") + builder.WriteString(strings.Join(matrixEntries, "\n")) + if len(matrixEntries) > 0 { + builder.WriteByte('\n') + } + builder.WriteString("}\n") + + return os.WriteFile(filepath.Join(harnessDir, "registry.go"), []byte(builder.String()), 0o600) +} + +func writeHarnessCommand(outDir, header string) error { + cmdDir := filepath.Join(outDir, harnessCmdDir) + if err := os.MkdirAll(cmdDir, 0o750); err != nil { + return fmt.Errorf("generate: create harness cmd dir: %w", err) + } + + mainFile := fmt.Sprintf(`%spackage main + +import ( + "fmt" + + "github.com/pion/transport/vnet" + harness %q +) + +func main() { + router := vnet.NewRouter() + defer func() { + _ = router.Stop() + }() + + fmt.Println("registered adapters:") + for name := range harness.Registry { + fmt.Printf(" - %%s\n", name) + } +} +`, header, rootModulePath+"/"+harnessPackageName) + + return os.WriteFile(filepath.Join(cmdDir, "main.go"), []byte(mainFile), 0o600) +} + +func copyTree(src, dst string) error { + return filepath.WalkDir(src, func(path string, dirEntry fs.DirEntry, err error) error { + if err != nil { + return err + } + rel, err := filepath.Rel(src, path) + if err != nil { + return err + } + if rel == ".git" || strings.HasPrefix(rel, ".git/") { + if dirEntry.IsDir() { + return filepath.SkipDir + } + + return nil + } + target := filepath.Join(dst, rel) + if dirEntry.IsDir() { + return os.MkdirAll(target, 0o750) + } + if !dirEntry.Type().IsRegular() { + return nil + } + data, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return err + } + + return os.WriteFile(target, data, 0o600) + }) +} + +func rewriteModule(dir, modulePath string) error { + goModPath := filepath.Join(dir, "go.mod") + cleanPath := filepath.Clean(goModPath) + data, err := os.ReadFile(cleanPath) + if err != nil { + return fmt.Errorf("generate: read go.mod: %w", err) + } + file, err := modfile.Parse("go.mod", data, nil) + if err != nil { + return fmt.Errorf("generate: parse go.mod: %w", err) + } + + if file.Module == nil { + if addErr := file.AddModuleStmt(modulePath); addErr != nil { + return fmt.Errorf("generate: set module path: %w", addErr) + } + } else { + file.Module.Mod.Path = modulePath + } + if file.Go == nil { + if addErr := file.AddGoStmt("1.21"); addErr != nil { + return fmt.Errorf("generate: set go version: %w", addErr) + } + } else { + file.Go.Version = "1.21" + } + + newData, err := file.Format() + if err != nil { + return fmt.Errorf("generate: format go.mod: %w", err) + } + + return os.WriteFile(cleanPath, newData, 0o600) +} + +func rewriteImports(dir, oldPath, newPath string) error { + return filepath.WalkDir(dir, func(path string, dirEntry fs.DirEntry, err error) error { + if err != nil { + return err + } + if dirEntry.IsDir() { + return nil + } + if filepath.Ext(path) != ".go" { + return nil + } + + cleanPath := filepath.Clean(path) + data, err := os.ReadFile(cleanPath) + if err != nil { + return err + } + updated := strings.ReplaceAll(string(data), oldPath, newPath) + if updated == string(data) { + return nil + } + + return os.WriteFile(cleanPath, []byte(updated), 0o600) + }) +} + +func sanitizePackage(name string) string { + mapped := strings.Map(func(r rune) rune { + if r >= 'A' && r <= 'Z' { + return r - 'A' + 'a' + } + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + return r + } + + return '_' + }, name) + + mapped = strings.Trim(mapped, "_") + if mapped == "" { + return "entry" + } + if mapped[0] >= '0' && mapped[0] <= '9' { + mapped = "x" + mapped + } + + return mapped +} + +func normalizeAllowList(values []string) map[string]struct{} { + if len(values) == 0 { + return nil + } + set := make(map[string]struct{}, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + set[value] = struct{}{} + } + + return set +} + +func ensureRequestedEntriesPresent(entries []scp.LockEntry, required map[string]struct{}) error { + if len(required) == 0 { + return nil + } + present := make(map[string]struct{}, len(entries)) + for _, entry := range entries { + present[entry.Name] = struct{}{} + } + for name := range required { + if _, ok := present[name]; !ok { + return fmt.Errorf("%w: %s", errRequestedEntryMissing, name) + } + } + + return nil +} + +func resolveSourcePath(repo string, entry scp.LockEntry) (string, error) { + if entry.Provenance == "local-path" { + if path, ok := entry.Labels["path"]; ok && path != "" { + return path, nil + } + + return "", fmt.Errorf("%w: %s", errLocalPathMissingLabel, entry.Name) + } + + cacheRoot, err := filepath.Abs(scp.DefaultCacheDir()) + if err != nil { + return "", fmt.Errorf("generate: resolve cache root: %w", err) + } + checkoutDir := filepath.Join( + cacheRoot, + "checkouts", + fmt.Sprintf("%s@%s", entry.Name, sanitizePathFragment(entry.Commit)), + ) + if _, statErr := os.Stat(checkoutDir); statErr != nil { + if !errors.Is(statErr, os.ErrNotExist) { + return "", fmt.Errorf("generate: stat checkout %s: %w", checkoutDir, statErr) + } + if err := cloneRevision(repo, entry.Commit, checkoutDir); err != nil { + return "", err + } + } + + return checkoutDir, nil +} + +func cloneRevision(repo, commit, dest string) error { + if err := os.MkdirAll(filepath.Dir(dest), 0o750); err != nil { + return fmt.Errorf("generate: prepare checkout: %w", err) + } + if _, err := os.Stat(dest); err == nil { + return nil + } + + clone := exec.Command("git", "clone", "--no-checkout", repo, dest) + clone.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0") + if output, err := clone.CombinedOutput(); err != nil { + return fmt.Errorf("generate: git clone %s: %w (output: %s)", repo, err, output) + } + + checkout := exec.Command("git", "-C", dest, "checkout", commit) + checkout.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0") + if output, err := checkout.CombinedOutput(); err != nil { + return fmt.Errorf("generate: git checkout %s: %w (output: %s)", commit, err, output) + } + + return nil +} + +func computeEntryFeatures(wrappers []wrapperInfo, spec featureSpec) map[string][]string { + since := parseFeatureVersions(spec.Features) + overrides := spec.Overrides + result := make(map[string][]string, len(wrappers)) + for _, wrapper := range wrappers { + result[wrapper.Entry.Name] = featuresForEntry(wrapper.Entry, since, overrides) + } + + return result +} + +func parseFeatureVersions(definitions []featureDefinition) map[string]*semver.Version { + result := make(map[string]*semver.Version, len(definitions)) + for _, definition := range definitions { + if definition.Since == "" { + continue + } + version, err := semver.NewVersion(normalizeSemver(definition.Since)) + if err != nil { + continue + } + result[definition.Name] = version + } + + return result +} + +func loadFeatureSpec(path string) (featureSpec, error) { + if path == "" { + return featureSpec{}, nil + } + data, err := readFileSafe(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return featureSpec{}, nil + } + + return featureSpec{}, err + } + + var spec featureSpec + if err := yaml.Unmarshal(data, &spec); err != nil { + return featureSpec{}, err + } + + return spec, nil +} + +func normalizeSemver(input string) string { + if strings.HasPrefix(input, "v") { + return input + } + + return "v" + input +} + +func featuresForEntry( + entry scp.LockEntry, + featureSince map[string]*semver.Version, + overrides map[string]featureOverride, +) []string { + enabled := map[string]struct{}{} + + if after, ok := strings.CutPrefix(entry.Selector, "tag:"); ok { + tag := after + version, err := semver.NewVersion(normalizeSemver(tag)) + if err == nil { + for name, since := range featureSince { + if !version.LessThan(since) { + enabled[name] = struct{}{} + } + } + } + } + + if override, ok := overrides[entry.Name]; ok { + for _, feat := range override.Enable { + enabled[feat] = struct{}{} + } + for _, feat := range override.Disable { + delete(enabled, feat) + } + } + + return setToSortedSlice(enabled) +} + +func setToSortedSlice(set map[string]struct{}) []string { + if len(set) == 0 { + return nil + } + items := make([]string, 0, len(set)) + for item := range set { + items = append(items, item) + } + sort.Strings(items) + + return items +} + +func readFileSafe(path string) ([]byte, error) { + cleaned := filepath.Clean(path) + if cleaned == "" || cleaned == "." { + return nil, fmt.Errorf("%w: %s", errInvalidInputPath, path) + } + if cleaned == ".." || strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) { + return nil, fmt.Errorf("%w: %s", errInvalidInputPath, path) + } + + return os.ReadFile(cleaned) +} + +func buildHeader(licensePath string) (string, error) { + var buf strings.Builder + if licensePath != "" { + content, err := readFileSafe(licensePath) + if err != nil { + return "", err + } + lines := strings.Split(strings.TrimRight(string(content), "\n"), "\n") + for _, line := range lines { + buf.WriteString("// ") + buf.WriteString(line) + buf.WriteByte('\n') + } + if len(lines) > 0 { + buf.WriteByte('\n') + } + } + buf.WriteString("// Code generated by scp generate; DO NOT EDIT.\n\n") + + return buf.String(), nil +} + +func sanitizePathFragment(value string) string { + value = strings.ReplaceAll(value, string(filepath.Separator), "_") + value = strings.ReplaceAll(value, ":", "_") + if value == "" { + return "unknown" + } + + return value +} + +type featureSpec struct { + Schema int `yaml:"schema"` + Features []featureDefinition `yaml:"features"` + Overrides map[string]featureOverride `yaml:"overrides"` + Scenarios []scenarioDefinition `yaml:"scenarios"` +} + +type featureDefinition struct { + Name string `yaml:"name"` + Since string `yaml:"since"` +} + +type featureOverride struct { + Enable []string `yaml:"enable"` + Disable []string `yaml:"disable"` +} + +type scenarioDefinition struct { + ID string `yaml:"id"` + Requires []string `yaml:"requires"` +} diff --git a/internal/resolve/errors.go b/internal/resolve/errors.go new file mode 100644 index 0000000..33cfb67 --- /dev/null +++ b/internal/resolve/errors.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package resolve resolves SCTP references into manifest and lock entries. +package resolve + +import "errors" + +var ( + errNoRefs = errors.New("resolve: no refs specified") + errNoRefsAfterParsing = errors.New("resolve: no refs specified after parsing") + errDuplicateEntry = errors.New("resolve: duplicate entry name") + errUnsupportedType = errors.New("resolve: unsupported selector type") + errBranchNotFound = errors.New("resolve: branch not found") + errInvalidPRNumber = errors.New("resolve: invalid PR number") + errEmptyPathSelector = errors.New("resolve: empty path selector") + errRangeNoMatches = errors.New("resolve: no matching tags for range") +) diff --git a/internal/resolve/options.go b/internal/resolve/options.go new file mode 100644 index 0000000..1fb916a --- /dev/null +++ b/internal/resolve/options.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package resolve + +import "github.com/pion/scp/internal/scp" + +const DefaultRepository = "https://github.com/pion/sctp" + +type Options struct { + Refs []string + Repository string + CacheDir string + ManifestPath string + LockPath string + IncludePreRelease bool + FreezeAt string + AllowDirtyLocal bool +} + +func defaultOptions() Options { + return Options{ + Repository: DefaultRepository, + CacheDir: scp.DefaultCacheDir(), + ManifestPath: scp.DefaultManifestPath(), + LockPath: scp.DefaultLockPath(), + IncludePreRelease: false, + } +} + +func (o Options) WithDefaults() Options { + def := defaultOptions() + if o.Repository == "" { + o.Repository = def.Repository + } + if o.CacheDir == "" { + o.CacheDir = def.CacheDir + } + if o.ManifestPath == "" { + o.ManifestPath = def.ManifestPath + } + if o.LockPath == "" { + o.LockPath = def.LockPath + } + + return o +} diff --git a/internal/resolve/resolver.go b/internal/resolve/resolver.go new file mode 100644 index 0000000..0ecd515 --- /dev/null +++ b/internal/resolve/resolver.go @@ -0,0 +1,291 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package resolve + +import ( + "context" + "fmt" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/Masterminds/semver/v3" + "github.com/pion/scp/internal/scp" +) + +type resolver struct { + opts Options + mirror *scp.Mirror +} + +func newResolver(opts Options, mirror *scp.Mirror) *resolver { + return &resolver{ + opts: opts, + mirror: mirror, + } +} + +func (r *resolver) ResolveAll(ctx context.Context, raws []string) ([]resolvedEntry, error) { + var results []resolvedEntry + seenNames := map[string]struct{}{} + + for _, raw := range raws { + sel, err := scp.ParseSelector(raw) + if err != nil { + return nil, err + } + + entries, err := r.resolveSelector(ctx, sel) + if err != nil { + return nil, fmt.Errorf("%s: %w", raw, err) + } + for _, ent := range entries { + if _, exists := seenNames[ent.Name]; exists { + return nil, fmt.Errorf("%w: %q derived from %q", errDuplicateEntry, ent.Name, ent.Selector.Raw) + } + seenNames[ent.Name] = struct{}{} + results = append(results, ent) + } + } + + return results, nil +} + +func (r *resolver) resolveSelector(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + switch sel.Type { + case scp.SelectorTag: + return r.resolveTag(ctx, sel) + case scp.SelectorBranch: + return r.resolveBranch(ctx, sel) + case scp.SelectorCommit: + return r.resolveCommit(ctx, sel) + case scp.SelectorPR: + return r.resolvePR(ctx, sel) + case scp.SelectorPath: + return r.resolvePath(ctx, sel) + case scp.SelectorRange: + return r.resolveRange(ctx, sel) + default: + return nil, fmt.Errorf("%w: %s", errUnsupportedType, sel.Type) + } +} + +func (r *resolver) resolveTag(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + ref := "refs/tags/" + sel.Value + info, err := r.mirror.ResolveRef(ctx, ref) + if err != nil { + return nil, err + } + name := scp.NameForSelector(sel.Raw, sel.Type, sel.Value, info.Object) + + return []resolvedEntry{{ + Name: name, + Selector: sel, + Commit: info.Object, + Provenance: "tag", + Labels: map[string]string{ + "selector": "tag", + "tag": sel.Value, + }, + }}, nil +} + +func (r *resolver) resolveBranch(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + names, err := r.matchBranches(ctx, sel.Value) + if err != nil { + return nil, err + } + if len(names) == 0 { + return nil, fmt.Errorf("%w: %s", errBranchNotFound, sel.Value) + } + sort.Strings(names) + var result []resolvedEntry + orig := sel.Value + hasGlob := strings.ContainsAny(orig, "*?[") + for _, name := range names { + info, err := r.mirror.ResolveRemoteBranchBefore(ctx, name, r.opts.FreezeAt) + if err != nil { + return nil, err + } + safeName := scp.NameForSelector(sel.Raw, scp.SelectorBranch, name, info.Object) + labels := map[string]string{ + "selector": "branch", + "branch": name, + } + if hasGlob { + labels["pattern"] = orig + } + result = append(result, resolvedEntry{ + Name: safeName, + Selector: selWithValue(sel, name), + Commit: info.Object, + Provenance: provenanceBranch(r.opts.FreezeAt), + Labels: labels, + }) + } + + return result, nil +} + +func (r *resolver) resolveCommit(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + sha, err := r.mirror.RevParse(ctx, sel.Value) + if err != nil { + return nil, err + } + name := scp.NameForSelector(sel.Raw, sel.Type, sel.Value, sha) + + return []resolvedEntry{{ + Name: name, + Selector: sel, + Commit: sha, + Provenance: "commit", + Labels: map[string]string{ + "selector": "commit", + }, + }}, nil +} + +func (r *resolver) resolvePR(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + num, err := strconv.Atoi(sel.Value) + if err != nil { + return nil, fmt.Errorf("%w: %s", errInvalidPRNumber, sel.Value) + } + info, err := r.mirror.ResolvePRHead(ctx, num) + if err != nil { + return nil, err + } + name := scp.NameForSelector(sel.Raw, sel.Type, sel.Value, info.Object) + + return []resolvedEntry{{ + Name: name, + Selector: sel, + Commit: info.Object, + Provenance: fmt.Sprintf("pr#%d", num), + Labels: map[string]string{ + "selector": "pr", + "pr": strconv.Itoa(num), + }, + }}, nil +} + +func (r *resolver) resolvePath(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + if sel.Value == "" { + return nil, errEmptyPathSelector + } + info, err := scp.InspectLocalPath(ctx, sel.Value, r.opts.AllowDirtyLocal) + if err != nil { + return nil, err + } + name := scp.NameForSelector(sel.Raw, sel.Type, sel.Value, info.Commit) + + return []resolvedEntry{{ + Name: name, + Selector: sel, + Commit: info.Commit, + Provenance: "local-path", + Labels: map[string]string{ + "selector": "path", + "path": info.DisplayPath, + "baseHead": info.BaseCommit, + }, + }}, nil +} + +func (r *resolver) resolveRange(ctx context.Context, sel scp.Selector) ([]resolvedEntry, error) { + value := sel.Value + rng, err := semver.NewConstraint(strings.TrimSpace(value)) + if err != nil { + return nil, fmt.Errorf("invalid semver range %q: %w", value, err) + } + + tags, err := r.mirror.ListTags(ctx) + if err != nil { + return nil, err + } + type tagged struct { + info scp.RefInfo + ver *semver.Version + } + var matches []tagged + for _, t := range tags { + v, err := semver.NewVersion(strings.TrimPrefix(t.Name, "v")) + if err != nil { + continue + } + if v.Prerelease() != "" && !r.opts.IncludePreRelease { + // skip prerelease tags unless explicitly allowed + continue + } + if rng.Check(v) { + matches = append(matches, tagged{info: t, ver: v}) + } + } + if len(matches) == 0 { + return nil, fmt.Errorf("%w: %s", errRangeNoMatches, value) + } + sort.Slice(matches, func(i, j int) bool { + return matches[i].ver.LessThan(matches[j].ver) + }) + + var results []resolvedEntry + for _, match := range matches { + tagSel := scp.Selector{ + Raw: "tag:" + match.info.Name, + Type: scp.SelectorTag, + Value: match.info.Name, + IsFloating: false, + } + name := scp.NameForSelector(tagSel.Raw, tagSel.Type, tagSel.Value, match.info.Object) + results = append(results, resolvedEntry{ + Name: name, + Selector: tagSel, + Commit: match.info.Object, + Provenance: "tag", + Labels: map[string]string{ + "selector": "tag", + "tag": match.info.Name, + }, + }) + } + + return results, nil +} + +func (r *resolver) matchBranches(ctx context.Context, pattern string) ([]string, error) { + if !strings.ContainsAny(pattern, "*?[") { + return []string{pattern}, nil + } + branches, err := r.mirror.ListBranches(ctx) + if err != nil { + return nil, err + } + var matches []string + for _, br := range branches { + ok, err := filepath.Match(pattern, br.Name) + if err != nil { + return nil, err + } + if ok { + matches = append(matches, br.Name) + } + } + + return matches, nil +} + +func provenanceBranch(freeze string) string { + if freeze == "" { + return "branch" + } + + return fmt.Sprintf("branch@%s", freeze) +} + +func selWithValue(sel scp.Selector, value string) scp.Selector { + sel.Value = value + sel.Raw = string(sel.Type) + ":" + value + + return sel +} diff --git a/internal/resolve/run.go b/internal/resolve/run.go new file mode 100644 index 0000000..4d404e2 --- /dev/null +++ b/internal/resolve/run.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package resolve + +import ( + "context" + + "github.com/pion/scp/internal/scp" +) + +func Run(ctx context.Context, opts Options) error { + opts = opts.WithDefaults() + if len(opts.Refs) == 0 { + return errNoRefs + } + + refs := scp.SplitAndTrim(opts.Refs) + if len(refs) == 0 { + return errNoRefsAfterParsing + } + + mirror, err := scp.EnsureMirror(ctx, opts.Repository, opts.CacheDir) + if err != nil { + return err + } + + resolver := newResolver(opts, mirror) + resolved, err := resolver.ResolveAll(ctx, refs) + if err != nil { + return err + } + + manifest := &scp.Manifest{ + Schema: 2, + Repo: opts.Repository, + Entries: make([]scp.ManifestEntry, 0, len(resolved)), + } + + lock := &scp.Lockfile{ + Schema: 2, + Metadata: scp.LockMetadata{ + Repository: opts.Repository, + }, + Entries: make([]scp.LockEntry, 0, len(resolved)), + } + + for _, entry := range resolved { + manifest.Entries = append(manifest.Entries, scp.ManifestEntry{ + Name: entry.Name, + Selector: entry.Selector.Raw, + }) + lock.Entries = append(lock.Entries, scp.LockEntry{ + Name: entry.Name, + Selector: entry.Selector.Raw, + Commit: entry.Commit, + Provenance: entry.Provenance, + Labels: entry.Labels, + }) + } + + if err := scp.WriteManifest(opts.ManifestPath, manifest); err != nil { + return err + } + if err := scp.WriteLock(opts.LockPath, lock); err != nil { + return err + } + + return nil +} + +type resolvedEntry struct { + Name string + Selector scp.Selector + Commit string + Provenance string + Labels map[string]string +} diff --git a/internal/scp/git_mirror.go b/internal/scp/git_mirror.go new file mode 100644 index 0000000..f29e181 --- /dev/null +++ b/internal/scp/git_mirror.go @@ -0,0 +1,138 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +type Mirror struct { + URL string + Path string +} + +func EnsureMirror(ctx context.Context, repoURL, cacheDir string) (*Mirror, error) { + cacheAbs, err := filepath.Abs(cacheDir) + if err != nil { + return nil, fmt.Errorf("git cache abs: %w", err) + } + mirrorPath := filepath.Join(cacheAbs, mirrorDirName(repoURL)) + if err := os.MkdirAll(cacheAbs, 0o750); err != nil { + return nil, fmt.Errorf("git cache: %w", err) + } + + if _, err := os.Stat(mirrorPath); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("git mirror stat: %w", err) + } + if err := runGit(ctx, cacheAbs, "clone", "--mirror", repoURL, mirrorPath); err != nil { + return nil, fmt.Errorf("git clone mirror: %w", err) + } + } else { + if err := runGit(ctx, mirrorPath, "remote", "update", "--prune"); err != nil { + return nil, fmt.Errorf("git remote update: %w", err) + } + } + + return &Mirror{ + URL: repoURL, + Path: mirrorPath, + }, nil +} + +func (m *Mirror) RevParse(ctx context.Context, rev string) (string, error) { + out, err := runGitStdout(ctx, m.Path, "rev-parse", rev) + if err != nil { + return "", err + } + + return strings.TrimSpace(out), nil +} + +func (m *Mirror) ForEachRef(ctx context.Context, pattern string, format string) ([]string, error) { + args := []string{"for-each-ref"} + if pattern != "" { + args = append(args, pattern) + } + if format != "" { + args = append(args, "--format="+format) + } + out, err := runGitStdout(ctx, m.Path, args...) + if err != nil { + return nil, err + } + lines := strings.Split(strings.TrimSpace(out), "\n") + if len(lines) == 1 && lines[0] == "" { + return nil, nil + } + + return lines, nil +} + +func (m *Mirror) ResolveBefore(ctx context.Context, revPattern string, before string) (string, error) { + if before == "" { + return m.RevParse(ctx, revPattern) + } + out, err := runGitStdout(ctx, m.Path, "rev-list", "-n", "1", "--before="+before, revPattern) + if err != nil { + return "", err + } + + return strings.TrimSpace(out), nil +} + +func mirrorDirName(repoURL string) string { + safe := repoURL + safe = strings.ReplaceAll(safe, "://", "_") + safe = strings.ReplaceAll(safe, "/", "_") + safe = strings.ReplaceAll(safe, "@", "_") + + return safe + ".git" +} + +func runGit(ctx context.Context, dir string, args ...string) error { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return gitCommandError(args, err, stderr.String()) + } + + return nil +} + +func runGitStdout(ctx context.Context, dir string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", gitCommandError(args, err, stderr.String()) + } + + return stdout.String(), nil +} + +var errGitCommand = errors.New("git command failed") + +func gitCommandError(args []string, cmdErr error, stderr string) error { + summary := strings.Join(args, " ") + + wrapped := fmt.Errorf("git %s: %w", summary, cmdErr) + stderr = strings.TrimSpace(stderr) + if stderr != "" { + wrapped = fmt.Errorf("%w (%s)", wrapped, stderr) + } + + return errors.Join(errGitCommand, wrapped) +} diff --git a/internal/scp/git_refs.go b/internal/scp/git_refs.go new file mode 100644 index 0000000..a986a64 --- /dev/null +++ b/internal/scp/git_refs.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "context" + "errors" + "fmt" + "strings" +) + +type RefInfo struct { + Name string + Object string +} + +var errMalformedRefLine = errors.New("git reference line malformed") + +func (m *Mirror) ListBranches(ctx context.Context) ([]RefInfo, error) { + lines, err := m.ForEachRef(ctx, "refs/heads", "%(refname:strip=2) %(objectname)") + if err != nil { + return nil, err + } + var infos []RefInfo + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Fields(line) + if len(parts) < 2 { + return nil, fmt.Errorf("%w: %q", errMalformedRefLine, line) + } + infos = append(infos, RefInfo{ + Name: parts[0], + Object: parts[1], + }) + } + + return infos, nil +} + +func (m *Mirror) ResolveBranch(ctx context.Context, name string) (RefInfo, error) { + return m.ResolveRef(ctx, "refs/heads/"+name) +} + +func (m *Mirror) ResolveRemoteBranch(ctx context.Context, name string) (RefInfo, error) { + // Mirrors have local refs/heads entries equivalent to origin/. + return m.ResolveBranch(ctx, name) +} + +func (m *Mirror) ResolveRemoteBranchBefore(ctx context.Context, name string, before string) (RefInfo, error) { + sha, err := m.ResolveBefore(ctx, "refs/heads/"+name, before) + if err != nil { + return RefInfo{}, err + } + + return RefInfo{Name: name, Object: sha}, nil +} + +func (m *Mirror) ResolvePRHead(ctx context.Context, number int) (RefInfo, error) { + ref := fmt.Sprintf("refs/pull/%d/head", number) + + return m.ResolveRef(ctx, ref) +} diff --git a/internal/scp/git_tags.go b/internal/scp/git_tags.go new file mode 100644 index 0000000..f145a98 --- /dev/null +++ b/internal/scp/git_tags.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "context" + "fmt" + "strings" +) + +func (m *Mirror) ListTags(ctx context.Context) ([]RefInfo, error) { + lines, err := m.ForEachRef(ctx, "refs/tags", "%(refname:strip=2) %(target)") + if err != nil { + return nil, err + } + var infos []RefInfo + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Fields(line) + if len(parts) < 2 { + return nil, fmt.Errorf("%w: %q", errMalformedRefLine, line) + } + infos = append(infos, RefInfo{ + Name: parts[0], + Object: parts[1], + }) + } + + return infos, nil +} + +func (m *Mirror) ResolveRef(ctx context.Context, ref string) (RefInfo, error) { + sha, err := m.RevParse(ctx, ref) + if err != nil { + return RefInfo{}, err + } + + return RefInfo{Name: ref, Object: sha}, nil +} diff --git a/internal/scp/jsonio.go b/internal/scp/jsonio.go new file mode 100644 index 0000000..24a9b84 --- /dev/null +++ b/internal/scp/jsonio.go @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +var ( + errPathOutsideWorkspace = errors.New("state path escapes working directory") + errEmptyStatePath = errors.New("state path empty") +) + +func ReadManifest(path string) (*Manifest, error) { + safePath, err := cleanStatePath(path) + if err != nil { + return nil, err + } + + f, err := os.Open(safePath) + if err != nil { + return nil, fmt.Errorf("read manifest: %w", err) + } + defer func() { + _ = f.Close() + }() + + var m Manifest + if err := json.NewDecoder(f).Decode(&m); err != nil { + return nil, fmt.Errorf("parse manifest: %w", err) + } + + return &m, nil +} + +func WriteManifest(path string, m *Manifest) error { + safePath, err := cleanStatePath(path) + if err != nil { + return err + } + + if parentErr := makeParent(safePath); parentErr != nil { + return parentErr + } + + return writeJSON(safePath, m) +} + +func ReadLock(path string) (*Lockfile, error) { + safePath, err := cleanStatePath(path) + if err != nil { + return nil, err + } + + f, err := os.Open(safePath) + if err != nil { + return nil, fmt.Errorf("read lock: %w", err) + } + defer func() { + _ = f.Close() + }() + + var l Lockfile + if err := json.NewDecoder(f).Decode(&l); err != nil { + return nil, fmt.Errorf("parse lock: %w", err) + } + + return &l, nil +} + +func WriteLock(path string, l *Lockfile) error { + safePath, err := cleanStatePath(path) + if err != nil { + return err + } + + if parentErr := makeParent(safePath); parentErr != nil { + return parentErr + } + + return writeJSON(safePath, l) +} + +func CopyJSON(dst string, r io.Reader) error { + safePath, err := cleanStatePath(dst) + if err != nil { + return err + } + + if parentErr := makeParent(safePath); parentErr != nil { + return parentErr + } + + f, err := openWritableFile(safePath, 0o640) + if err != nil { + return fmt.Errorf("create %s: %w", safePath, err) + } + defer func() { + _ = f.Close() + }() + + if _, err := io.Copy(f, r); err != nil { + return fmt.Errorf("write %s: %w", safePath, err) + } + + return nil +} + +func makeParent(path string) error { + dir := filepath.Dir(path) + if dir == "." || dir == "" { + return nil + } + + return os.MkdirAll(dir, 0o750) +} + +func writeJSON(path string, v any) error { + tmp := path + ".tmp" + tmpFile, err := openWritableFile(tmp, 0o640) + if err != nil { + return fmt.Errorf("create temp: %w", err) + } + + enc := json.NewEncoder(tmpFile) + enc.SetIndent("", " ") + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + closeErr := tmpFile.Close() + removeErr := os.Remove(tmp) + + combined := fmt.Errorf("encode json: %w", err) + if closeErr != nil { + combined = errors.Join(combined, fmt.Errorf("close temp file: %w", closeErr)) + } + if removeErr != nil { + combined = errors.Join(combined, fmt.Errorf("remove temp file: %w", removeErr)) + } + + return combined + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("close temp: %w", err) + } + + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("rename temp: %w", err) + } + + return nil +} + +func cleanStatePath(path string) (string, error) { + cleaned := filepath.Clean(path) + if cleaned == "" { + return "", errEmptyStatePath + } + + if filepath.IsAbs(cleaned) { + return cleaned, nil + } + + if cleaned == ".." || strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) { + return "", errPathOutsideWorkspace + } + + return cleaned, nil +} + +func openWritableFile(path string, perm os.FileMode) (*os.File, error) { + return os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) +} diff --git a/internal/scp/jsonio_test.go b/internal/scp/jsonio_test.go new file mode 100644 index 0000000..7d223a6 --- /dev/null +++ b/internal/scp/jsonio_test.go @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCleanStatePath(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + abs := filepath.Join(tmpDir, "manifest.json") + + tests := []struct { + name string + input string + want string + wantErr error + }{ + {"Relative", "state/lock.json", filepath.Clean("state/lock.json"), nil}, + {"Absolute", abs, abs, nil}, + {"ParentTraversal", "../lock.json", "", errPathOutsideWorkspace}, + {"Empty", "", ".", nil}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := cleanStatePath(tc.input) + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + + return + } + + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestManifestRoundTrip(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "manifest.json") + + manifest := &Manifest{ + Schema: 2, + Repo: "https://example.com/repo.git", + Entries: []ManifestEntry{ + {Name: "v1", Selector: "tag:v1.0.0"}, + }, + } + + require.NoError(t, WriteManifest(path, manifest)) + + readManifest, err := ReadManifest(path) + require.NoError(t, err) + require.Equal(t, manifest, readManifest) +} + +func TestLockfileRoundTrip(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "lock.json") + + lock := &Lockfile{ + Schema: 2, + Entries: []LockEntry{ + { + Name: "v1", + Selector: "tag:v1.0.0", + Commit: "abc123", + Provenance: "tag", + }, + }, + } + + require.NoError(t, WriteLock(path, lock)) + + got, err := ReadLock(path) + require.NoError(t, err) + require.Equal(t, lock, got) +} + +func TestCopyJSON(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + dst := filepath.Join(dir, "copy.json") + + data := `{"hello":"world"}` + require.NoError(t, CopyJSON(dst, strings.NewReader(data))) + + contentFile, err := os.Open(dst) + require.NoError(t, err) + defer func() { + _ = contentFile.Close() + }() + + content, err := io.ReadAll(contentFile) + require.NoError(t, err) + require.Equal(t, data, string(content)) + + var parsed map[string]string + require.NoError(t, json.Unmarshal(content, &parsed)) + require.Equal(t, "world", parsed["hello"]) +} diff --git a/internal/scp/local.go b/internal/scp/local.go new file mode 100644 index 0000000..24b7d7b --- /dev/null +++ b/internal/scp/local.go @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +type LocalInfo struct { + Commit string + BaseCommit string + DisplayPath string +} + +var ( + errLocalDirty = errors.New("local path has uncommitted changes") + errLocalNotRepo = errors.New("path is not a git repository") +) + +func InspectLocalPath(ctx context.Context, path string, allowDirty bool) (LocalInfo, error) { + abs, err := filepath.Abs(path) + if err != nil { + return LocalInfo{}, fmt.Errorf("local path: %w", err) + } + if _, statErr := os.Stat(abs); statErr != nil { + return LocalInfo{}, fmt.Errorf("local path: %w", statErr) + } + + if repoErr := ensureGitRepo(ctx, abs); repoErr != nil { + return LocalInfo{}, repoErr + } + + status, err := gitStatusPorcelain(ctx, abs) + if err != nil { + return LocalInfo{}, err + } + + isDirty := len(status) > 0 + if isDirty && !allowDirty { + return LocalInfo{}, fmt.Errorf("%w: %s", errLocalDirty, abs) + } + + head, err := gitRevParse(ctx, abs, "HEAD") + if err != nil { + return LocalInfo{}, err + } + + commit := head + if isDirty { + diff, err := gitDiff(ctx, abs) + if err != nil { + return LocalInfo{}, err + } + sum := sha256.Sum256(diff) + commit = "dirty:" + hex.EncodeToString(sum[:4]) + } + + return LocalInfo{ + Commit: commit, + BaseCommit: head, + DisplayPath: abs, + }, nil +} + +func ensureGitRepo(ctx context.Context, dir string) error { + _, err := gitCommand(ctx, dir, "rev-parse", "--is-inside-work-tree") + if err != nil { + return fmt.Errorf("%w: %s", errLocalNotRepo, dir) + } + + return nil +} + +func gitStatusPorcelain(ctx context.Context, dir string) (string, error) { + out, err := gitCommand(ctx, dir, "status", "--porcelain") + + return strings.TrimSpace(out), err +} + +func gitRevParse(ctx context.Context, dir string, rev string) (string, error) { + out, err := gitCommand(ctx, dir, "rev-parse", rev) + + return strings.TrimSpace(out), err +} + +func gitDiff(ctx context.Context, dir string) ([]byte, error) { + out, err := gitCommandBytes(ctx, dir, "diff") + + return out, err +} + +func gitCommand(ctx context.Context, dir string, args ...string) (string, error) { + buf, err := gitCommandBytes(ctx, dir, args...) + + return string(buf), err +} + +func gitCommandBytes(ctx context.Context, dir string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return nil, gitCommandError(args, err, stderr.String()) + } + + return stdout.Bytes(), nil +} diff --git a/internal/scp/lockfile.go b/internal/scp/lockfile.go new file mode 100644 index 0000000..a44658d --- /dev/null +++ b/internal/scp/lockfile.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +type Lockfile struct { + Schema int `json:"schema"` + GeneratedAt string `json:"generatedAt"` + Entries []LockEntry `json:"entries"` + Version string `json:"version,omitempty"` + Metadata LockMetadata `json:"metadata,omitempty"` +} + +type LockEntry struct { + Name string `json:"name"` + Selector string `json:"selector"` + Commit string `json:"commit"` + Provenance string `json:"provenance"` + Labels map[string]string `json:"labels,omitempty"` +} + +type LockMetadata struct { + Repository string `json:"repository,omitempty"` +} diff --git a/internal/scp/manifest.go b/internal/scp/manifest.go new file mode 100644 index 0000000..9b1a350 --- /dev/null +++ b/internal/scp/manifest.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +type Manifest struct { + Schema int `json:"schema"` + Repo string `json:"repo"` + Entries []ManifestEntry `json:"entries"` +} + +type ManifestEntry struct { + Name string `json:"name"` + Selector string `json:"selector"` +} diff --git a/internal/scp/naming.go b/internal/scp/naming.go new file mode 100644 index 0000000..486e640 --- /dev/null +++ b/internal/scp/naming.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "regexp" + "strings" +) + +var ( + nonAlnum = regexp.MustCompile(`[^a-zA-Z0-9]+`) + suffixAllowed = regexp.MustCompile(`[a-f0-9]+`) +) + +func Slugify(input string) string { + if input == "" { + return "entry" + } + s := nonAlnum.ReplaceAllString(input, "_") + s = strings.Trim(s, "_") + s = strings.ToLower(s) + if s == "" { + return "entry" + } + + return s +} + +func WithSuffix(base, sha string) string { + if len(sha) > 7 { + sha = sha[:7] + } + sha = sanitizeSuffix(sha) + base = Slugify(base) + if sha == "" { + return base + } + if strings.HasSuffix(base, sha) { + return base + } + + return base + "_" + sha +} + +func NameForSelector(raw string, selType SelectorType, value string, commit string) string { + switch selType { + case SelectorTag: + return Slugify(value) + case SelectorBranch: + return WithSuffix("branch_"+Slugify(value), commit) + case SelectorPR: + return WithSuffix("pr_"+Slugify(value), commit) + case SelectorCommit: + return WithSuffix("sha", commit) + case SelectorPath: + return WithSuffix("local_"+Slugify(value), commit) + case SelectorRange: + return WithSuffix(Slugify(value), commit) + default: + return WithSuffix(Slugify(raw), commit) + } +} + +func sanitizeSuffix(sha string) string { + sha = strings.ToLower(sha) + if matches := suffixAllowed.FindAllString(sha, -1); len(matches) > 0 { + return matches[0] + } + + return "" +} diff --git a/internal/scp/naming_test.go b/internal/scp/naming_test.go new file mode 100644 index 0000000..1d25e4c --- /dev/null +++ b/internal/scp/naming_test.go @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSlugify(t *testing.T) { + t.Parallel() + + tests := map[string]string{ + "": "entry", + "Hello World!": "hello_world", + " multiple ": "multiple", + "***": "entry", + "MiXeD-Case123": "mixed_case123", + } + + for input, want := range tests { + got := Slugify(input) + require.Equal(t, want, got, "Slugify(%q)", input) + } +} + +func TestWithSuffix(t *testing.T) { + t.Parallel() + + got := WithSuffix("Feature", "ABCDEF123456") + want := "feature_abcdef1" + require.Equal(t, want, got) + + got = WithSuffix("already_suffix_abcdef1", "abcdef1234") + require.Equal(t, "already_suffix_abcdef1", got) +} + +func TestNameForSelector(t *testing.T) { + t.Parallel() + + tests := []struct { + desc string + raw string + typ SelectorType + value string + commit string + want string + }{ + {"Tag", "tag:v1.0.0", SelectorTag, "v1.0.0", "abc", "v1_0_0"}, + {"Branch", "branch:main", SelectorBranch, "main", "abcdef1", "branch_main_abcdef1"}, + {"PR", "pr:42", SelectorPR, "42", "ff00ee", "pr_42_ff00ee"}, + {"Commit", "commit:facefeed", SelectorCommit, "facefeed", "facefeed", "sha_facefee"}, + {"Path", "path:/tmp", SelectorPath, "/tmp", "1234567", "local_tmp_1234567"}, + {"Default", "custom:thing", SelectorType("custom"), "thing", "deadbeef", "custom_thing_deadbee"}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + got := NameForSelector(tc.raw, tc.typ, tc.value, tc.commit) + require.Equal(t, tc.want, got, "NameForSelector(%q, %v, %q, %q)", + tc.raw, tc.typ, tc.value, tc.commit) + }) + } +} diff --git a/internal/scp/paths.go b/internal/scp/paths.go new file mode 100644 index 0000000..1332bd0 --- /dev/null +++ b/internal/scp/paths.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package scp contains shared helper utilities for the scp CLI tools. +package scp + +import "path/filepath" + +const ( + DefaultStateDir = ".scp" + DefaultManifestFile = "manifest.json" + DefaultLockFile = "lock.json" + DefaultCacheDirName = "cache" + DefaultFeaturesFile = "features.yaml" + DefaultOutputDirName = "generated" +) + +func DefaultManifestPath() string { + return filepath.Join(DefaultStateDir, DefaultManifestFile) +} + +func DefaultLockPath() string { + return filepath.Join(DefaultStateDir, DefaultLockFile) +} + +func DefaultCacheDir() string { + return filepath.Join(DefaultStateDir, DefaultCacheDirName) +} + +func DefaultOutputDir() string { + return DefaultOutputDirName +} + +func DefaultFeaturesPath() string { + return DefaultFeaturesFile +} diff --git a/internal/scp/selectors.go b/internal/scp/selectors.go new file mode 100644 index 0000000..f5f5b02 --- /dev/null +++ b/internal/scp/selectors.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "errors" + "fmt" + "path/filepath" + "strings" +) + +type SelectorType string + +const ( + SelectorTag SelectorType = "tag" + SelectorBranch SelectorType = "branch" + SelectorCommit SelectorType = "commit" + SelectorPR SelectorType = "pr" + SelectorPath SelectorType = "path" + SelectorRange SelectorType = "range" +) + +type Selector struct { + Raw string + Type SelectorType + Value string + Flags map[string]string + IsFloating bool +} + +var ( + errSelectorEmpty = errors.New("selector empty") + errSelectorMissingType = errors.New("selector missing type") + errSelectorUnsupported = errors.New("selector type unsupported") +) + +func ParseSelector(raw string) (Selector, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return Selector{}, errSelectorEmpty + } + + prefix, value, err := splitSelector(trimmed) + if err != nil { + return Selector{}, err + } + + sel := Selector{ + Raw: trimmed, + Type: SelectorType(strings.ToLower(prefix)), + Value: value, + Flags: map[string]string{}, + IsFloating: false, + } + + if err := normaliseSelector(&sel); err != nil { + return Selector{}, err + } + + return sel, nil +} + +func splitSelector(raw string) (prefix string, value string, err error) { + colon := strings.IndexByte(raw, ':') + if colon <= 0 { + return "", "", fmt.Errorf("%w: %q", errSelectorMissingType, raw) + } + + return raw[:colon], raw[colon+1:], nil +} + +func normaliseSelector(sel *Selector) error { + switch sel.Type { + case SelectorTag: + case SelectorBranch: + sel.IsFloating = true + case SelectorCommit: + case SelectorPR: + sel.IsFloating = true + case SelectorPath: + cleaned := strings.TrimSpace(sel.Value) + if cleaned != "" { + if !filepath.IsAbs(cleaned) { + cleaned = filepath.Join(".", cleaned) + } + sel.Value = filepath.Clean(cleaned) + } + case SelectorRange: + sel.IsFloating = true + default: + return fmt.Errorf("%w: %s", errSelectorUnsupported, sel.Type) + } + + return nil +} diff --git a/internal/scp/selectors_test.go b/internal/scp/selectors_test.go new file mode 100644 index 0000000..49b002c --- /dev/null +++ b/internal/scp/selectors_test.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseSelector(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + wantType SelectorType + wantValue string + wantFloating bool + wantErr error + }{ + { + name: "Tag", + raw: "tag:v1.2.3", + wantType: SelectorTag, + wantValue: "v1.2.3", + wantFloating: false, + }, + { + name: "BranchFloating", + raw: "branch:main", + wantType: SelectorBranch, + wantValue: "main", + wantFloating: true, + }, + { + name: "Commit", + raw: "commit:abc123", + wantType: SelectorCommit, + wantValue: "abc123", + }, + { + name: "PullRequest", + raw: "pr:42", + wantType: SelectorPR, + wantValue: "42", + wantFloating: true, + }, + { + name: "PathRelative", + raw: "path:foo/bar", + wantType: SelectorPath, + wantValue: filepath.Join(".", "foo", "bar"), + }, + { + name: "RangeFloating", + raw: "range:>=1.0.0 <1.1.0", + wantType: SelectorRange, + wantValue: ">=1.0.0 <1.1.0", + wantFloating: true, + }, + { + name: "Empty", + raw: " ", + wantErr: errSelectorEmpty, + }, + { + name: "MissingType", + raw: "no-prefix", + wantErr: errSelectorMissingType, + }, + { + name: "Unsupported", + raw: "foo:bar", + wantErr: errSelectorUnsupported, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sel, err := ParseSelector(tc.raw) + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + + return + } + + require.NoError(t, err) + require.Equal(t, tc.wantType, sel.Type) + require.Equal(t, tc.wantValue, sel.Value) + require.Equal(t, tc.wantFloating, sel.IsFloating) + require.Equal(t, tc.raw, sel.Raw) + }) + } +} diff --git a/internal/scp/slice.go b/internal/scp/slice.go new file mode 100644 index 0000000..35359a4 --- /dev/null +++ b/internal/scp/slice.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import "strings" + +func SplitAndTrim(fields []string) []string { + var out []string + for _, field := range fields { + for _, part := range strings.Split(field, ",") { + if trimmed := strings.TrimSpace(part); trimmed != "" { + out = append(out, trimmed) + } + } + } + + return out +} diff --git a/internal/scp/slice_test.go b/internal/scp/slice_test.go new file mode 100644 index 0000000..213d3ce --- /dev/null +++ b/internal/scp/slice_test.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package scp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSplitAndTrim(t *testing.T) { + t.Parallel() + + input := []string{" alpha ,beta", "gamma,,", " delta "} + got := SplitAndTrim(input) + want := []string{"alpha", "beta", "gamma", "delta"} + + require.Equal(t, want, got) +} diff --git a/internal/testcmd/cases.go b/internal/testcmd/cases.go new file mode 100644 index 0000000..907e09f --- /dev/null +++ b/internal/testcmd/cases.go @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "fmt" + "strings" +) + +const ( + caseMaxBurst = "max-burst" + caseHandshake = "handshake" + caseUnorderedLowRTT = "unordered-late-low-rtt" + caseUnorderedHighRTT = "unordered-late-high-rtt" + caseUnorderedDynamicRTT = "unordered-late-dynamic-rtt" + caseCongestionRack = "congestion" + caseRetransmission = "retransmission" + caseRackReorderLow = "reorder-low" + caseRackReorderHigh = "reorder-high" + caseRackBurstLoss = "burst-loss" +) + +type scenarioResult struct { + CaseName string + Pair pair + ForwardBurst int + ReverseBurst int + Passed bool + Errored bool + Details string + Iteration int + Metrics resultMetrics +} + +func runCases(ctx context.Context, caseNames []string, pairs []pair, seed int64, repeat int) ([]scenarioResult, error) { + names := normalizeCases(caseNames) + if len(names) == 0 { + return nil, errNoCases + } + + var results []scenarioResult + for _, name := range names { + switch name { + case caseMaxBurst: + res, err := runMaxBurstCase(ctx, pairs, seed, repeat) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseHandshake: + res, err := runHandshakeCase(ctx, pairs, seed, repeat) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseUnorderedLowRTT: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, lowRTTProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseUnorderedHighRTT: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, highRTTProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseUnorderedDynamicRTT: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, dynamicRTTProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseCongestionRack: + res, err := runCongestionCase(ctx, pairs, seed, repeat) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseRetransmission: + res, err := runRetransmissionCase(ctx, pairs, seed, repeat) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseRackReorderLow: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, rackReorderLowProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseRackReorderHigh: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, rackReorderHighProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + case caseRackBurstLoss: + res, err := runUnorderedCase(ctx, pairs, seed, repeat, rackBurstLossProfile()) + if err != nil { + return nil, err + } + results = append(results, res...) + default: + return nil, fmt.Errorf("%w: %s", errUnknownCase, name) + } + } + + return results, nil +} + +func normalizeCases(names []string) []string { + seen := make(map[string]struct{}, len(names)) + var ordered []string + for _, name := range names { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + ordered = append(ordered, trimmed) + } + + return ordered +} diff --git a/internal/testcmd/cases_extra.go b/internal/testcmd/cases_extra.go new file mode 100644 index 0000000..b549817 --- /dev/null +++ b/internal/testcmd/cases_extra.go @@ -0,0 +1,140 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "time" +) + +type networkProfile struct { + MinDelay time.Duration + MaxJitter time.Duration + DropPercent float64 + Unordered bool + Name string +} + +func lowRTTProfile() networkProfile { + return networkProfile{ + MinDelay: 10 * time.Millisecond, + MaxJitter: 10 * time.Millisecond, + DropPercent: 0.0, + Unordered: true, + Name: "low-rtt-late", + } +} + +func highRTTProfile() networkProfile { + return networkProfile{ + MinDelay: 180 * time.Millisecond, + MaxJitter: 60 * time.Millisecond, + DropPercent: 0.0, + Unordered: true, + Name: "high-rtt-late", + } +} + +func dynamicRTTProfile() networkProfile { + return networkProfile{ + MinDelay: 40 * time.Millisecond, + MaxJitter: 180 * time.Millisecond, + DropPercent: 0.0, + Unordered: true, + Name: "dynamic-rtt-late", + } +} + +func congestionProfile() networkProfile { + return networkProfile{ + MinDelay: 60 * time.Millisecond, + MaxJitter: 40 * time.Millisecond, + DropPercent: 0.02, + Unordered: false, + Name: "congestion", + } +} + +func lossProfile() networkProfile { + return networkProfile{ + MinDelay: 40 * time.Millisecond, + MaxJitter: 20 * time.Millisecond, + DropPercent: 0.05, + Unordered: false, + Name: "retransmission", + } +} + +func rackReorderLowProfile() networkProfile { + return networkProfile{ + MinDelay: 15 * time.Millisecond, + MaxJitter: 25 * time.Millisecond, + DropPercent: 1.5, // light loss with reordering + Unordered: true, + Name: "reorder-low", + } +} + +func rackReorderHighProfile() networkProfile { + return networkProfile{ + MinDelay: 140 * time.Millisecond, + MaxJitter: 120 * time.Millisecond, + DropPercent: 2.5, + Unordered: true, + Name: "reorder-high", + } +} + +func rackBurstLossProfile() networkProfile { + return networkProfile{ + MinDelay: 50 * time.Millisecond, + MaxJitter: 50 * time.Millisecond, + DropPercent: 4.0, + Unordered: true, + Name: "burst-loss", + } +} + +func runHandshakeCase(ctx context.Context, pairs []pair, seed int64, repeat int) ([]scenarioResult, error) { + // handshake is effectively an unordered case with no data; we reuse unordered runner with a tiny payload count. + return runUnorderedCase(ctx, pairs, seed, repeat, networkProfile{Name: "handshake"}) +} + +func runUnorderedCase(ctx context.Context, pairs []pair, seed int64, repeat int, profile networkProfile) ([]scenarioResult, error) { + if len(pairs) == 0 { + return nil, errInsufficientEntries + } + + var results []scenarioResult + for idx, p := range pairs { + for iter := range repeat { + seq := idx*repeat + iter + forward, reverse, metrics, err := runBurstTrafficProfile(ctx, p, seed, seq, profile) + res := scenarioResult{ + CaseName: profile.Name, + Pair: p, + Iteration: iter + 1, + ForwardBurst: forward, + ReverseBurst: reverse, + Metrics: metrics, + Passed: err == nil && forward >= minBurstPackets && reverse >= minBurstPackets, + Details: formatMetrics(metrics), + } + if err != nil { + res.Details += " err=" + err.Error() + } + results = append(results, res) + } + } + + return results, nil +} + +func runCongestionCase(ctx context.Context, pairs []pair, seed int64, repeat int) ([]scenarioResult, error) { + return runUnorderedCase(ctx, pairs, seed, repeat, congestionProfile()) +} + +func runRetransmissionCase(ctx context.Context, pairs []pair, seed int64, repeat int) ([]scenarioResult, error) { + return runUnorderedCase(ctx, pairs, seed, repeat, lossProfile()) +} diff --git a/internal/testcmd/cases_test.go b/internal/testcmd/cases_test.go new file mode 100644 index 0000000..c34520c --- /dev/null +++ b/internal/testcmd/cases_test.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "testing" + + "github.com/pion/scp/internal/scp" + "github.com/stretchr/testify/require" +) + +func TestRunCasesDefaultsToMaxBurst(t *testing.T) { + t.Parallel() + + pairs := []pair{{ + Left: scp.LockEntry{Name: "v1", Commit: "aaaaaaaa"}, + Right: scp.LockEntry{Name: "v2", Commit: "bbbbbbbb"}, + }} + + _, err := runCases(context.Background(), nil, pairs, 123, 1) + require.ErrorIs(t, err, errNoCases) +} + +func TestNormalizeCases(t *testing.T) { + t.Parallel() + + input := []string{" max-burst", "max-burst", "other", " ", "other"} + got := normalizeCases(input) + + require.Equal(t, []string{"max-burst", "other"}, got) +} + +func TestRunCasesUnknownCase(t *testing.T) { + t.Parallel() + + pairs := []pair{{ + Left: scp.LockEntry{Name: "v1", Commit: "aaaaaaaa"}, + Right: scp.LockEntry{Name: "v2", Commit: "bbbbbbbb"}, + }} + + _, err := runCases(context.Background(), []string{"nope"}, pairs, 123, 1) + require.ErrorIs(t, err, errUnknownCase) +} diff --git a/internal/testcmd/errors.go b/internal/testcmd/errors.go new file mode 100644 index 0000000..457f857 --- /dev/null +++ b/internal/testcmd/errors.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package testcmd provides the scaffolding for the scp test command. +package testcmd + +import "errors" + +var ( + errMissingLockPath = errors.New("test: lock path is required") + errEmptyLock = errors.New("test: lock file has no entries") + errNoSelectableEntries = errors.New("test: no entries selected after filtering") + errRequestedEntryMissing = errors.New("test: requested entry missing") + errNoCases = errors.New("test: no cases specified") + errInsufficientEntries = errors.New("test: at least two entries are required") + errUnknownPairMode = errors.New("test: unknown pair mode") + errMissingExplicitPairs = errors.New("test: explicit pairs required") + errUnknownCase = errors.New("test: unknown scenario case") + errInvalidRepeat = errors.New("test: repeat must be >= 1") + errScenarioFailed = errors.New("test: scenario failed") +) diff --git a/internal/testcmd/junit.go b/internal/testcmd/junit.go new file mode 100644 index 0000000..bc6d1fc --- /dev/null +++ b/internal/testcmd/junit.go @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "encoding/xml" + "fmt" + "os" + "path/filepath" +) + +type junitSuite struct { + XMLName xml.Name `xml:"testsuite"` + Name string `xml:"name,attr"` + Tests int `xml:"tests,attr"` + Failures int `xml:"failures,attr"` + TestCases []junitCase `xml:"testcase"` +} + +type junitCase struct { + Classname string `xml:"classname,attr"` + Name string `xml:"name,attr"` + SystemOut string `xml:"system-out,omitempty"` + Failure *junitFailure `xml:"failure,omitempty"` +} + +type junitFailure struct { + Message string `xml:"message,attr,omitempty"` + Details string `xml:",chardata"` +} + +func writeJUnitReport(path string, results []scenarioResult) error { + if path == "" { + return nil + } + if err := ensureJUnitDir(path); err != nil { + return err + } + + suite := junitSuite{ + Name: "scp-smoke", + Tests: len(results), + Failures: countFailures(results), + } + + for _, res := range results { + caseName := fmt.Sprintf("%s_vs_%s", res.Pair.Left.Name, res.Pair.Right.Name) + if res.Iteration > 1 { + caseName = fmt.Sprintf("%s#%d", caseName, res.Iteration) + } + metricsLine := formatMetrics(res.Metrics) + jc := junitCase{ + Classname: "scp." + res.CaseName, + Name: caseName, + SystemOut: res.Details + "\n" + metricsLine, + } + if !res.Passed { + jc.Failure = &junitFailure{ + Message: "max burst threshold not met", + Details: res.Details + "\n" + metricsLine, + } + } + suite.TestCases = append(suite.TestCases, jc) + } + + data, err := xml.MarshalIndent(suite, "", " ") + if err != nil { + return err + } + data = append([]byte(xml.Header), data...) + + return os.WriteFile(filepath.Clean(path), data, 0o600) +} + +func ensureJUnitDir(path string) error { + dir := filepath.Dir(path) + if dir == "." || dir == "" { + return nil + } + + return os.MkdirAll(dir, 0o750) +} diff --git a/internal/testcmd/options.go b/internal/testcmd/options.go new file mode 100644 index 0000000..43aa092 --- /dev/null +++ b/internal/testcmd/options.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package testcmd defines harness test command configuration and helpers. +package testcmd + +import "github.com/pion/scp/internal/scp" + +const ( + DefaultPairMode = "adjacent" + DefaultTimeout = "2m" +) + +type Options struct { + LockPath string + PairMode string + IncludeNames []string + ExcludeNames []string + ExplicitPairs []string + Cases []string + Timeout string + Seed int64 + JUnitPath string + Repeat int +} + +func DefaultOptions() Options { + return Options{ + LockPath: scp.DefaultLockPath(), + PairMode: DefaultPairMode, + Timeout: DefaultTimeout, + JUnitPath: "", + Repeat: 1, + } +} diff --git a/internal/testcmd/pairs.go b/internal/testcmd/pairs.go new file mode 100644 index 0000000..35e99e8 --- /dev/null +++ b/internal/testcmd/pairs.go @@ -0,0 +1,186 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "fmt" + "sort" + "strings" + + "github.com/pion/scp/internal/scp" +) + +type pair struct { + Left scp.LockEntry + Right scp.LockEntry +} + +func buildNameSet(values []string) map[string]struct{} { + names := scp.SplitAndTrim(values) + if len(names) == 0 { + return nil + } + + result := make(map[string]struct{}, len(names)) + for _, name := range names { + result[name] = struct{}{} + } + + return result +} + +func selectEntries(entries []scp.LockEntry, include, exclude map[string]struct{}) ([]scp.LockEntry, error) { + if len(entries) == 0 { + return nil, errEmptyLock + } + if err := ensureIncluded(entries, include); err != nil { + return nil, err + } + + var filtered []scp.LockEntry + for _, entry := range entries { + if len(include) > 0 { + if _, ok := include[entry.Name]; !ok { + continue + } + } + if len(exclude) > 0 { + if _, ok := exclude[entry.Name]; ok { + continue + } + } + filtered = append(filtered, entry) + } + + if len(filtered) == 0 { + return nil, errNoSelectableEntries + } + + return filtered, nil +} + +func ensureIncluded(entries []scp.LockEntry, include map[string]struct{}) error { + if len(include) == 0 { + return nil + } + + present := make(map[string]struct{}, len(entries)) + for _, entry := range entries { + present[entry.Name] = struct{}{} + } + + var missing []string + for name := range include { + if _, ok := present[name]; !ok { + missing = append(missing, name) + } + } + + if len(missing) > 0 { + sort.Strings(missing) + + return fmt.Errorf("%w: %s", errRequestedEntryMissing, strings.Join(missing, ", ")) + } + + return nil +} + +func buildPairs(entries []scp.LockEntry, mode string, explicit []string) ([]pair, error) { + if err := validatePairs(entries, mode); err != nil { + return nil, err + } + + switch mode { + case "", DefaultPairMode: + return adjacentPairs(entries), nil + case "latest-prev": + return latestPrevPair(entries), nil + case "matrix": + return matrixPairs(entries), nil + case "explicit": + return explicitPairs(entries, explicit) + case "self": + return selfPairs(entries), nil + default: + return nil, fmt.Errorf("%w: %s", errUnknownPairMode, mode) + } +} + +func validatePairs(pairs []scp.LockEntry, mode string) error { + if len(pairs) == 0 { + return errEmptyLock + } + if len(pairs) < 2 && mode != "self" && mode != "explicit" { + return errInsufficientEntries + } + + return nil +} + +func adjacentPairs(entries []scp.LockEntry) []pair { + pairs := make([]pair, 0, len(entries)-1) + for i := 1; i < len(entries); i++ { + pairs = append(pairs, pair{Left: entries[i-1], Right: entries[i]}) + } + + return pairs +} + +func latestPrevPair(entries []scp.LockEntry) []pair { + last := len(entries) - 1 + + return []pair{{Left: entries[last-1], Right: entries[last]}} +} + +func matrixPairs(entries []scp.LockEntry) []pair { + estimated := len(entries) * (len(entries) - 1) / 2 + pairs := make([]pair, 0, estimated) + for i := 0; i < len(entries); i++ { + for j := i + 1; j < len(entries); j++ { + pairs = append(pairs, pair{Left: entries[i], Right: entries[j]}) + } + } + + return pairs +} + +func explicitPairs(entries []scp.LockEntry, specs []string) ([]pair, error) { + flattened := scp.SplitAndTrim(specs) + if len(flattened) == 0 { + return nil, errMissingExplicitPairs + } + + lookup := make(map[string]scp.LockEntry, len(entries)) + for _, entry := range entries { + lookup[entry.Name] = entry + } + + pairs := make([]pair, 0, len(flattened)) + for _, spec := range flattened { + parts := strings.SplitN(spec, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("%w: %s", errMissingExplicitPairs, spec) + } + left, ok := lookup[parts[0]] + if !ok { + return nil, fmt.Errorf("%w: %s", errRequestedEntryMissing, parts[0]) + } + right, ok := lookup[parts[1]] + if !ok { + return nil, fmt.Errorf("%w: %s", errRequestedEntryMissing, parts[1]) + } + pairs = append(pairs, pair{Left: left, Right: right}) + } + + return pairs, nil +} + +func selfPairs(entries []scp.LockEntry) []pair { + pairs := make([]pair, 0, len(entries)) + for _, entry := range entries { + pairs = append(pairs, pair{Left: entry, Right: entry}) + } + + return pairs +} diff --git a/internal/testcmd/pairs_test.go b/internal/testcmd/pairs_test.go new file mode 100644 index 0000000..7e9967b --- /dev/null +++ b/internal/testcmd/pairs_test.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "testing" + + "github.com/pion/scp/internal/scp" + "github.com/stretchr/testify/require" +) + +func TestSelectEntries(t *testing.T) { + t.Parallel() + + entries := []scp.LockEntry{ + {Name: "v1", Commit: "a"}, + {Name: "v2", Commit: "b"}, + {Name: "v3", Commit: "c"}, + } + + t.Run("Include", func(t *testing.T) { + selected, err := selectEntries(entries, nameSet("v1", "v3"), nil) + require.NoError(t, err) + require.Len(t, selected, 2) + require.Equal(t, "v1", selected[0].Name) + require.Equal(t, "v3", selected[1].Name) + }) + + t.Run("MissingInclude", func(t *testing.T) { + _, err := selectEntries(entries, nameSet("v4"), nil) + require.ErrorIs(t, err, errRequestedEntryMissing) + }) + + t.Run("Exclude", func(t *testing.T) { + selected, err := selectEntries(entries, nil, nameSet("v2")) + require.NoError(t, err) + require.Len(t, selected, 2) + require.Equal(t, []string{"v1", "v3"}, []string{selected[0].Name, selected[1].Name}) + }) +} + +func TestBuildPairs(t *testing.T) { + t.Parallel() + + entries := []scp.LockEntry{ + {Name: "v1", Commit: "a"}, + {Name: "v2", Commit: "b"}, + {Name: "v3", Commit: "c"}, + } + + t.Run("Adjacent", func(t *testing.T) { + pairs, err := buildPairs(entries, "adjacent", nil) + require.NoError(t, err) + require.Len(t, pairs, 2) + require.Equal(t, "v1", pairs[0].Left.Name) + require.Equal(t, "v2", pairs[0].Right.Name) + }) + + t.Run("LatestPrev", func(t *testing.T) { + pairs, err := buildPairs(entries, "latest-prev", nil) + require.NoError(t, err) + require.Len(t, pairs, 1) + require.Equal(t, "v2", pairs[0].Left.Name) + require.Equal(t, "v3", pairs[0].Right.Name) + }) + + t.Run("Matrix", func(t *testing.T) { + pairs, err := buildPairs(entries, "matrix", nil) + require.NoError(t, err) + require.Len(t, pairs, 3) + }) + + t.Run("Explicit", func(t *testing.T) { + pairs, err := buildPairs(entries, "explicit", []string{"v1:v3", "v2:v3"}) + require.NoError(t, err) + require.Len(t, pairs, 2) + require.Equal(t, "v1", pairs[0].Left.Name) + require.Equal(t, "v3", pairs[0].Right.Name) + }) + + t.Run("Self", func(t *testing.T) { + pairs, err := buildPairs(entries[:1], "self", nil) + require.NoError(t, err) + require.Len(t, pairs, 1) + require.Equal(t, pairs[0].Left.Name, pairs[0].Right.Name) + }) +} + +func nameSet(names ...string) map[string]struct{} { + if len(names) == 0 { + return nil + } + set := make(map[string]struct{}, len(names)) + for _, name := range names { + set[name] = struct{}{} + } + + return set +} diff --git a/internal/testcmd/run.go b/internal/testcmd/run.go new file mode 100644 index 0000000..a3f87c7 --- /dev/null +++ b/internal/testcmd/run.go @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "fmt" + + "github.com/pion/scp/internal/scp" +) + +func Run(ctx context.Context, opts Options) error { + if err := validateOptions(opts); err != nil { + return err + } + + lock, err := loadAndValidateLock(opts.LockPath) + if err != nil { + return err + } + + pairs, caseNames, err := prepareTestData(opts, lock) + if err != nil { + return err + } + + results, err := runCases(ctx, caseNames, pairs, opts.Seed, opts.Repeat) + if err != nil { + return err + } + + if err := reportResults(results, opts.JUnitPath); err != nil { + return err + } + + return checkFailures(results) +} + +func validateOptions(opts Options) error { + if opts.LockPath == "" { + return errMissingLockPath + } + if opts.Repeat <= 0 { + return errInvalidRepeat + } + + return nil +} + +func loadAndValidateLock(lockPath string) (*scp.Lockfile, error) { + lock, err := scp.ReadLock(lockPath) + if err != nil { + return nil, fmt.Errorf("test: read lock: %w", err) + } + if lock == nil || len(lock.Entries) == 0 { + return nil, errEmptyLock + } + + return lock, nil +} + +func prepareTestData(opts Options, lock *scp.Lockfile) ([]pair, []string, error) { + include := buildNameSet(opts.IncludeNames) + exclude := buildNameSet(opts.ExcludeNames) + entries, err := selectEntries(lock.Entries, include, exclude) + if err != nil { + return nil, nil, err + } + + pairs, err := buildPairs(entries, opts.PairMode, opts.ExplicitPairs) + if err != nil { + return nil, nil, err + } + + caseNames := scp.SplitAndTrim(opts.Cases) + + return pairs, caseNames, nil +} + +func reportResults(results []scenarioResult, junitPath string) error { + printResults(results) + + if junitPath != "" { + if err := writeJUnitReport(junitPath, results); err != nil { + return err + } + } + + return nil +} + +func checkFailures(results []scenarioResult) error { + if failures := countFailures(results); failures > 0 { + return fmt.Errorf("%w: %d failing cases", errScenarioFailed, failures) + } + + return nil +} + +func printResults(results []scenarioResult) { + if len(results) == 0 { + fmt.Println("test: no cases executed") + + return + } + + for _, res := range results { + label := res.CaseName + if res.Iteration > 1 { + label = fmt.Sprintf("%s#%d", label, res.Iteration) + } + fmt.Printf("[%s] %s ↔ %s :: forward=%d packets, reverse=%d packets\n", + label, + res.Pair.Left.Name, + res.Pair.Right.Name, + res.ForwardBurst, + res.ReverseBurst, + ) + if !res.Passed && res.Details != "" { + fmt.Printf(" details: %s\n", res.Details) + } + if (res.Metrics != resultMetrics{}) { + fmt.Printf(" metrics: %s\n", formatMetrics(res.Metrics)) + } + if res.Errored { + fmt.Printf(" errored=1\n") + } + } +} + +func countFailures(results []scenarioResult) int { + failures := 0 + for _, res := range results { + if !res.Passed { + failures++ + } + } + + return failures +} diff --git a/internal/testcmd/run_test.go b/internal/testcmd/run_test.go new file mode 100644 index 0000000..5dfc0f9 --- /dev/null +++ b/internal/testcmd/run_test.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/pion/scp/internal/scp" + "github.com/stretchr/testify/require" +) + +func TestRunMaxBurstWritesJUnit(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + lockPath := filepath.Join(dir, "lock.json") + lock := &scp.Lockfile{ + Entries: []scp.LockEntry{ + {Name: "v1", Selector: "tag:v1.0.0", Commit: "aaaaaaaa", Provenance: "tag"}, + {Name: "v2", Selector: "tag:v2.0.0", Commit: "bbbbbbbb", Provenance: "tag"}, + }, + } + require.NoError(t, scp.WriteLock(lockPath, lock)) + + junitPath := filepath.Join(dir, "reports", "junit.xml") + opts := Options{ + LockPath: lockPath, + PairMode: "matrix", + Cases: []string{caseMaxBurst}, + Seed: 42, + JUnitPath: junitPath, + Repeat: 1, + } + + require.NoError(t, Run(context.Background(), opts)) + data, err := os.ReadFile(filepath.Clean(junitPath)) + require.NoError(t, err) + require.Contains(t, string(data), "testsuite") +} diff --git a/internal/testcmd/smoke.go b/internal/testcmd/smoke.go new file mode 100644 index 0000000..deff37d --- /dev/null +++ b/internal/testcmd/smoke.go @@ -0,0 +1,708 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "math/rand/v2" + "net" + "sort" + "strings" + "time" + + "github.com/pion/logging" + "github.com/pion/sctp" + "github.com/pion/transport/vnet" + "golang.org/x/sys/unix" +) + +const ( + minBurstPackets = 64 + burstRange = 512 - minBurstPackets + 1 + burstPayloadOctets = 1200 + minPacketsPerSecond = 1.0 +) + +func runMaxBurstCase(ctx context.Context, pairs []pair, baseSeed int64, repeat int) ([]scenarioResult, error) { + if len(pairs) == 0 { + return nil, errInsufficientEntries + } + + resolvedSeed := baseSeed + if resolvedSeed == 0 { + resolvedSeed = deriveDefaultSeed(pairs) + } + + results := make([]scenarioResult, 0, len(pairs)*repeat) + for idx, pair := range pairs { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + for iter := range repeat { + result := scenarioResult{ + CaseName: caseMaxBurst, + Pair: pair, + Iteration: iter + 1, + } + + seq := idx*repeat + iter + if isSelfPair(pair) { + seq = iter + } + + forward, reverse, metrics, err := runBurstTraffic(ctx, pair, resolvedSeed, seq) + result.ForwardBurst = forward + result.ReverseBurst = reverse + result.Metrics = metrics + result.Passed = forward >= minBurstPackets && reverse >= minBurstPackets + result.Details = fmt.Sprintf("run=%d %s->%s=%d packets, %s->%s=%d packets", + iter+1, pair.Left.Name, pair.Right.Name, forward, pair.Right.Name, pair.Left.Name, reverse, + ) + if err != nil || !result.Passed { + result.Details += fmt.Sprintf(" err=%v", err) + } + if err != nil { + result.Errored = true + } + if metrics.PacketsPerSecond > 0 && metrics.PacketsPerSecond < minPacketsPerSecond { + result.Passed = false + result.Details += fmt.Sprintf(" rate=%.2fpps>1))) //nolint:gosec // not cryptographic purpose + target := minBurstPackets + rng.IntN(burstRange) + targetForward := target + targetReverse := target + payload := make([]byte, burstPayloadOctets) + for i := range payload { + payload[i] = byte(rng.IntN(256)) + } + // ensure payload has room for timestamp + sequence + if len(payload) < 16 { + payload = make([]byte, 16) + } + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + QueueSize: 4096, + MinDelay: profile.MinDelay, + MaxJitter: profile.MaxJitter, + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + return 0, 0, resultMetrics{}, fmt.Errorf("burst: router: %w", err) + } + if profile.DropPercent > 0 { + rng := rand.New(rand.NewPCG(uint64(seed), uint64(seed>>1))) //nolint:gosec + router.AddChunkFilter(func(c vnet.Chunk) bool { + value := rng.IntN(1000) + return float64(value)/10.0 >= profile.DropPercent + }) + } + leftNet := vnet.NewNet(&vnet.NetConfig{StaticIP: "10.0.0.1"}) + rightNet := vnet.NewNet(&vnet.NetConfig{StaticIP: "10.0.0.2"}) + if err := router.AddNet(leftNet); err != nil { + return 0, 0, resultMetrics{}, fmt.Errorf("burst: add left net: %w", err) + } + if err := router.AddNet(rightNet); err != nil { + return 0, 0, resultMetrics{}, fmt.Errorf("burst: add right net: %w", err) + } + if err := router.Start(); err != nil { + return 0, 0, resultMetrics{}, fmt.Errorf("burst: start router: %w", err) + } + defer func() { + _ = router.Stop() + }() + + startCPU := readCPUSeconds() + startTime := time.Now() + + forward, reverse, latencies, stats, err := runSCTPBurst(ctx, leftNet, rightNet, targetForward, targetReverse, payload, profile.Unordered) + + duration := time.Since(startTime) + cpu := readCPUSeconds() - startCPU + packets := forward + reverse + pps := 0.0 + if duration > 0 { + pps = float64(packets) / duration.Seconds() + } + + latP50, latP90, latP99 := computePercentiles(latencies) + metrics := resultMetrics{ + Duration: duration, + PacketsPerSecond: pps, + CPUSeconds: cpu, + LatencyP50: latP50, + LatencyP90: latP90, + LatencyP99: latP99, + BytesSent: stats.BytesSent, + BytesReceived: stats.BytesReceived, + Reordered: stats.Reordered, + Retransmitted: stats.Retransmitted, + GoodputBps: goodput(stats.BytesReceived, duration), + TailRecovery: stats.TailRecovery, + Target: target, + } + + if err != nil { + if (errors.Is(err, context.DeadlineExceeded) || isTimeoutErr(err)) && forward >= minBurstPackets && reverse >= minBurstPackets { + err = nil + } else { + return forward, reverse, metrics, err + } + } + if forward < minBurstPackets || reverse < minBurstPackets { + return forward, reverse, metrics, fmt.Errorf("burst: incomplete forward=%d reverse=%d target=%d", forward, reverse, target) + } + + return forward, reverse, metrics, nil +} + +type sctpSession struct { + clientAssoc *sctp.Association + serverAssoc *sctp.Association + clientStream *sctp.Stream + serverStream *sctp.Stream + clientConn net.Conn + serverConn net.Conn +} + +type assocResult struct { + assoc *sctp.Association + err error +} + +type streamResult struct { + stream *sctp.Stream + err error +} + +type resultMetrics struct { + Duration time.Duration + PacketsPerSecond float64 + CPUSeconds float64 + LatencyP50 time.Duration + LatencyP90 time.Duration + LatencyP99 time.Duration + BytesSent uint64 + BytesReceived uint64 + Reordered int + Retransmitted int + Dropped int + GoodputBps float64 + TailRecovery time.Duration + Target int +} + +type assocStats struct { + BytesSent uint64 + BytesReceived uint64 + Reordered int + Retransmitted int + Dropped int + TailRecovery time.Duration +} + +type receiveResult struct { + count int + latencies []time.Duration + reordered int + retrans int + tailRecovery time.Duration + dropped int +} + +func runSCTPBurst(ctx context.Context, leftNet, rightNet *vnet.Net, forwardPackets, reversePackets int, payload []byte, unordered bool) (int, int, []time.Duration, assocStats, error) { + session, err := establishSCTPSession(ctx, leftNet, rightNet) + if err != nil { + return 0, 0, nil, assocStats{}, err + } + defer session.close() + + if err := warmupStreams(ctx, session.clientStream, session.serverStream); err != nil { + var netErr net.Error + if !(isTimeoutErr(err) || (errors.As(err, &netErr) && netErr.Timeout())) { + return 0, 0, nil, assocStats{}, fmt.Errorf("sctp: warmup: %w", err) + } + } + + forwardCh := make(chan receiveResult, 1) + reverseCh := make(chan receiveResult, 1) + + go func() { + count, lats, reorder, retrans, tail, dropped := receivePackets(ctx, session.serverStream, forwardPackets, len(payload)) + forwardCh <- receiveResult{count: count, latencies: lats, reordered: reorder, retrans: retrans, tailRecovery: tail, dropped: dropped} + }() + go func() { + count, lats, reorder, retrans, tail, dropped := receivePackets(ctx, session.clientStream, reversePackets, len(payload)) + reverseCh <- receiveResult{count: count, latencies: lats, reordered: reorder, retrans: retrans, tailRecovery: tail, dropped: dropped} + }() + + if unordered { + session.clientStream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, 0) //nolint:errcheck + session.serverStream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, 0) //nolint:errcheck + } + + sendErr := transmitPackets(ctx, session.clientStream, forwardPackets, payload) + sendErr = errors.Join(sendErr, transmitPackets(ctx, session.serverStream, reversePackets, payload)) + + forwardRes := <-forwardCh + reverseRes := <-reverseCh + forward := forwardRes.count + reverse := reverseRes.count + forwardLat := forwardRes.latencies + reverseLat := reverseRes.latencies + + if sendErr != nil { + return forward, reverse, append(forwardLat, reverseLat...), collectStats(session, forwardRes, reverseRes), sendErr + } + if ctxErr := ctx.Err(); ctxErr != nil { + return forward, reverse, append(forwardLat, reverseLat...), collectStats(session, forwardRes, reverseRes), ctxErr + } + + return forward, reverse, append(forwardLat, reverseLat...), collectStats(session, forwardRes, reverseRes), nil +} + +func establishSCTPSession(ctx context.Context, leftNet, rightNet *vnet.Net) (*sctpSession, error) { + serverAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 5000} + clientAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5001} + + serverConn, err := rightNet.DialUDP("udp4", serverAddr, clientAddr) + if err != nil { + return nil, fmt.Errorf("sctp: server dial: %w", err) + } + + clientConn, err := leftNet.DialUDP("udp4", clientAddr, serverAddr) + if err != nil { + serverConn.Close() + return nil, fmt.Errorf("sctp: client dial: %w", err) + } + + serverConfig := sctp.Config{ + NetConn: serverConn, + LoggerFactory: logging.NewDefaultLoggerFactory(), + } + clientConfig := sctp.Config{ + NetConn: clientConn, + LoggerFactory: logging.NewDefaultLoggerFactory(), + } + + serverAssocCh := make(chan *sctp.Association, 1) + serverErrCh := make(chan error, 1) + go func() { + assoc, serveErr := sctp.Server(serverConfig) + if serveErr != nil { + serverErrCh <- serveErr + return + } + serverAssocCh <- assoc + }() + + clientAssocCh := make(chan assocResult, 1) + go func() { + assoc, assocErr := sctp.Client(clientConfig) + clientAssocCh <- assocResult{assoc: assoc, err: assocErr} + }() + + var clientAssoc *sctp.Association + select { + case res := <-clientAssocCh: + if res.err != nil { + serverConn.Close() + clientConn.Close() + return nil, fmt.Errorf("sctp: client: %w", res.err) + } + clientAssoc = res.assoc + case <-ctx.Done(): + serverConn.Close() + clientConn.Close() + return nil, ctx.Err() + } + + var serverAssoc *sctp.Association + select { + case serverErr := <-serverErrCh: + clientAssoc.Close() + serverConn.Close() + clientConn.Close() + return nil, fmt.Errorf("sctp: server: %w", serverErr) + case serverAssoc = <-serverAssocCh: + case <-ctx.Done(): + clientAssoc.Close() + serverConn.Close() + clientConn.Close() + return nil, ctx.Err() + } + + clientStreamCh := make(chan streamResult, 1) + go func() { + stream, streamErr := clientAssoc.OpenStream(0, sctp.PayloadTypeWebRTCBinary) + clientStreamCh <- streamResult{stream: stream, err: streamErr} + }() + serverStreamCh := make(chan streamResult, 1) + go func() { + stream, streamErr := serverAssoc.AcceptStream() + serverStreamCh <- streamResult{stream: stream, err: streamErr} + }() + + var clientStream *sctp.Stream + var serverStream *sctp.Stream + for clientStream == nil || serverStream == nil { + select { + case res := <-clientStreamCh: + if res.err != nil { + serverAssoc.Close() + clientAssoc.Close() + serverConn.Close() + clientConn.Close() + return nil, fmt.Errorf("sctp: open stream: %w", res.err) + } + clientStream = res.stream + // Kick the server side by sending an initial warmup packet. + _ = clientStream.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) + _, _ = clientStream.Write(make([]byte, 8)) + case res := <-serverStreamCh: + if res.err != nil { + serverAssoc.Close() + clientAssoc.Close() + serverConn.Close() + clientConn.Close() + return nil, fmt.Errorf("sctp: accept stream: %w", res.err) + } + serverStream = res.stream + case <-ctx.Done(): + serverAssoc.Close() + clientAssoc.Close() + serverConn.Close() + clientConn.Close() + return nil, ctx.Err() + } + } + + return &sctpSession{ + clientAssoc: clientAssoc, + serverAssoc: serverAssoc, + clientStream: clientStream, + serverStream: serverStream, + clientConn: clientConn, + serverConn: serverConn, + }, nil +} + +func (s *sctpSession) close() { + _ = s.clientStream.Close() + _ = s.serverStream.Close() + _ = s.clientAssoc.Close() + _ = s.serverAssoc.Close() + _ = s.clientConn.Close() + _ = s.serverConn.Close() +} + +func warmupStreams(ctx context.Context, clientStream, serverStream *sctp.Stream) error { + handshake := []byte("warmup") + const attempts = 3 + for i := 0; i < attempts; i++ { + if err := sendOne(ctx, clientStream, handshake); err != nil { + if isTimeoutErr(err) { + continue + } + return err + } + if err := recvOne(serverStream, len(handshake)); err != nil { + if isTimeoutErr(err) { + continue + } + return err + } + if err := sendOne(ctx, serverStream, handshake); err != nil { + if isTimeoutErr(err) { + continue + } + return err + } + if err := recvOne(clientStream, len(handshake)); err != nil { + if isTimeoutErr(err) { + continue + } + return err + } + return nil + } + + return fmt.Errorf("warmup: exceeded retries") +} + +func sendOne(ctx context.Context, stream *sctp.Stream, payload []byte) error { + if err := stream.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + if _, err := stream.Write(payload); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return err + } + } + + return err + } + + return nil +} + +func recvOne(stream *sctp.Stream, payloadSize int) error { + buf := make([]byte, payloadSize+16) + if err := stream.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + if _, err := stream.Read(buf); err != nil { + return err + } + + return nil +} + +func receivePackets(ctx context.Context, stream *sctp.Stream, packets int, payloadSize int) (int, []time.Duration, int, int, time.Duration, int) { + buf := make([]byte, payloadSize+16) + count := 0 + latencies := make([]time.Duration, 0, packets) + seen := make(map[uint64]int, packets) + expectedSeq := uint64(1) + reordered := 0 + retrans := 0 + var lastSendTS int64 + var lastRecvTS time.Time + dropped := 0 + for count < packets { + _ = stream.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + n, err := stream.Read(buf) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + select { + case <-ctx.Done(): + return count, latencies, reordered, retrans, recvTail(lastSendTS, lastRecvTS), dropped + default: + continue + } + } + + return count, latencies, reordered, retrans, recvTail(lastSendTS, lastRecvTS), dropped + } + count++ + if n >= 8 { + sendTS := int64(binary.LittleEndian.Uint64(buf[:8])) + if sendTS > 0 { + latencies = append(latencies, time.Since(time.Unix(0, sendTS))) + if sendTS > lastSendTS { + lastSendTS = sendTS + } + lastRecvTS = time.Now() + } + } + if n >= 16 { + seq := binary.LittleEndian.Uint64(buf[8:16]) + if seq != expectedSeq { + reordered++ + } + if seen[seq] > 0 { + retrans++ + } + seen[seq]++ + if seq == expectedSeq { + expectedSeq++ + } + if seq > uint64(packets) { + dropped++ + } + } + } + + return count, latencies, reordered, retrans, recvTail(lastSendTS, lastRecvTS), dropped +} + +func recvTail(lastSendTS int64, lastRecvTS time.Time) time.Duration { + if lastSendTS == 0 || lastRecvTS.IsZero() { + return 0 + } + return lastRecvTS.Sub(time.Unix(0, lastSendTS)) +} + +func transmitPackets(ctx context.Context, stream *sctp.Stream, packets int, payload []byte) error { + seq := uint64(1) + for i := 0; i < packets; { + if err := stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { + return fmt.Errorf("burst: set write deadline: %w", err) + } + now := uint64(time.Now().UnixNano()) + binary.LittleEndian.PutUint64(payload[:8], now) + binary.LittleEndian.PutUint64(payload[8:16], seq) + if _, err := stream.Write(payload); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + continue + } + } + + return fmt.Errorf("burst: write: %w", err) + } + i++ + seq++ + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + + return nil +} + +func derivePairSeed(base int64, p pair, idx int) int64 { + payload := fmt.Sprintf("%d:%s:%s:%s:%s:%d", base, p.Left.Name, p.Left.Commit, p.Right.Name, p.Right.Commit, idx) + sum := sha256.Sum256([]byte(payload)) + + return int64(binary.LittleEndian.Uint64(sum[:8])) //nolint:gosec // not cryptographic purpose +} + +func deriveSelfSeed(base int64, idx int) int64 { + payload := fmt.Sprintf("%d:self:%d", base, idx) + sum := sha256.Sum256([]byte(payload)) + + return int64(binary.LittleEndian.Uint64(sum[:8])) //nolint:gosec // not cryptographic purpose +} + +func isSelfPair(p pair) bool { + return p.Left.Name == p.Right.Name && p.Left.Commit == p.Right.Commit +} + +func isTimeoutErr(err error) bool { + var netErr net.Error + return errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) +} + +func collectStats(session *sctpSession, forward receiveResult, reverse receiveResult) assocStats { + return assocStats{ + BytesSent: session.clientAssoc.BytesSent() + session.serverAssoc.BytesSent(), + BytesReceived: session.clientAssoc.BytesReceived() + session.serverAssoc.BytesReceived(), + Reordered: forward.reordered + reverse.reordered, + Retransmitted: forward.retrans + reverse.retrans, + Dropped: forward.dropped + reverse.dropped, + TailRecovery: maxDuration(forward.tailRecovery, reverse.tailRecovery), + } +} + +func computePercentiles(latencies []time.Duration) (time.Duration, time.Duration, time.Duration) { + if len(latencies) == 0 { + return 0, 0, 0 + } + values := make([]time.Duration, len(latencies)) + copy(values, latencies) + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + + p50 := values[len(values)*50/100] + p90 := values[len(values)*90/100] + p99 := values[len(values)*99/100] + + return p50, p90, p99 +} + +func readCPUSeconds() float64 { + var ru unix.Rusage + if err := unix.Getrusage(unix.RUSAGE_SELF, &ru); err != nil { + return 0 + } + + user := float64(ru.Utime.Sec) + float64(ru.Utime.Usec)/1_000_000 + sys := float64(ru.Stime.Sec) + float64(ru.Stime.Usec)/1_000_000 + + return user + sys +} + +func formatMetrics(m resultMetrics) string { + return fmt.Sprintf("duration=%s pps=%.2f cpu=%.4fs p50=%s p90=%s p99=%s bytes_sent=%d bytes_recv=%d dropped=%d reordered=%d retrans=%d goodput=%.2fbps tail=%s target=%d", + m.Duration, + m.PacketsPerSecond, + m.CPUSeconds, + m.LatencyP50, + m.LatencyP90, + m.LatencyP99, + m.BytesSent, + m.BytesReceived, + m.Dropped, + m.Reordered, + m.Retransmitted, + m.GoodputBps, + m.TailRecovery, + m.Target, + ) +} + +func goodput(bytes uint64, d time.Duration) float64 { + if d <= 0 { + return 0 + } + return float64(bytes) * 8 / d.Seconds() +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +func deriveDefaultSeed(pairs []pair) int64 { + var builder strings.Builder + for _, p := range pairs { + builder.WriteString(p.Left.Name) + builder.WriteByte(':') + builder.WriteString(p.Left.Commit) + builder.WriteByte('|') + builder.WriteString(p.Right.Name) + builder.WriteByte(':') + builder.WriteString(p.Right.Commit) + builder.WriteByte(';') + } + sum := sha256.Sum256([]byte(builder.String())) + seed := int64(binary.LittleEndian.Uint64(sum[:8])) //nolint:gosec // not cryptographic purpose + if seed == 0 { + seed = time.Now().UnixNano() + } + + return seed +} diff --git a/internal/testcmd/smoke_test.go b/internal/testcmd/smoke_test.go new file mode 100644 index 0000000..9745bc9 --- /dev/null +++ b/internal/testcmd/smoke_test.go @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package testcmd + +import ( + "context" + "testing" + + "github.com/pion/scp/internal/scp" + "github.com/stretchr/testify/require" +) + +func TestMaxBurstDeterministicWithoutSeed(t *testing.T) { + t.Parallel() + + pairs := []pair{{ + Left: scp.LockEntry{Name: "v1", Commit: "aaaaaaaa"}, + Right: scp.LockEntry{Name: "v2", Commit: "bbbbbbbb"}, + }} + + first, err := runCases(context.Background(), []string{caseMaxBurst}, pairs, 0, 1) + require.NoError(t, err) + second, err := runCases(context.Background(), []string{caseMaxBurst}, pairs, 0, 1) + require.NoError(t, err) + require.Len(t, first, len(second)) + for i := range first { + require.Equal(t, first[i].ForwardBurst, second[i].ForwardBurst) + require.Equal(t, first[i].ReverseBurst, second[i].ReverseBurst) + require.Equal(t, first[i].Iteration, second[i].Iteration) + } +} + +func TestMaxBurstRepeat(t *testing.T) { + t.Parallel() + + pairs := []pair{{ + Left: scp.LockEntry{Name: "v1", Commit: "aaaaaaaa"}, + Right: scp.LockEntry{Name: "v2", Commit: "bbbbbbbb"}, + }} + + results, err := runCases(context.Background(), []string{caseMaxBurst}, pairs, 123, 2) + require.NoError(t, err) + require.Len(t, results, 2) + require.Equal(t, 1, results[0].Iteration) + require.Equal(t, 2, results[1].Iteration) + require.NotEqual(t, results[0].ForwardBurst, results[1].ForwardBurst) + require.NotEqual(t, results[0].ReverseBurst, results[1].ReverseBurst) +} + +func TestMaxBurstSelfRepeatStable(t *testing.T) { + t.Parallel() + + entry := scp.LockEntry{Name: "v1", Commit: "aaaaaaaa"} + pairs := []pair{ + {Left: entry, Right: entry}, + {Left: scp.LockEntry{Name: "v2", Commit: "aaaaaaaa"}, Right: scp.LockEntry{Name: "v2", Commit: "aaaaaaaa"}}, + } + + results, err := runCases(context.Background(), []string{caseMaxBurst}, pairs, 123, 2) + require.NoError(t, err) + require.Len(t, results, 4) + + // self pairs should have symmetric forward/reverse per run + require.Equal(t, results[0].ForwardBurst, results[0].ReverseBurst) + require.Equal(t, results[1].ForwardBurst, results[1].ReverseBurst) + require.Equal(t, results[2].ForwardBurst, results[2].ReverseBurst) + require.Equal(t, results[3].ForwardBurst, results[3].ReverseBurst) + + // iteration N should be identical across different self pairs + require.Equal(t, results[0].ForwardBurst, results[2].ForwardBurst) + require.Equal(t, results[1].ForwardBurst, results[3].ForwardBurst) + + // iterations should differ from each other + require.NotEqual(t, results[0].ForwardBurst, results[1].ForwardBurst) +} diff --git a/internal/update/errors.go b/internal/update/errors.go new file mode 100644 index 0000000..4e9abdc --- /dev/null +++ b/internal/update/errors.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package update + +import "errors" + +var ( + errMissingManifestPath = errors.New("update: manifest path is required") + errMissingLockPath = errors.New("update: lock path is required") + errNotImplemented = errors.New("update: implementation pending") +) diff --git a/internal/update/options.go b/internal/update/options.go new file mode 100644 index 0000000..de98ed5 --- /dev/null +++ b/internal/update/options.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package update + +import "github.com/pion/scp/internal/scp" + +type Options struct { + ManifestPath string + LockPath string + OnlyNames []string + FreezeAt string +} + +func DefaultOptions() Options { + return Options{ + ManifestPath: scp.DefaultManifestPath(), + LockPath: scp.DefaultLockPath(), + } +} diff --git a/internal/update/run.go b/internal/update/run.go new file mode 100644 index 0000000..58886f5 --- /dev/null +++ b/internal/update/run.go @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package update + +import "context" + +func Run(ctx context.Context, opts Options) error { + if opts.ManifestPath == "" { + return errMissingManifestPath + } + if opts.LockPath == "" { + return errMissingLockPath + } + + return errNotImplemented +} diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 0000000..42a92e3 --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package update refreshes lockfiles from manifests with floating selectors. +package update