Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions cmd/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ the new PRs (existing PRs are never removed).`,

// resolvedArg holds the result of resolving a single CLI argument to a PR.
type resolvedArg struct {
branch string // head branch name
prNumber int // PR number
prURL string // PR URL (for display)
created bool // true if we created this PR (skip base-fix re-fetch)
branch string // head branch name
prNumber int // PR number
prURL string // PR URL (for display)
created bool // true if we created this PR (skip base-fix re-fetch)
pr *github.PullRequest // full PR data (only set for existing PRs)
}

func runLink(cfg *config.Config, opts *linkOptions, args []string) error {
Expand All @@ -98,6 +99,12 @@ func runLink(cfg *config.Config, opts *linkOptions, args []string) error {
return err
}

// Phase 2b: Validate that all found PRs are eligible to be added to a stack.
// Only open/draft PRs without auto-merge enabled are allowed.
if err := validatePREligibility(cfg, found); err != nil {
return err
}

// Phase 3: Pre-validate the stack — check that adding these PRs won't
// conflict with existing stacks before creating any new PRs.
// Also fetches stacks for reuse in the upsert phase.
Expand Down Expand Up @@ -238,6 +245,7 @@ func findExistingPR(cfg *config.Config, client github.ClientOps, arg string) (*r
branch: pr.HeadRefName,
prNumber: pr.Number,
prURL: pr.URL,
pr: pr,
}, nil
}
// PR doesn't exist — fall through to branch name lookup
Expand All @@ -255,12 +263,47 @@ func findExistingPR(cfg *config.Config, client github.ClientOps, arg string) (*r
branch: arg,
prNumber: pr.Number,
prURL: pr.URL,
pr: pr,
}, nil
}

return nil, nil // needs PR creation
}

// validatePREligibility checks that all found PRs are eligible to be added
// to a stack. Only open or draft PRs without auto-merge enabled are allowed.
// Merged, closed, queued, and auto-merge-enabled PRs are rejected.
// Reports all invalid PRs at once before returning.
func validatePREligibility(cfg *config.Config, found []*resolvedArg) error {
invalid := 0
for _, r := range found {
if r == nil || r.pr == nil {
continue
}
pr := r.pr
reason := ""
switch {
case pr.State == "MERGED":
reason = "it has been merged"
case pr.State == "CLOSED":
reason = "it is closed"
case pr.IsQueued():
reason = "it is queued for merge"
case pr.IsAutoMergeEnabled():
reason = "it has auto-merge enabled"
}
if reason != "" {
cfg.Errorf("PR %s cannot be added to a stack: %s",
cfg.PRLink(r.prNumber, r.prURL), reason)
invalid++
}
}
if invalid > 0 {
return ErrInvalidArgs
}
return nil
}

// listStacksSafe fetches all stacks, handling the 404 "not enabled" case.
func listStacksSafe(cfg *config.Config, client github.ClientOps) ([]github.RemoteStack, error) {
stacks, err := client.ListStacks()
Expand Down
259 changes: 259 additions & 0 deletions cmd/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,265 @@ func TestLink_Create422(t *testing.T) {
assert.Contains(t, output, "must form a stack")
}

// --- PR eligibility tests ---

func TestLink_RejectsMergedPR(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(n int) (*github.PullRequest, error) {
return &github.PullRequest{
Number: n,
State: "MERGED",
HeadRefName: fmt.Sprintf("branch-%d", n),
BaseRefName: "main",
URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n),
Merged: true,
}, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"10", "20"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "merged")
}

func TestLink_RejectsClosedPR(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(n int) (*github.PullRequest, error) {
return &github.PullRequest{
Number: n,
State: "CLOSED",
HeadRefName: fmt.Sprintf("branch-%d", n),
BaseRefName: "main",
URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n),
}, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"10", "20"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "closed")
}

func TestLink_RejectsQueuedPR(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(n int) (*github.PullRequest, error) {
pr := &github.PullRequest{
Number: n,
State: "OPEN",
HeadRefName: fmt.Sprintf("branch-%d", n),
BaseRefName: "main",
URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n),
}
if n == 20 {
pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_123"}
}
return pr, nil
},
ListStacksFn: func() ([]github.RemoteStack, error) {
return []github.RemoteStack{}, nil
},
CreateStackFn: func([]int) (int, error) {
t.Fatal("CreateStack should not be called for ineligible PRs")
return 0, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"10", "20"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "queued for merge")
}

func TestLink_RejectsAutoMergeEnabledPR(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(n int) (*github.PullRequest, error) {
pr := &github.PullRequest{
Number: n,
State: "OPEN",
HeadRefName: fmt.Sprintf("branch-%d", n),
BaseRefName: "main",
URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n),
}
if n == 10 {
pr.AutoMergeRequest = &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"}
}
return pr, nil
},
ListStacksFn: func() ([]github.RemoteStack, error) {
return []github.RemoteStack{}, nil
},
CreateStackFn: func([]int) (int, error) {
t.Fatal("CreateStack should not be called for ineligible PRs")
return 0, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"10", "20"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "auto-merge")
}

func TestLink_RejectsQueuedPR_ByBranch(t *testing.T) {
restore := git.SetOps(newLinkGitMock("feature-a", "feature-b"))
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
pr := &github.PullRequest{
Number: 10,
HeadRefName: branch,
BaseRefName: "main",
URL: "https://github.com/o/r/pull/10",
}
if branch == "feature-b" {
pr.Number = 20
pr.URL = "https://github.com/o/r/pull/20"
pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_456"}
}
return pr, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"feature-a", "feature-b"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "queued for merge")
}

func TestLink_RejectsAutoMergePR_ByBranch(t *testing.T) {
restore := git.SetOps(newLinkGitMock("feature-a", "feature-b"))
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
if branch == "feature-a" {
return &github.PullRequest{
Number: 10,
HeadRefName: branch,
BaseRefName: "main",
URL: "https://github.com/o/r/pull/10",
AutoMergeRequest: &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"},
}, nil
}
return &github.PullRequest{
Number: 20,
HeadRefName: branch,
BaseRefName: "main",
URL: "https://github.com/o/r/pull/20",
}, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"feature-a", "feature-b"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
assert.Contains(t, output, "cannot be added to a stack")
assert.Contains(t, output, "auto-merge")
}

func TestLink_ReportsMultipleIneligiblePRs(t *testing.T) {
cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRByNumberFn: func(n int) (*github.PullRequest, error) {
pr := &github.PullRequest{
Number: n,
State: "OPEN",
HeadRefName: fmt.Sprintf("branch-%d", n),
BaseRefName: "main",
URL: fmt.Sprintf("https://github.com/o/r/pull/%d", n),
}
switch n {
case 10:
pr.State = "MERGED"
pr.Merged = true
case 20:
pr.MergeQueueEntry = &github.MergeQueueEntry{ID: "MQE_789"}
case 30:
pr.AutoMergeRequest = &github.AutoMergeRequest{EnabledAt: "2024-01-01T00:00:00Z"}
}
return pr, nil
},
}

cmd := LinkCmd(cfg)
cmd.SetArgs([]string{"10", "20", "30"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)
output := string(errOut)

assert.ErrorIs(t, err, ErrInvalidArgs)
// All three invalid PRs should be reported
assert.Contains(t, output, "merged")
assert.Contains(t, output, "queued for merge")
assert.Contains(t, output, "auto-merge")
}

// --- Branch name tests ---

func TestLink_BranchNames_AllHavePRs(t *testing.T) {
Expand Down
Loading
Loading