diff --git a/bitbucket/bitbucket.go b/bitbucket/bitbucket.go index d3f6ff9..0f06097 100644 --- a/bitbucket/bitbucket.go +++ b/bitbucket/bitbucket.go @@ -17,6 +17,23 @@ var bitbucketAPI = "https://api.bitbucket.org/2.0" // setBitbucketAPI overrides the Bitbucket API base URL (for testing). func setBitbucketAPI(url string) { bitbucketAPI = url } +type bbCloneLink struct { + Href string `json:"href"` + Name string `json:"name"` // "https" or "ssh" +} + +func parseCloneURLs(links []bbCloneLink) (cloneURL, sshURL string) { + for _, link := range links { + switch link.Name { + case "https": + cloneURL = link.Href + case "ssh": + sshURL = link.Href + } + } + return +} + type bitbucketForge struct { token string httpClient *http.Client @@ -69,10 +86,7 @@ type bbRepository struct { Avatar struct { Href string `json:"href"` } `json:"avatar"` - Clone []struct { - Href string `json:"href"` - Name string `json:"name"` - } `json:"clone"` + Clone []bbCloneLink `json:"clone"` } `json:"links"` CreatedOn string `json:"created_on"` UpdatedOn string `json:"updated_on"` @@ -157,14 +171,7 @@ func convertBitbucketRepo(bb bbRepository) forge.Repository { LogoURL: bb.Links.Avatar.Href, } - for _, c := range bb.Links.Clone { - switch c.Name { - case "https": - result.CloneURL = c.Href - case "ssh": - result.SSHURL = c.Href - } - } + result.CloneURL, result.SSHURL = parseCloneURLs(bb.Links.Clone) if bb.Owner != nil { result.Owner = bb.Owner.Username diff --git a/bitbucket/bitbucket_test.go b/bitbucket/bitbucket_test.go index a95e6a3..b99f683 100644 --- a/bitbucket/bitbucket_test.go +++ b/bitbucket/bitbucket_test.go @@ -42,10 +42,7 @@ func TestBitbucketGetRepo(t *testing.T) { Avatar struct { Href string `json:"href"` } `json:"avatar"` - Clone []struct { - Href string `json:"href"` - Name string `json:"name"` - } `json:"clone"` + Clone []bbCloneLink `json:"clone"` }{ HTML: struct { Href string `json:"href"` diff --git a/bitbucket/prs.go b/bitbucket/prs.go index 7d8ae61..e32d289 100644 --- a/bitbucket/prs.go +++ b/bitbucket/prs.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" forge "github.com/git-pkgs/forge" @@ -21,22 +22,29 @@ func (f *bitbucketForge) PullRequests() forge.PullRequestService { return &bitbucketPRService{token: f.token, httpClient: f.httpClient} } +type bbPRBranch struct { + Branch struct { + Name string `json:"name"` + } `json:"branch"` + Commit *struct { + Hash string `json:"hash"` + } `json:"commit"` + Repository *struct { + FullName string `json:"full_name"` + Links struct { + Clone []bbCloneLink `json:"clone"` + } `json:"links"` + } `json:"repository"` +} + type bbPullRequest struct { - ID int `json:"id"` - Title string `json:"title"` - Description string `json:"description"` - State string `json:"state"` // OPEN, MERGED, DECLINED, SUPERSEDED - Source struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - } `json:"source"` - Destination struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - } `json:"destination"` - Author *struct { + ID int `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + State string `json:"state"` // OPEN, MERGED, DECLINED, SUPERSEDED + Source bbPRBranch `json:"source"` + Destination bbPRBranch `json:"destination"` + Author *struct { Username string `json:"username"` DisplayName string `json:"display_name"` Links struct { @@ -81,13 +89,40 @@ func convertBitbucketPR(bb bbPullRequest) forge.PullRequest { Number: bb.ID, Title: bb.Title, Body: bb.Description, - Head: bb.Source.Branch.Name, - Base: bb.Destination.Branch.Name, + Head: forge.PRBranch{Ref: bb.Source.Branch.Name}, + Base: forge.PRBranch{Ref: bb.Destination.Branch.Name}, Comments: bb.CommentCount, HTMLURL: bb.Links.HTML.Href, DiffURL: bb.Links.Diff.Href, } + var destFullName string + if bb.Destination.Commit != nil { + result.Base.SHA = bb.Destination.Commit.Hash + } + if bb.Destination.Repository != nil { + destFullName = bb.Destination.Repository.FullName + } + + if bb.Source.Commit != nil { + result.Head.SHA = bb.Source.Commit.Hash + } + if bb.Source.Repository != nil && bb.Source.Repository.FullName != destFullName { + cloneURL, sshURL := parseCloneURLs(bb.Source.Repository.Links.Clone) + parts := strings.Split(bb.Source.Repository.FullName, "/") + var owner, name string + if len(parts) >= 2 { + owner = parts[0] + name = parts[1] + } + result.Head.Fork = &forge.ForkInfo{ + Owner: owner, + Name: name, + CloneURL: cloneURL, + SSHURL: sshURL, + } + } + switch bb.State { case "OPEN": result.State = "open" diff --git a/bitbucket/prs_test.go b/bitbucket/prs_test.go index b40c76b..bcdce32 100644 --- a/bitbucket/prs_test.go +++ b/bitbucket/prs_test.go @@ -17,18 +17,10 @@ func TestBitbucketGetPR(t *testing.T) { Title: "Add feature", Description: "New feature PR", State: "OPEN", - Source: struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - }{Branch: struct { + Source: bbPRBranch{Branch: struct { Name string `json:"name"` }{Name: "feature-branch"}}, - Destination: struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - }{Branch: struct { + Destination: bbPRBranch{Branch: struct { Name string `json:"name"` }{Name: "main"}}, Author: &struct { @@ -82,8 +74,8 @@ func TestBitbucketGetPR(t *testing.T) { assertEqual(t, "Title", "Add feature", pr.Title) assertEqual(t, "Body", "New feature PR", pr.Body) assertEqual(t, "State", "open", pr.State) - assertEqual(t, "Head", "feature-branch", pr.Head) - assertEqual(t, "Base", "main", pr.Base) + assertEqual(t, "Head", "feature-branch", pr.Head.Ref) + assertEqual(t, "Base", "main", pr.Base.Ref) assertEqual(t, "Author.Login", "author1", pr.Author.Login) assertEqualInt(t, "Comments", 3, pr.Comments) assertEqualBool(t, "Merged", false, pr.Merged) diff --git a/gitea/prs.go b/gitea/prs.go index ceeb335..a0920ae 100644 --- a/gitea/prs.go +++ b/gitea/prs.go @@ -48,11 +48,27 @@ func convertGiteaPR(pr *gitea.PullRequest) forge.PullRequest { result.State = stateOpen } - if pr.Head != nil { - result.Head = pr.Head.Name - } + var baseRepoID int64 if pr.Base != nil { - result.Base = pr.Base.Name + result.Base = forge.PRBranch{ + Ref: pr.Base.Ref, + SHA: pr.Base.Sha, + } + baseRepoID = pr.Base.RepoID + } + if pr.Head != nil { + result.Head = forge.PRBranch{ + Ref: pr.Head.Ref, + SHA: pr.Head.Sha, + } + if pr.Head.RepoID != baseRepoID && pr.Head.Repository != nil && pr.Head.Repository.Owner != nil { + result.Head.Fork = &forge.ForkInfo{ + Owner: pr.Head.Repository.Owner.UserName, + Name: pr.Head.Repository.Name, + CloneURL: pr.Head.Repository.CloneURL, + SSHURL: pr.Head.Repository.SSHURL, + } + } } if pr.Poster != nil { diff --git a/github/prs.go b/github/prs.go index 0e1be56..b4f97f3 100644 --- a/github/prs.go +++ b/github/prs.go @@ -81,11 +81,31 @@ func convertGitHubPR(pr *github.PullRequest) forge.PullRequest { } } - if h := pr.GetHead(); h != nil { - result.Head = h.GetRef() - } + var baseFullName string if b := pr.GetBase(); b != nil { - result.Base = b.GetRef() + result.Base = forge.PRBranch{ + Ref: b.GetRef(), + SHA: b.GetSHA(), + } + if repo := b.GetRepo(); repo != nil { + baseFullName = repo.GetFullName() + } + } + if h := pr.GetHead(); h != nil { + result.Head = forge.PRBranch{ + Ref: h.GetRef(), + SHA: h.GetSHA(), + } + if repo := h.GetRepo(); repo != nil { + if repo.GetFullName() != baseFullName { + result.Head.Fork = &forge.ForkInfo{ + Owner: repo.GetOwner().GetLogin(), + Name: repo.GetName(), + CloneURL: repo.GetCloneURL(), + SSHURL: repo.GetSSHURL(), + } + } + } } if u := pr.GetMergedBy(); u != nil { diff --git a/github/prs_test.go b/github/prs_test.go index 6444011..8ff6a49 100644 --- a/github/prs_test.go +++ b/github/prs_test.go @@ -69,8 +69,8 @@ func TestGitHubGetPR(t *testing.T) { assertEqualBool(t, "Draft", false, pr.Draft) assertEqualBool(t, "Merged", false, pr.Merged) assertEqualBool(t, "Mergeable", true, pr.Mergeable) - assertEqual(t, "Head", "feature-branch", pr.Head) - assertEqual(t, "Base", "main", pr.Base) + assertEqual(t, "Head", "feature-branch", pr.Head.Ref) + assertEqual(t, "Base", "main", pr.Base.Ref) assertEqual(t, "Author.Login", "octocat", pr.Author.Login) assertEqualInt(t, "Comments", 2, pr.Comments) assertEqualInt(t, "Additions", 10, pr.Additions) @@ -174,7 +174,7 @@ func TestGitHubListPRs(t *testing.T) { t.Fatalf("expected 2 PRs, got %d", len(prs)) } assertEqual(t, "prs[0].Title", "First PR", prs[0].Title) - assertEqual(t, "prs[0].Head", "feature-1", prs[0].Head) + assertEqual(t, "prs[0].Head", "feature-1", prs[0].Head.Ref) assertEqual(t, "prs[1].Title", "Second PR", prs[1].Title) } diff --git a/gitlab/prs.go b/gitlab/prs.go index bc26a03..6251ab9 100644 --- a/gitlab/prs.go +++ b/gitlab/prs.go @@ -2,10 +2,11 @@ package gitlab import ( "context" - forge "github.com/git-pkgs/forge" + "fmt" "net/http" "time" + forge "github.com/git-pkgs/forge" gitlab "gitlab.com/gitlab-org/api/client-go" ) @@ -28,8 +29,8 @@ func convertGitLabMR(mr *gitlab.MergeRequest) forge.PullRequest { Body: mr.Description, State: mr.State, // "opened", "closed", "merged" Draft: mr.Draft, - Head: mr.SourceBranch, - Base: mr.TargetBranch, + Head: forge.PRBranch{Ref: mr.SourceBranch, SHA: mr.SHA}, + Base: forge.PRBranch{Ref: mr.TargetBranch}, Merged: mr.State == "merged", Comments: int(mr.UserNotesCount), // ChangesCount is a string in the GitLab API @@ -117,8 +118,8 @@ func convertBasicGitLabMR(mr *gitlab.BasicMergeRequest) forge.PullRequest { Body: mr.Description, State: mr.State, Draft: mr.Draft, - Head: mr.SourceBranch, - Base: mr.TargetBranch, + Head: forge.PRBranch{Ref: mr.SourceBranch}, + Base: forge.PRBranch{Ref: mr.TargetBranch}, Merged: mr.State == "merged", HTMLURL: mr.WebURL, } @@ -184,6 +185,22 @@ func (s *gitLabPRService) Get(ctx context.Context, owner, repo string, number in return nil, err } result := convertGitLabMR(mr) + + if mr.SourceProjectID != mr.TargetProjectID { + sourceProject, _, err := s.client.Projects.GetProject(mr.SourceProjectID, nil) + if err != nil { + return nil, fmt.Errorf("getting source project: %w", err) + } + if sourceProject != nil { + result.Head.Fork = &forge.ForkInfo{ + Owner: sourceProject.Namespace.Path, + Name: sourceProject.Path, + CloneURL: sourceProject.HTTPURLToRepo, + SSHURL: sourceProject.SSHURLToRepo, + } + } + } + return &result, nil } diff --git a/internal/cli/pr.go b/internal/cli/pr.go index e17b501..ca22cf1 100644 --- a/internal/cli/pr.go +++ b/internal/cli/pr.go @@ -1,8 +1,10 @@ package cli import ( + "context" "fmt" "os" + "os/exec" "strconv" "strings" @@ -36,6 +38,7 @@ func init() { prCmd.AddCommand(prCommentCmd()) prCmd.AddCommand(prReactionsCmd()) prCmd.AddCommand(prReactCmd()) + prCmd.AddCommand(prCheckoutCmd()) } func prViewCmd() *cobra.Command { @@ -100,7 +103,7 @@ func printPRDetails(pr *forges.PullRequest) { _, _ = fmt.Fprintf(os.Stdout, "#%d %s\n", pr.Number, output.Sanitize(pr.Title)) _, _ = fmt.Fprintf(os.Stdout, "State: %s\n", pr.State) _, _ = fmt.Fprintf(os.Stdout, "Author: %s\n", output.Sanitize(pr.Author.Login)) - _, _ = fmt.Fprintf(os.Stdout, "Branch: %s -> %s\n", pr.Head, pr.Base) + _, _ = fmt.Fprintf(os.Stdout, "Branch: %s -> %s\n", pr.Head.Ref, pr.Base.Ref) if pr.Draft { _, _ = fmt.Fprintln(os.Stdout, "Draft: yes") @@ -208,7 +211,7 @@ func prListCmd() *cobra.Command { strconv.Itoa(pr.Number), title, output.Sanitize(pr.Author.Login), - pr.Head, + pr.Head.Ref, pr.UpdatedAt.Format("2006-01-02"), } } @@ -538,3 +541,169 @@ func prCommentCmd() *cobra.Command { cmd.Flags().StringVarP(&flagBody, "body", "b", "", "Comment body") return cmd } + +func prCheckoutCmd() *cobra.Command { + var ( + flagRemoteName string + flagBranch string + flagDetach bool + flagForce bool + ) + + cmd := &cobra.Command{ + Use: "checkout ", + Short: "Check out a pull request locally", + Long: `Check out a pull request's head branch locally. + +If the PR is from a fork, the fork repository is added as a remote +(named after the fork owner by default), and the branch is fetched +and checked out with upstream tracking configured. + +For same-repo PRs, the branch is fetched and checked out.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + number, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid PR number: %s", args[0]) + } + + forge, owner, repoName, _, err := resolve.Repo(flagRepo, flagForgeType) + if err != nil { + return err + } + + ctx := cmd.Context() + + pr, err := forge.PullRequests().Get(ctx, owner, repoName, number) + if err != nil { + return fmt.Errorf("getting PR #%d: %w", number, err) + } + + // remoteRef is the branch name on the remote (PR's head branch) + remoteRef := pr.Head.Ref + + // localBranch is what we'll name the local branch (defaults to remote ref) + localBranch := remoteRef + if flagBranch != "" { + localBranch = flagBranch + } + + if pr.Head.Fork != nil { + return checkoutForkPR(ctx, pr, remoteRef, localBranch, flagRemoteName, flagDetach, flagForce) + } + + return checkoutSameRepoPR(ctx, remoteRef, localBranch, flagDetach, flagForce) + }, + } + + cmd.Flags().StringVar(&flagRemoteName, "remote-name", "", "Name for fork remote (default: fork owner)") + cmd.Flags().StringVarP(&flagBranch, "branch", "b", "", "Local branch name (default: same as remote)") + cmd.Flags().BoolVar(&flagDetach, "detach", false, "Checkout in detached HEAD mode") + cmd.Flags().BoolVarP(&flagForce, "force", "f", false, "Reset the local branch to the remote state even if it has diverged") + return cmd +} + +func checkoutForkPR(ctx context.Context, pr *forges.PullRequest, remoteRef, localBranch, flagRemoteName string, detach, force bool) error { + fork := pr.Head.Fork + remoteName := flagRemoteName + if remoteName == "" { + remoteName = fork.Owner + } + if remoteName == "" { + remoteName = "fork" + } + + cloneURL := fork.CloneURL + if cloneURL == "" { + cloneURL = fork.SSHURL + } + if cloneURL == "" { + return fmt.Errorf("no clone URL available for fork repository") + } + + remoteName, err := ensureRemote(ctx, remoteName, cloneURL) + if err != nil { + return err + } + + return gitCheckout(ctx, remoteName, remoteRef, localBranch, detach, force) +} + +func checkoutSameRepoPR(ctx context.Context, remoteRef, localBranch string, detach, force bool) error { + return gitCheckout(ctx, resolve.RemoteName(), remoteRef, localBranch, detach, force) +} + +func ensureRemote(ctx context.Context, preferredName, cloneURL string) (string, error) { + remotes, err := exec.CommandContext(ctx, "git", "remote", "-v").Output() + if err == nil { + for _, line := range strings.Split(string(remotes), "\n") { + fields := strings.Fields(line) + if len(fields) >= 2 && fields[1] == cloneURL { + return fields[0], nil + } + } + } + + existingURL, err := exec.CommandContext(ctx, "git", "remote", "get-url", preferredName).Output() + if err != nil { + addCmd := exec.CommandContext(ctx, "git", "remote", "add", "--", preferredName, cloneURL) + addCmd.Stdout = os.Stdout + addCmd.Stderr = os.Stderr + if err := addCmd.Run(); err != nil { + return "", fmt.Errorf("adding remote %s: %w", preferredName, err) + } + return preferredName, nil + } + + if strings.TrimSpace(string(existingURL)) == cloneURL { + return preferredName, nil + } + + return "", fmt.Errorf("remote %q already exists with a different URL; use --remote-name to specify a different name", preferredName) +} + +func gitCheckout(ctx context.Context, remote, remoteRef, localBranch string, detach, force bool) error { + refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s/%s", remoteRef, remote, remoteRef) + fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", remote, refspec) + fetchCmd.Stdout = os.Stdout + fetchCmd.Stderr = os.Stderr + if err := fetchCmd.Run(); err != nil { + return fmt.Errorf("fetching %s/%s: %w", remote, remoteRef, err) + } + + ref := remote + "/" + remoteRef + + if detach { + cmd := exec.CommandContext(ctx, "git", "checkout", "--detach", ref) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() + } + + // Try creating a new branch + if exec.CommandContext(ctx, "git", "checkout", "-b", localBranch, ref).Run() == nil { + return nil + } + + // Branch exists - switch to it and try to fast-forward + cmd := exec.CommandContext(ctx, "git", "checkout", localBranch) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("checking out %s: %w", localBranch, err) + } + + if exec.CommandContext(ctx, "git", "merge", "--ff-only", ref).Run() == nil { + return nil + } + + if !force { + return fmt.Errorf("local branch %q has diverged from %s; use --force to reset it", localBranch, ref) + } + + _, _ = fmt.Fprintf(os.Stderr, "warning: resetting %q to %s (local commits will be lost)\n", localBranch, ref) + resetCmd := exec.CommandContext(ctx, "git", "reset", "--hard", ref) + resetCmd.Stdout = os.Stdout + resetCmd.Stderr = os.Stderr + return resetCmd.Run() +} diff --git a/internal/cli/pr_checkout_test.go b/internal/cli/pr_checkout_test.go new file mode 100644 index 0000000..58272c0 --- /dev/null +++ b/internal/cli/pr_checkout_test.go @@ -0,0 +1,423 @@ +package cli + +import ( + "bytes" + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + forges "github.com/git-pkgs/forge" + "github.com/git-pkgs/forge/internal/resolve" +) + +// mockPRService implements forges.PullRequestService for testing. +type mockPRService struct { + pr *forges.PullRequest + err error +} + +func (m *mockPRService) Get(_ context.Context, _, _ string, _ int) (*forges.PullRequest, error) { + return m.pr, m.err +} + +func (m *mockPRService) List(_ context.Context, _, _ string, _ forges.ListPROpts) ([]forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Create(_ context.Context, _, _ string, _ forges.CreatePROpts) (*forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Update(_ context.Context, _, _ string, _ int, _ forges.UpdatePROpts) (*forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Close(_ context.Context, _, _ string, _ int) error { + return nil +} + +func (m *mockPRService) Reopen(_ context.Context, _, _ string, _ int) error { + return nil +} + +func (m *mockPRService) Merge(_ context.Context, _, _ string, _ int, _ forges.MergePROpts) error { + return nil +} + +func (m *mockPRService) Diff(_ context.Context, _, _ string, _ int) (string, error) { + return "", nil +} + +func (m *mockPRService) CreateComment(_ context.Context, _, _ string, _ int, _ string) (*forges.Comment, error) { + return nil, nil +} + +func (m *mockPRService) ListComments(_ context.Context, _, _ string, _ int) ([]forges.Comment, error) { + return nil, nil +} + +func (m *mockPRService) ListReactions(_ context.Context, _, _ string, _ int, _ int64) ([]forges.Reaction, error) { + return nil, nil +} + +func (m *mockPRService) AddReaction(_ context.Context, _, _ string, _ int, _ int64, _ string) (*forges.Reaction, error) { + return nil, nil +} + +func (m *mockPRService) ListURL(_ string) string { + return "" +} + +// mockForge implements forges.Forge for testing. +type mockForge struct { + prService *mockPRService +} + +func (m *mockForge) Repos() forges.RepoService { return nil } +func (m *mockForge) Issues() forges.IssueService { return nil } +func (m *mockForge) PullRequests() forges.PullRequestService { return m.prService } +func (m *mockForge) Labels() forges.LabelService { return nil } +func (m *mockForge) Milestones() forges.MilestoneService { return nil } +func (m *mockForge) Releases() forges.ReleaseService { return nil } +func (m *mockForge) CI() forges.CIService { return nil } +func (m *mockForge) Branches() forges.BranchService { return nil } +func (m *mockForge) DeployKeys() forges.DeployKeyService { return nil } +func (m *mockForge) Secrets() forges.SecretService { return nil } +func (m *mockForge) Notifications() forges.NotificationService { return nil } +func (m *mockForge) Reviews() forges.ReviewService { return nil } +func (m *mockForge) Files() forges.FileService { return nil } +func (m *mockForge) Collaborators() forges.CollaboratorService { return nil } +func (m *mockForge) CommitStatuses() forges.CommitStatusService { return nil } +func (m *mockForge) GetRateLimit(_ context.Context) (*forges.RateLimit, error) { + return nil, forges.ErrNotSupported +} + +// setupTestRepo creates a temporary git repository with an initial commit +// and an origin remote pointing to a fake URL. +func setupTestRepo(t *testing.T, originURL string) string { + t.Helper() + dir := t.TempDir() + + commands := [][]string{ + {"git", "init"}, + {"git", "config", "user.email", "test@test.com"}, + {"git", "config", "user.name", "Test User"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } + + // Create an initial commit so we have a valid HEAD + testFile := filepath.Join(dir, "README.md") + if err := os.WriteFile(testFile, []byte("# Test\n"), 0644); err != nil { + t.Fatalf("writing test file: %v", err) + } + + commands = [][]string{ + {"git", "add", "README.md"}, + {"git", "commit", "-m", "Initial commit"}, + {"git", "remote", "add", "origin", originURL}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } + + return dir +} + +// setupBareRepo creates a bare git repository that can be used as a remote. +func setupBareRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + + cmd := exec.Command("git", "init", "--bare") + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init --bare failed: %v\n%s", err, out) + } + + return dir +} + +// pushBranchToRemote creates a branch and pushes it to a remote. +func pushBranchToRemote(t *testing.T, repoDir, remoteName, branchName string) { + t.Helper() + + // Create a file and commit on a new branch + testFile := filepath.Join(repoDir, branchName+".txt") + if err := os.WriteFile(testFile, []byte("content for "+branchName+"\n"), 0644); err != nil { + t.Fatalf("writing test file: %v", err) + } + + commands := [][]string{ + {"git", "checkout", "-b", branchName}, + {"git", "add", "."}, + {"git", "commit", "-m", "Add " + branchName}, + {"git", "push", remoteName, branchName}, + {"git", "checkout", "-"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } +} + +func TestPRCheckout(t *testing.T) { + tests := []struct { + name string + pr *forges.PullRequest + args []string + setupOrigin bool // whether to create and push to origin + setupFork bool // whether to create a fork remote + wantBranch string + wantRemote string // expected remote name for fork PRs + wantErr string + }{ + { + name: "same-repo PR checks out branch", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42"}, + setupOrigin: true, + wantBranch: "feature-branch", + }, + { + name: "fork PR adds remote and checks out", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + CloneURL: "FORK_URL_PLACEHOLDER", // will be replaced + }, + }, + }, + args: []string{"pr", "checkout", "42"}, + setupFork: true, + wantBranch: "feature", + wantRemote: "contributor", + }, + { + name: "fork PR with custom remote name", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + CloneURL: "FORK_URL_PLACEHOLDER", + }, + }, + }, + args: []string{"pr", "checkout", "42", "--remote-name", "upstream"}, + setupFork: true, + wantBranch: "feature", + wantRemote: "upstream", + }, + { + name: "detach mode", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42", "--detach"}, + setupOrigin: true, + wantBranch: "", // detached HEAD + }, + { + name: "checkout with custom branch name", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42", "-b", "my-local-branch"}, + setupOrigin: true, + wantBranch: "my-local-branch", + }, + { + name: "invalid PR number", + args: []string{"pr", "checkout", "notanumber"}, + wantErr: "invalid PR number", + }, + { + name: "fork PR without clone URL", + pr: &forges.PullRequest{ + Number: 42, + Head: forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + // CloneURL and SSHURL both empty + }, + }, + }, + args: []string{"pr", "checkout", "42"}, + wantErr: "no clone URL available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip tests that need real git operations in short mode + if testing.Short() && (tt.setupOrigin || tt.setupFork) { + t.Skip("skipping git integration test in short mode") + } + + // Reset flags to defaults before each test + // Find the checkout command and reset its flags + checkoutCmd, _, _ := rootCmd.Find([]string{"pr", "checkout"}) + if checkoutCmd != nil { + _ = checkoutCmd.Flags().Set("detach", "false") + _ = checkoutCmd.Flags().Set("force", "false") + _ = checkoutCmd.Flags().Set("branch", "") + _ = checkoutCmd.Flags().Set("remote-name", "") + } + + var workDir string + + // For git integration tests, set up repos + if tt.setupOrigin || tt.setupFork { + originDir := setupBareRepo(t) + workDir = setupTestRepo(t, originDir) + + if tt.setupOrigin { + branchName := tt.pr.Head.Ref + pushBranchToRemote(t, workDir, "origin", branchName) + } + + if tt.setupFork { + forkDir := setupBareRepo(t) + tt.pr.Head.Fork.CloneURL = forkDir + + branchName := tt.pr.Head.Ref + cmd := exec.Command("git", "remote", "add", "tempfork", forkDir) + cmd.Dir = workDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("adding temp fork remote: %v\n%s", err, out) + } + pushBranchToRemote(t, workDir, "tempfork", branchName) + cmd = exec.Command("git", "remote", "remove", "tempfork") + cmd.Dir = workDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("removing temp fork remote: %v\n%s", err, out) + } + } + } else if tt.pr != nil { + // For error tests that still need a git context, create a minimal repo + originDir := setupBareRepo(t) + workDir = setupTestRepo(t, originDir) + } + + // Change to work directory for the test + if workDir != "" { + t.Chdir(workDir) + } + + // Set up mock forge + if tt.pr != nil { + resolve.SetTestForge( + &mockForge{prService: &mockPRService{pr: tt.pr}}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + } + + // Execute command + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + rootCmd.SetArgs(tt.args) + + err := rootCmd.Execute() + + // Check error + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, buf.String()) + } + + if workDir == "" { + return // no git state to verify + } + + // Verify branch + if tt.wantBranch != "" { + cmd := exec.Command("git", "branch", "--show-current") + cmd.Dir = workDir + out, err := cmd.Output() + if err != nil { + t.Fatalf("getting current branch: %v", err) + } + gotBranch := strings.TrimSpace(string(out)) + if gotBranch != tt.wantBranch { + t.Errorf("branch: want %q, got %q", tt.wantBranch, gotBranch) + } + } else { + // Detached HEAD - verify no branch + cmd := exec.Command("git", "branch", "--show-current") + cmd.Dir = workDir + out, _ := cmd.Output() + if strings.TrimSpace(string(out)) != "" { + t.Errorf("expected detached HEAD, but on branch %q", strings.TrimSpace(string(out))) + } + } + + // Verify remote for fork PRs + if tt.wantRemote != "" { + cmd := exec.Command("git", "remote", "-v") + cmd.Dir = workDir + out, err := cmd.Output() + if err != nil { + t.Fatalf("listing remotes: %v", err) + } + if !strings.Contains(string(out), tt.wantRemote) { + t.Errorf("expected remote %q in output:\n%s", tt.wantRemote, out) + } + } + }) + } +} diff --git a/internal/resolve/resolve.go b/internal/resolve/resolve.go index a5b0f0f..953262e 100644 --- a/internal/resolve/resolve.go +++ b/internal/resolve/resolve.go @@ -20,6 +20,13 @@ var ( remoteName = "origin" hostOverride string forgeTypeOverride string + + // testForge allows tests to inject a mock forge. When set, Repo() returns + // this forge directly without network or git resolution. + testForge forges.Forge + testOwner string + testRepo string + testDomain string ) // SetRemote sets which git remote to read when resolving the current @@ -32,6 +39,12 @@ func SetRemote(name string) { } } +// RemoteName returns the name of the git remote being used for resolution. +// This is "origin" by default, or whatever was set via SetRemote. +func RemoteName() string { + return remoteName +} + // SetHost forces a specific forge domain, taking precedence over FORGE_HOST, // --forge-type, and git remote detection. The CLI calls this from the --host // persistent flag. An empty string is ignored. @@ -50,6 +63,23 @@ func SetForgeType(forgeType string) { } } +// SetTestForge configures a mock forge for testing. When set, Repo() returns +// this forge directly without network or git resolution. +func SetTestForge(forge forges.Forge, owner, repo, domain string) { + testForge = forge + testOwner = owner + testRepo = repo + testDomain = domain +} + +// ResetTestForge clears the test forge configuration. +func ResetTestForge() { + testForge = nil + testOwner = "" + testRepo = "" + testDomain = "" +} + var builders = forges.ForgeBuilders{ GitHub: ghforge.NewWithBase, GitLab: glforge.New, @@ -60,6 +90,9 @@ var builders = forges.ForgeBuilders{ // git remote. The -R flag takes precedence; otherwise we read the "origin" // remote URL and parse it. func Repo(flagRepo, flagForgeType string) (forge forges.Forge, owner, repo, domain string, err error) { + if testForge != nil { + return testForge, testOwner, testRepo, testDomain, nil + } if flagRepo != "" { return repoFromFlag(flagRepo, flagForgeType) } diff --git a/types.go b/types.go index b7e0738..bc04c18 100644 --- a/types.go +++ b/types.go @@ -229,6 +229,22 @@ type UpdateIssueOpts struct { Milestone *string } +// ForkInfo holds minimal repository info needed for PR checkout from forks. +type ForkInfo struct { + Owner string `json:"owner"` + Name string `json:"name,omitempty"` + CloneURL string `json:"clone_url,omitempty"` + SSHURL string `json:"ssh_url,omitempty"` +} + +// PRBranch holds branch info including the repository it belongs to. +// For same-repo PRs, Fork is nil. For fork PRs, Fork points to the source repo. +type PRBranch struct { + Ref string `json:"ref"` // branch name + SHA string `json:"sha,omitempty"` // commit SHA + Fork *ForkInfo `json:"fork,omitempty"` // nil if same repo as target +} + // PullRequest holds normalized metadata about a pull request (or merge request). type PullRequest struct { Number int `json:"number"` @@ -241,8 +257,8 @@ type PullRequest struct { Reviewers []User `json:"reviewers,omitempty"` Labels []Label `json:"labels,omitempty"` Milestone *Milestone `json:"milestone,omitempty"` - Head string `json:"head"` // head branch - Base string `json:"base"` // base branch + Head PRBranch `json:"head"` + Base PRBranch `json:"base"` Mergeable bool `json:"mergeable"` Merged bool `json:"merged"` MergedBy *User `json:"merged_by,omitempty"`