Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions cmd/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,9 @@ func importRemoteStack(
}

// Ensure trunk exists locally
if !git.BranchExists(trunk) {
remoteTrunk := remote + "/" + trunk
if err := git.CreateBranch(trunk, remoteTrunk); err != nil {
cfg.Errorf("could not create trunk branch %s from %s: %v", trunk, remoteTrunk, err)
return nil, ErrSilent
}
if err := ensureLocalTrunk(cfg, trunk, remote); err != nil {
cfg.Errorf("%s", err)
return nil, ErrSilent
}

// Create local branches for each PR's head branch.
Expand Down
17 changes: 17 additions & 0 deletions cmd/modify.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -297,6 +298,22 @@ func checkModifyPreconditions(cfg *config.Config) (*loadStackResult, error) {
return nil, ErrSilent
}

// Ensure trunk branch exists locally (it may be absent if the user
// renamed their initial branch before starting the stack).
if !git.BranchExists(s.Trunk.Branch) {
remote, err := pickRemote(cfg, result.CurrentBranch, "")
if err != nil {
if !errors.Is(err, errInterrupt) {
cfg.Errorf("failed to resolve remote: %s", err)
}
return nil, ErrSilent
}
if err := ensureLocalTrunk(cfg, s.Trunk.Branch, remote); err != nil {
cfg.Errorf("%s", err)
return nil, ErrSilent
}
}
Comment thread
Copilot marked this conversation as resolved.

// Show loading indicator while syncing PRs
fmt.Fprintf(cfg.Err, "Loading stack...")

Expand Down
6 changes: 6 additions & 0 deletions cmd/rebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ func runRebase(cfg *config.Config, opts *rebaseOptions) error {
cfg.Successf("Fetched %s", remote)
}

// Ensure trunk exists locally before fast-forward or cascade rebase.
if err := ensureLocalTrunk(cfg, s.Trunk.Branch, remote); err != nil {
cfg.Errorf("%s", err)
return ErrSilent
}

// Fast-forward trunk so the cascade rebase targets the latest upstream.
fastForwardTrunk(cfg, s.Trunk.Branch, remote, currentBranch)

Expand Down
15 changes: 15 additions & 0 deletions cmd/trunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func runTrunk(cfg *config.Config) error {
return nil
}

// Ensure trunk exists locally before checkout.
if !git.BranchExists(trunk) {
remote, err := pickRemote(cfg, currentBranch, "")
if err != nil {
if !errors.Is(err, errInterrupt) {
cfg.Errorf("failed to resolve remote: %s", err)
}
return ErrSilent
}
if err := ensureLocalTrunk(cfg, trunk, remote); err != nil {
cfg.Errorf("%s", err)
return ErrSilent
}
}

if err := git.CheckoutBranch(trunk); err != nil {
return err
}
Expand Down
50 changes: 50 additions & 0 deletions cmd/trunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,53 @@ func TestTrunk_RejectsArgs(t *testing.T) {

assert.Error(t, err, "should reject positional arguments")
}

func TestTrunk_MissingLocallyCreatedFromRemote(t *testing.T) {
s := stack.Stack{
Trunk: stack.BranchRef{Branch: "main"},
Branches: []stack.BranchRef{{Branch: "b1"}, {Branch: "b2"}},
}

var checkedOut []string
var createdBranch string
tmpDir := t.TempDir()
writeStackFile(t, tmpDir, s)

mock := &git.MockOps{
GitDirFn: func() (string, error) { return tmpDir, nil },
CurrentBranchFn: func() (string, error) { return "b1", nil },
BranchExistsFn: func(name string) bool {
// trunk does not exist locally
return name != "main"
},
ResolveRemoteFn: func(branch string) (string, error) {
return "origin", nil
},
FetchBranchesFn: func(remote string, branches []string) error {
return nil
},
CreateBranchFn: func(name, base string) error {
createdBranch = name
return nil
},
CheckoutBranchFn: func(name string) error {
checkedOut = append(checkedOut, name)
return nil
},
}
restore := git.SetOps(mock)
defer restore()

cfg, outR, errR := config.NewTestConfig()
cmd := TrunkCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

output := readCfgOutput(cfg, outR, errR)

assert.NoError(t, err)
assert.Equal(t, "main", createdBranch, "should create trunk from remote")
assert.Equal(t, []string{"main"}, checkedOut)
assert.Contains(t, output, "Created local trunk branch main from origin/main")
}
25 changes: 24 additions & 1 deletion cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,34 @@ func resolveOriginalRefs(s *stack.Stack) (map[string]string, error) {
return originalRefs, nil
}

// ensureLocalTrunk ensures the trunk branch exists locally. If it does not,
// it fetches the branch from the remote and creates a local tracking branch.
// This handles the case where a user started their stack after renaming their
// initial branch (e.g. `git branch -m newbranch`), leaving no local trunk.
func ensureLocalTrunk(cfg *config.Config, trunk, remote string) error {
if git.BranchExists(trunk) {
return nil
}

if err := git.FetchBranches(remote, []string{trunk}); err != nil {
return fmt.Errorf("could not fetch trunk branch %s from %s: %w", trunk, remote, err)
}

remoteTrunk := remote + "/" + trunk
if err := git.CreateBranch(trunk, remoteTrunk); err != nil {
return fmt.Errorf("could not create local trunk branch %s from %s: %w", trunk, remoteTrunk, err)
}

cfg.Successf("Created local trunk branch %s from %s", trunk, remoteTrunk)
return nil
}

// fastForwardTrunk fast-forwards the trunk branch to match its remote tracking
// branch. Returns true if trunk was updated.
func fastForwardTrunk(cfg *config.Config, trunk, remote, currentBranch string) bool {
// If the local trunk branch doesn't exist, there's nothing to
// fast-forward. The remote tracking ref is sufficient for rebasing.
// fast-forward. Callers should use ensureLocalTrunk beforehand if
// they need trunk to be resolvable as a local ref.
if !git.BranchExists(trunk) {
return false
}
Expand Down
85 changes: 85 additions & 0 deletions cmd/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,88 @@ func TestWarnStacksUnavailableOrPAT_ShowsNotEnabledForOAuth(t *testing.T) {
assert.Contains(t, output, "Stacked PRs are not enabled for this repository")
assert.NotContains(t, output, "Personal access tokens")
}

func TestEnsureLocalTrunk_AlreadyExists(t *testing.T) {
mock := &git.MockOps{
BranchExistsFn: func(name string) bool {
return name == "main"
},
}
restore := git.SetOps(mock)
defer restore()

cfg, _, _ := config.NewTestConfig()
err := ensureLocalTrunk(cfg, "main", "origin")
assert.NoError(t, err)
}

func TestEnsureLocalTrunk_FetchesAndCreates(t *testing.T) {
var fetchedBranches []string
var createdBranch, createdBase string

mock := &git.MockOps{
BranchExistsFn: func(name string) bool {
return false
},
FetchBranchesFn: func(remote string, branches []string) error {
fetchedBranches = branches
return nil
},
CreateBranchFn: func(name, base string) error {
createdBranch = name
createdBase = base
return nil
},
}
restore := git.SetOps(mock)
defer restore()

cfg, _, _ := config.NewTestConfig()
err := ensureLocalTrunk(cfg, "main", "origin")

assert.NoError(t, err)
assert.Equal(t, []string{"main"}, fetchedBranches)
assert.Equal(t, "main", createdBranch)
assert.Equal(t, "origin/main", createdBase)
}

func TestEnsureLocalTrunk_FetchFails(t *testing.T) {
mock := &git.MockOps{
BranchExistsFn: func(name string) bool {
return false
},
FetchBranchesFn: func(remote string, branches []string) error {
return fmt.Errorf("network error")
},
}
restore := git.SetOps(mock)
defer restore()

cfg, _, _ := config.NewTestConfig()
err := ensureLocalTrunk(cfg, "main", "origin")

assert.Error(t, err)
assert.Contains(t, err.Error(), "could not fetch trunk branch main from origin")
}

func TestEnsureLocalTrunk_CreateFails(t *testing.T) {
mock := &git.MockOps{
BranchExistsFn: func(name string) bool {
return false
},
FetchBranchesFn: func(remote string, branches []string) error {
return nil
},
CreateBranchFn: func(name, base string) error {
return fmt.Errorf("ref not found")
},
}
restore := git.SetOps(mock)
defer restore()

cfg, _, _ := config.NewTestConfig()
err := ensureLocalTrunk(cfg, "main", "origin")

assert.Error(t, err)
assert.Contains(t, err.Error(), "could not create local trunk branch main")
}