diff --git a/cmd/link.go b/cmd/link.go index 1367e51..9740580 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -68,10 +68,11 @@ the new PRs (existing PRs are never removed).`, // resolvedArg holds the result of resolving a single CLI argument to a PR. type resolvedArg struct { - branch string // head branch name - prNumber int // PR number - prURL string // PR URL (for display) - created bool // true if we created this PR (skip base-fix re-fetch) + branch string // head branch name + prNumber int // PR number + prURL string // PR URL (for display) + created bool // true if we created this PR (skip base-fix re-fetch) + pr *github.PullRequest // full PR data (only set for existing PRs) } func runLink(cfg *config.Config, opts *linkOptions, args []string) error { @@ -98,6 +99,12 @@ func runLink(cfg *config.Config, opts *linkOptions, args []string) error { return err } + // Phase 2b: Validate that all found PRs are eligible to be added to a stack. + // Only open/draft PRs without auto-merge enabled are allowed. + if err := validatePREligibility(cfg, found); err != nil { + return err + } + // Phase 3: Pre-validate the stack — check that adding these PRs won't // conflict with existing stacks before creating any new PRs. // Also fetches stacks for reuse in the upsert phase. @@ -238,6 +245,7 @@ func findExistingPR(cfg *config.Config, client github.ClientOps, arg string) (*r branch: pr.HeadRefName, prNumber: pr.Number, prURL: pr.URL, + pr: pr, }, nil } // PR doesn't exist — fall through to branch name lookup @@ -255,12 +263,47 @@ func findExistingPR(cfg *config.Config, client github.ClientOps, arg string) (*r branch: arg, prNumber: pr.Number, prURL: pr.URL, + pr: pr, }, nil } return nil, nil // needs PR creation } +// validatePREligibility checks that all found PRs are eligible to be added +// to a stack. Only open or draft PRs without auto-merge enabled are allowed. +// Merged, closed, queued, and auto-merge-enabled PRs are rejected. +// Reports all invalid PRs at once before returning. +func validatePREligibility(cfg *config.Config, found []*resolvedArg) error { + invalid := 0 + for _, r := range found { + if r == nil || r.pr == nil { + continue + } + pr := r.pr + reason := "" + switch { + case pr.State == "MERGED": + reason = "it has been merged" + case pr.State == "CLOSED": + reason = "it is closed" + case pr.IsQueued(): + reason = "it is queued for merge" + case pr.IsAutoMergeEnabled(): + reason = "it has auto-merge enabled" + } + if reason != "" { + cfg.Errorf("PR %s cannot be added to a stack: %s", + cfg.PRLink(r.prNumber, r.prURL), reason) + invalid++ + } + } + if invalid > 0 { + return ErrInvalidArgs + } + return nil +} + // listStacksSafe fetches all stacks, handling the 404 "not enabled" case. func listStacksSafe(cfg *config.Config, client github.ClientOps) ([]github.RemoteStack, error) { stacks, err := client.ListStacks() diff --git a/cmd/link_test.go b/cmd/link_test.go index 4514de2..80f3970 100644 --- a/cmd/link_test.go +++ b/cmd/link_test.go @@ -298,6 +298,265 @@ func TestLink_Create422(t *testing.T) { assert.Contains(t, output, "must form a stack") } +// --- PR eligibility tests --- + +func TestLink_RejectsMergedPR(t *testing.T) { + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: n, + State: "MERGED", + HeadRefName: fmt.Sprintf("branch-%d", n), + BaseRefName: "main", + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n), + Merged: true, + }, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "merged") +} + +func TestLink_RejectsClosedPR(t *testing.T) { + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: n, + State: "CLOSED", + HeadRefName: fmt.Sprintf("branch-%d", n), + BaseRefName: "main", + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n), + }, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "closed") +} + +func TestLink_RejectsQueuedPR(t *testing.T) { + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + pr := &github.PullRequest{ + Number: n, + State: "OPEN", + HeadRefName: fmt.Sprintf("branch-%d", n), + BaseRefName: "main", + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n), + } + if n == 20 { + pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_123"} + } + return pr, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func([]int) (int, error) { + t.Fatal("CreateStack should not be called for ineligible PRs") + return 0, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "queued for merge") +} + +func TestLink_RejectsAutoMergeEnabledPR(t *testing.T) { + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + pr := &github.PullRequest{ + Number: n, + State: "OPEN", + HeadRefName: fmt.Sprintf("branch-%d", n), + BaseRefName: "main", + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n), + } + if n == 10 { + pr.AutoMergeRequest = &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"} + } + return pr, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func([]int) (int, error) { + t.Fatal("CreateStack should not be called for ineligible PRs") + return 0, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "auto-merge") +} + +func TestLink_RejectsQueuedPR_ByBranch(t *testing.T) { + restore := git.SetOps(newLinkGitMock("feature-a", "feature-b")) + defer restore() + + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + pr := &github.PullRequest{ + Number: 10, + HeadRefName: branch, + BaseRefName: "main", + URL: "https://github.com/o/r/pull/10", + } + if branch == "feature-b" { + pr.Number = 20 + pr.URL = "https://github.com/o/r/pull/20" + pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_456"} + } + return pr, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"feature-a", "feature-b"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "queued for merge") +} + +func TestLink_RejectsAutoMergePR_ByBranch(t *testing.T) { + restore := git.SetOps(newLinkGitMock("feature-a", "feature-b")) + defer restore() + + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + if branch == "feature-a" { + return &github.PullRequest{ + Number: 10, + HeadRefName: branch, + BaseRefName: "main", + URL: "https://github.com/o/r/pull/10", + AutoMergeRequest: &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"}, + }, nil + } + return &github.PullRequest{ + Number: 20, + HeadRefName: branch, + BaseRefName: "main", + URL: "https://github.com/o/r/pull/20", + }, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"feature-a", "feature-b"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + assert.Contains(t, output, "cannot be added to a stack") + assert.Contains(t, output, "auto-merge") +} + +func TestLink_ReportsMultipleIneligiblePRs(t *testing.T) { + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + pr := &github.PullRequest{ + Number: n, + State: "OPEN", + HeadRefName: fmt.Sprintf("branch-%d", n), + BaseRefName: "main", + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n), + } + switch n { + case 10: + pr.State = "MERGED" + pr.Merged = true + case 20: + pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_789"} + case 30: + pr.AutoMergeRequest = &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"} + } + return pr, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20", "30"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + output := string(errOut) + + assert.ErrorIs(t, err, ErrInvalidArgs) + // All three invalid PRs should be reported + assert.Contains(t, output, "merged") + assert.Contains(t, output, "queued for merge") + assert.Contains(t, output, "auto-merge") +} + // --- Branch name tests --- func TestLink_BranchNames_AllHavePRs(t *testing.T) { diff --git a/internal/github/github.go b/internal/github/github.go index 4df668e..7e3784d 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -16,17 +16,25 @@ type MergeQueueEntry struct { ID string `graphql:"id"` } +// AutoMergeRequest represents an auto-merge configuration on a PR. +// When the GraphQL field autoMergeRequest is null (auto-merge not enabled), +// the pointer will be nil. +type AutoMergeRequest struct { + EnabledAt string `graphql:"enabledAt"` +} + // PullRequest represents a GitHub pull request. type PullRequest struct { - ID string `graphql:"id"` - Number int `graphql:"number"` - State string `graphql:"state"` - URL string `graphql:"url"` - HeadRefName string `graphql:"headRefName"` - BaseRefName string `graphql:"baseRefName"` - IsDraft bool `graphql:"isDraft"` - Merged bool `graphql:"merged"` - MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + ID string `graphql:"id"` + Number int `graphql:"number"` + State string `graphql:"state"` + URL string `graphql:"url"` + HeadRefName string `graphql:"headRefName"` + BaseRefName string `graphql:"baseRefName"` + IsDraft bool `graphql:"isDraft"` + Merged bool `graphql:"merged"` + MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + AutoMergeRequest *AutoMergeRequest `graphql:"autoMergeRequest"` } // IsQueued reports whether the pull request is currently in a merge queue. @@ -34,6 +42,11 @@ func (pr *PullRequest) IsQueued() bool { return pr != nil && pr.MergeQueueEntry != nil && pr.MergeQueueEntry.ID != "" } +// IsAutoMergeEnabled reports whether the pull request has auto-merge enabled. +func (pr *PullRequest) IsAutoMergeEnabled() bool { + return pr != nil && pr.AutoMergeRequest != nil +} + // Client wraps GitHub API operations. type Client struct { gql *api.GraphQLClient @@ -85,12 +98,13 @@ func (c *Client) FindPRForBranch(branch string) (*PullRequest, error) { Repository struct { PullRequests struct { Nodes []struct { - ID string `graphql:"id"` - Number int `graphql:"number"` - URL string `graphql:"url"` - BaseRefName string `graphql:"baseRefName"` - IsDraft bool `graphql:"isDraft"` - MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + ID string `graphql:"id"` + Number int `graphql:"number"` + URL string `graphql:"url"` + BaseRefName string `graphql:"baseRefName"` + IsDraft bool `graphql:"isDraft"` + MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + AutoMergeRequest *AutoMergeRequest `graphql:"autoMergeRequest"` } } `graphql:"pullRequests(headRefName: $head, states: [OPEN], first: 1)"` } `graphql:"repository(owner: $owner, name: $name)"` @@ -113,12 +127,13 @@ func (c *Client) FindPRForBranch(branch string) (*PullRequest, error) { n := nodes[0] return &PullRequest{ - ID: n.ID, - Number: n.Number, - URL: n.URL, - BaseRefName: n.BaseRefName, - IsDraft: n.IsDraft, - MergeQueueEntry: n.MergeQueueEntry, + ID: n.ID, + Number: n.Number, + URL: n.URL, + BaseRefName: n.BaseRefName, + IsDraft: n.IsDraft, + MergeQueueEntry: n.MergeQueueEntry, + AutoMergeRequest: n.AutoMergeRequest, }, nil } @@ -296,15 +311,16 @@ func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { var query struct { Repository struct { PullRequest struct { - ID string `graphql:"id"` - Number int `graphql:"number"` - State string `graphql:"state"` - URL string `graphql:"url"` - HeadRefName string `graphql:"headRefName"` - BaseRefName string `graphql:"baseRefName"` - IsDraft bool `graphql:"isDraft"` - Merged bool `graphql:"merged"` - MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + ID string `graphql:"id"` + Number int `graphql:"number"` + State string `graphql:"state"` + URL string `graphql:"url"` + HeadRefName string `graphql:"headRefName"` + BaseRefName string `graphql:"baseRefName"` + IsDraft bool `graphql:"isDraft"` + Merged bool `graphql:"merged"` + MergeQueueEntry *MergeQueueEntry `graphql:"mergeQueueEntry"` + AutoMergeRequest *AutoMergeRequest `graphql:"autoMergeRequest"` } `graphql:"pullRequest(number: $number)"` } `graphql:"repository(owner: $owner, name: $name)"` } @@ -324,15 +340,16 @@ func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { return nil, nil } return &PullRequest{ - ID: n.ID, - Number: n.Number, - State: n.State, - URL: n.URL, - HeadRefName: n.HeadRefName, - BaseRefName: n.BaseRefName, - IsDraft: n.IsDraft, - Merged: n.Merged, - MergeQueueEntry: n.MergeQueueEntry, + ID: n.ID, + Number: n.Number, + State: n.State, + URL: n.URL, + HeadRefName: n.HeadRefName, + BaseRefName: n.BaseRefName, + IsDraft: n.IsDraft, + Merged: n.Merged, + MergeQueueEntry: n.MergeQueueEntry, + AutoMergeRequest: n.AutoMergeRequest, }, nil } diff --git a/internal/github/github_test.go b/internal/github/github_test.go index 29814bc..a7e7822 100644 --- a/internal/github/github_test.go +++ b/internal/github/github_test.go @@ -48,6 +48,26 @@ func TestPullRequest_IsQueued(t *testing.T) { }) } +func TestPullRequest_IsAutoMergeEnabled(t *testing.T) { + t.Run("not enabled when AutoMergeRequest is nil", func(t *testing.T) { + pr := &PullRequest{Number: 1} + assert.False(t, pr.IsAutoMergeEnabled()) + }) + + t.Run("enabled when AutoMergeRequest is present", func(t *testing.T) { + pr := &PullRequest{ + Number: 1, + AutoMergeRequest: &AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"}, + } + assert.True(t, pr.IsAutoMergeEnabled()) + }) + + t.Run("nil receiver is safe", func(t *testing.T) { + var pr *PullRequest + assert.False(t, pr.IsAutoMergeEnabled()) + }) +} + func TestToGraphQLInt(t *testing.T) { t.Run("in range", func(t *testing.T) { got, err := toGraphQLInt(123)