diff --git a/cmd/wait-for-github/ci.go b/cmd/wait-for-github/ci.go index 7a3d30a..f68b44c 100644 --- a/cmd/wait-for-github/ci.go +++ b/cmd/wait-for-github/ci.go @@ -44,6 +44,14 @@ var ( commitRegexp = regexp.MustCompile(`.*github\.com/(?P[^/]+)/(?P[^/]+)/commit/(?P[abcdef\d]+)/?.*`) ) +type ErrInvalidCommitURL struct { + url string +} + +func (e ErrInvalidCommitURL) Error() string { + return fmt.Sprintf("invalid commit URL: %s", e.url) +} + func parseCIArguments(c *cli.Context) (ciConfig, error) { var owner, repo, ref string @@ -53,7 +61,7 @@ func parseCIArguments(c *cli.Context) (ciConfig, error) { url := c.Args().Get(0) match := commitRegexp.FindStringSubmatch(url) if match == nil { - return ciConfig{}, fmt.Errorf("invalid commit URL: %s", url) + return ciConfig{}, ErrInvalidCommitURL{url} } owner = match[1] @@ -71,10 +79,12 @@ func parseCIArguments(c *cli.Context) (ciConfig, error) { // but it doesn't work, says "unknown command ci". So we go through the parent command. lineage := c.Lineage() parent := lineage[1] - cli.ShowCommandHelpAndExit(parent, "ci", 1) + err := cli.ShowCommandHelp(parent, "ci") + if err != nil { + return ciConfig{}, err + } - // shouldn't get here, the previous line should exit - return ciConfig{}, nil + return ciConfig{}, cli.Exit("invalid number of arguments", 1) } return ciConfig{ @@ -143,12 +153,7 @@ func (ci checkSpecificCI) Check(ctx context.Context, recheckInterval time.Durati return handleCIStatus(status, recheckInterval) } -func checkCIStatus(timeoutCtx context.Context, cfg *config, ciConf *ciConfig) error { - githubClient, err := github.NewGithubClient(timeoutCtx, cfg.AuthInfo) - if err != nil { - return err - } - +func checkCIStatus(timeoutCtx context.Context, githubClient github.CheckCIStatus, cfg *config, ciConf *ciConfig) error { log.Infof("Checking CI status on %s/%s@%s", ciConf.owner, ciConf.repo, ciConf.ref) all := checkAllCI{ @@ -184,7 +189,14 @@ func ciCommand(cfg *config) *cli.Command { return err }, - Action: func(c *cli.Context) error { return checkCIStatus(c.Context, cfg, &ciConf) }, + Action: func(c *cli.Context) error { + githubClient, err := github.NewGithubClient(c.Context, cfg.AuthInfo) + if err != nil { + return err + } + + return checkCIStatus(c.Context, githubClient, cfg, &ciConf) + }, Flags: []cli.Flag{ &cli.StringSliceFlag{ Name: "check", diff --git a/cmd/wait-for-github/ci_test.go b/cmd/wait-for-github/ci_test.go new file mode 100644 index 0000000..58534d8 --- /dev/null +++ b/cmd/wait-for-github/ci_test.go @@ -0,0 +1,229 @@ +// wait-for-github +// Copyright (C) 2023, Grafana Labs + +// This program is free software: you can redistribute it and/or modify it under +// the terms of the GNU Affero General Public License as published by the Free +// Software Foundation, either version 3 of the License, or (at your option) any +// later version. + +// This program is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +// details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "context" + "flag" + "testing" + "time" + + "github.com/grafana/wait-for-github/internal/github" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" +) + +// FakeCIStatusChecker implements the CheckCIStatus interface. +type FakeCIStatusChecker struct { + status github.CIStatus + err error +} + +func (c *FakeCIStatusChecker) GetCIStatus(ctx context.Context, owner, repo string, commitHash string) (github.CIStatus, error) { + return c.status, c.err +} + +func (c *FakeCIStatusChecker) GetCIStatusForChecks(ctx context.Context, owner, repo string, commitHash string, checkNames []string) (github.CIStatus, []string, error) { + return c.status, checkNames, c.err +} + +var ( + zero = 0 + one = 1 +) + +func TestHandleCIStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status github.CIStatus + expectedExitCode *int + }{ + { + name: "passed", + status: github.CIStatusPassed, + expectedExitCode: &zero, + }, + { + name: "failed", + status: github.CIStatusFailed, + expectedExitCode: &one, + }, + { + name: "pending", + status: github.CIStatusPending, + expectedExitCode: nil, + }, + { + name: "unknown", + status: github.CIStatusUnknown, + expectedExitCode: nil, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + output := handleCIStatus(tt.status, 1) + if tt.expectedExitCode == nil { + require.Nil(t, output) + } else { + require.NotNil(t, output) + require.Equal(t, *tt.expectedExitCode, output.ExitCode()) + } + }) + } +} +func TestCheckCIStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + checks []string + status github.CIStatus + err error + recheckInterval time.Duration + expectedExitCode *int + }{ + { + name: "All checks pass", + checks: []string{}, + status: github.CIStatusPassed, + err: cli.Exit("CI successful", 0), + recheckInterval: time.Second * 2, + expectedExitCode: &zero, + }, + { + name: "Specific checks pass", + checks: []string{"check1", "check2"}, + status: github.CIStatusPassed, + err: cli.Exit("CI successful", 0), + expectedExitCode: &zero, + }, + { + name: "All checks fail", + checks: []string{}, + status: github.CIStatusFailed, + err: cli.Exit("CI failed", 1), + expectedExitCode: &one, + }, + { + name: "Specific checks fail", + checks: []string{"check1", "check2"}, + status: github.CIStatusFailed, + err: cli.Exit("CI failed", 1), + expectedExitCode: &one, + }, + { + name: "All checks pending", + checks: []string{}, + status: github.CIStatusPending, + err: nil, + recheckInterval: 1, + expectedExitCode: &one, + }, + { + name: "Specific checks pending", + checks: []string{"check1", "check2"}, + status: github.CIStatusPending, + err: nil, + recheckInterval: 1, + expectedExitCode: &one, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeCIStatusChecker := &FakeCIStatusChecker{status: tt.status, err: tt.err} + cfg := &config{ + recheckInterval: 1, + } + ciConf := &ciConfig{ + owner: "owner", + repo: "repo", + ref: "ref", + checks: tt.checks, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := checkCIStatus(ctx, fakeCIStatusChecker, cfg, ciConf) + + var exitErr cli.ExitCoder + require.ErrorAs(t, err, &exitErr) + require.Equal(t, *tt.expectedExitCode, exitErr.ExitCode()) + }) + } +} + +func TestParseCIArguments(t *testing.T) { + tests := []struct { + name string + args []string + want ciConfig + wantErr error + }{ + { + name: "Valid commit URL", + args: []string{"https://github.com/owner/repo/commit/abc123"}, + want: ciConfig{ + owner: "owner", + repo: "repo", + ref: "abc123", + }, + }, + { + name: "Valid arguments owner, repo, ref", + args: []string{"owner", "repo", "abc123"}, + want: ciConfig{ + owner: "owner", + repo: "repo", + ref: "abc123", + }, + }, + { + name: "Invalid commit URL", + args: []string{"https://invalid_url"}, + wantErr: ErrInvalidCommitURL{}, + }, + { + name: "Invalid number of arguments", + args: []string{"owner", "repo"}, + wantErr: cli.Exit("invalid number of arguments", 1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + err := flagSet.Parse(tt.args) + require.NoError(t, err) + parentCliContext := cli.NewContext(nil, nil, nil) + parentCliContext.App = cli.NewApp() + cliContext := cli.NewContext(nil, flagSet, parentCliContext) + + got, err := parseCIArguments(cliContext) + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + } + require.Equal(t, tt.want, got) + }) + } +}