From 2960e9fc86e1bed2a3f819078ba2a77b05ac7095 Mon Sep 17 00:00:00 2001 From: Iain Lane Date: Wed, 28 Jun 2023 15:14:30 +0100 Subject: [PATCH] pr: Add testing Some small refactoring to allow for it, nothing major. --- cmd/wait-for-github/pr.go | 49 ++++-- cmd/wait-for-github/pr_test.go | 265 +++++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 12 deletions(-) create mode 100644 cmd/wait-for-github/pr_test.go diff --git a/cmd/wait-for-github/pr.go b/cmd/wait-for-github/pr.go index dffac08..2111318 100644 --- a/cmd/wait-for-github/pr.go +++ b/cmd/wait-for-github/pr.go @@ -31,12 +31,25 @@ import ( "github.com/urfave/cli/v2" ) +// for testing. We could use `io.Writer`, but then we have to handle opening and +// closing the file. +type fileWriter interface { + WriteFile(filename string, data []byte, perm os.FileMode) error +} + +type osFileWriter struct{} + +func (f osFileWriter) WriteFile(filename string, data []byte, perm os.FileMode) error { + return os.WriteFile(filename, data, perm) +} + type prConfig struct { owner string repo string pr int commitInfoFile string + writer fileWriter } var ( @@ -44,6 +57,14 @@ var ( pullRequestRegexp = regexp.MustCompile(`.*github\.com/(?P[^/]+)/(?P[^/]+)/pull/(?P\d+)/?.*`) ) +type ErrInvalidPRURL struct { + url string +} + +func (e ErrInvalidPRURL) Error() string { + return fmt.Sprintf("invalid pull request URL: %s", e.url) +} + func parsePRArguments(c *cli.Context) (prConfig, error) { var owner, repo, number string @@ -53,7 +74,7 @@ func parsePRArguments(c *cli.Context) (prConfig, error) { url := c.Args().Get(0) match := pullRequestRegexp.FindStringSubmatch(url) if match == nil { - return prConfig{}, fmt.Errorf("invalid pull request URL: %s", url) + return prConfig{}, ErrInvalidPRURL{url} } owner = match[1] @@ -71,10 +92,12 @@ func parsePRArguments(c *cli.Context) (prConfig, error) { // but it doesn't work, says "unknown command pr". So we go through the parent command. lineage := c.Lineage() parent := lineage[1] - cli.ShowCommandHelpAndExit(parent, "pr", 1) + err := cli.ShowCommandHelp(parent, "pr") + if err != nil { + return prConfig{}, err + } - // shouldn't get here, the previous line should exit - return prConfig{}, nil + return prConfig{}, cli.Exit("invalid number of arguments", 1) } n, err := strconv.Atoi(number) @@ -88,6 +111,7 @@ func parsePRArguments(c *cli.Context) (prConfig, error) { repo: repo, pr: n, commitInfoFile: c.String("commit-info-file"), + writer: osFileWriter{}, }, nil } @@ -126,7 +150,7 @@ func (pr prCheck) Check(ctx context.Context, recheckInterval time.Duration) erro } log.Debugf("Writing commit info to file %s", pr.commitInfoFile) - if err := os.WriteFile(pr.commitInfoFile, jsonCommit, 0644); err != nil { + if err := pr.writer.WriteFile(pr.commitInfoFile, jsonCommit, 0644); err != nil { return fmt.Errorf("failed to write commit info to file: %w", err) } } @@ -141,12 +165,7 @@ func (pr prCheck) Check(ctx context.Context, recheckInterval time.Duration) erro return nil } -func checkPRMerged(timeoutCtx context.Context, cfg *config, prConf *prConfig) error { - githubClient, err := github.NewGithubClient(timeoutCtx, cfg.AuthInfo) - if err != nil { - return err - } - +func checkPRMerged(timeoutCtx context.Context, githubClient github.CheckPRMerged, cfg *config, prConf *prConfig) error { checkPRMergedOrClosed := prCheck{ githubClient: githubClient, prConfig: *prConf, @@ -175,6 +194,12 @@ func prCommand(cfg *config) *cli.Command { return err }, - Action: func(c *cli.Context) error { return checkPRMerged(c.Context, cfg, &prConf) }, + Action: func(c *cli.Context) error { + githubClient, err := github.NewGithubClient(c.Context, cfg.AuthInfo) + if err != nil { + return err + } + return checkPRMerged(c.Context, githubClient, cfg, &prConf) + }, } } diff --git a/cmd/wait-for-github/pr_test.go b/cmd/wait-for-github/pr_test.go new file mode 100644 index 0000000..49e75e0 --- /dev/null +++ b/cmd/wait-for-github/pr_test.go @@ -0,0 +1,265 @@ +// wait-for-github +// Copyright (C) 2023, Grafana Labs + +// This program is free software: you can redistribute it and/or modify it under +// the terms of the GNU Affero General Public License as published by the Free +// Software Foundation, either version 3 of the License, or (at your option) any +// later version. + +// This program is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +// details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io/fs" + "testing" + + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" +) + +// FakeGithubClientPRCheck implements CheckPRMerged +type FakeGithubClientPRCheck struct { + MergedCommit string + Closed bool + MergedAt int64 + Error error +} + +func (fg *FakeGithubClientPRCheck) IsPRMergedOrClosed(ctx context.Context, owner, repo string, pr int) (string, bool, int64, error) { + return fg.MergedCommit, fg.Closed, fg.MergedAt, fg.Error +} + +var ( + zero = 0 + one = 1 +) + +func TestPRCheck(t *testing.T) { + tests := []struct { + name string + fakeClient FakeGithubClientPRCheck + err error + expectedExitCode *int + }{ + { + name: "PR is merged", + fakeClient: FakeGithubClientPRCheck{ + MergedCommit: "abc123", + MergedAt: 1234567890, + }, + expectedExitCode: &zero, + }, + { + name: "PR is closed", + fakeClient: FakeGithubClientPRCheck{ + Closed: true, + }, + expectedExitCode: &one, + }, + { + name: "PR is open", + fakeClient: FakeGithubClientPRCheck{ + MergedCommit: "", + Closed: false, + }, + expectedExitCode: &one, + }, + { + name: "Error from IsPRMergedOrClosed", + fakeClient: FakeGithubClientPRCheck{ + Error: fmt.Errorf("an error occurred"), + }, + err: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakePRStatusChecker := &tt.fakeClient + cfg := &config{ + recheckInterval: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 1) + cancel() + + prConfig := prConfig{ + owner: "owner", + repo: "repo", + pr: 1, + } + + err := checkPRMerged(ctx, fakePRStatusChecker, cfg, &prConfig) + if tt.expectedExitCode != nil { + var exitErr cli.ExitCoder + require.ErrorAs(t, err, &exitErr) + require.Equal(t, *tt.expectedExitCode, exitErr.ExitCode()) + } else if err != nil { + require.ErrorAs(t, err, &tt.err) + } + }) + } +} + +type fakeFileWriter struct { + filename string + data []byte + perm fs.FileMode + + err error +} + +func (f *fakeFileWriter) WriteFile(filename string, data []byte, perm fs.FileMode) error { + f.filename = filename + f.data = data + f.perm = perm + + return f.err +} + +func TestWriteCommitInfoFile(t *testing.T) { + prConfig := prConfig{ + owner: "owner", + repo: "repo", + pr: 1, + + commitInfoFile: "commit_info_file", + writer: &fakeFileWriter{}, + } + + prCheck := &prCheck{ + prConfig: prConfig, + + githubClient: &FakeGithubClientPRCheck{ + MergedCommit: "abc123", + MergedAt: 1234567890, + }, + } + + err := prCheck.Check(context.TODO(), 1) + var cliExitErr cli.ExitCoder + require.ErrorAs(t, err, &cliExitErr) + require.Equal(t, 0, cliExitErr.ExitCode()) + + ffw := prConfig.writer.(*fakeFileWriter) + + var gotCommitInfo commitInfo + err = json.Unmarshal(ffw.data, &gotCommitInfo) + require.NoError(t, err) + + require.Equal(t, "commit_info_file", ffw.filename) + require.Equal(t, commitInfo{ + Owner: "owner", + Repo: "repo", + Commit: "abc123", + MergedAt: 1234567890, + }, gotCommitInfo) + require.Equal(t, fs.FileMode(0644), ffw.perm) +} + +type erroringFileWriter struct{} + +func (e erroringFileWriter) WriteFile(filename string, data []byte, perm fs.FileMode) error { + return fmt.Errorf("error") +} + +func TestWriteCommitInfoFileError(t *testing.T) { + prConfig := prConfig{ + owner: "owner", + repo: "repo", + pr: 1, + + commitInfoFile: "commit_info_file", + writer: erroringFileWriter{}, + } + + prCheck := &prCheck{ + prConfig: prConfig, + + githubClient: &FakeGithubClientPRCheck{ + MergedCommit: "abc123", + MergedAt: 1234567890, + }, + } + + err := prCheck.Check(context.Background(), 1) + require.Error(t, err) +} + +func TestParsePRArguments(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + want prConfig + wantErr error + }{ + { + name: "Valid pull request URL", + args: []string{"https://github.com/owner/repo/pull/1"}, + want: prConfig{ + owner: "owner", + repo: "repo", + pr: 1, + writer: osFileWriter{}, + }, + }, + { + name: "Valid arguments owner, repo, pr", + args: []string{"owner", "repo", "1"}, + want: prConfig{ + owner: "owner", + repo: "repo", + pr: 1, + writer: osFileWriter{}, + }, + }, + { + name: "Invalid pull request URL", + args: []string{"https://invalid_url"}, + wantErr: ErrInvalidPRURL{}, + }, + { + name: "Invalid number of arguments", + args: []string{"owner", "repo"}, + wantErr: cli.Exit("invalid number of arguments", 1), + }, + { + name: "Invalid PR number", + args: []string{"owner", "repo", "invalid"}, + wantErr: cli.Exit("invalid PR number", 1), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + err := flagSet.Parse(tt.args) + require.NoError(t, err) + parentCliContext := cli.NewContext(nil, nil, nil) + parentCliContext.App = cli.NewApp() + cliContext := cli.NewContext(nil, flagSet, parentCliContext) + + got, err := parsePRArguments(cliContext) + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + require.Equal(t, err, tt.wantErr) + } + require.Equal(t, tt.want, got) + }) + } +}