diff --git a/cmd/checkout.go b/cmd/checkout.go index fab8188..0ceed6a 100644 --- a/cmd/checkout.go +++ b/cmd/checkout.go @@ -482,12 +482,9 @@ func importRemoteStack( } // Ensure trunk exists locally - if !git.BranchExists(trunk) { - remoteTrunk := remote + "/" + trunk - if err := git.CreateBranch(trunk, remoteTrunk); err != nil { - cfg.Errorf("could not create trunk branch %s from %s: %v", trunk, remoteTrunk, err) - return nil, ErrSilent - } + if err := ensureLocalTrunk(cfg, trunk, remote); err != nil { + cfg.Errorf("%s", err) + return nil, ErrSilent } // Create local branches for each PR's head branch. diff --git a/cmd/modify.go b/cmd/modify.go index 8da6e0a..3d992bb 100644 --- a/cmd/modify.go +++ b/cmd/modify.go @@ -1,6 +1,7 @@ package cmd import ( + "errors" "fmt" "strings" @@ -297,6 +298,22 @@ func checkModifyPreconditions(cfg *config.Config) (*loadStackResult, error) { return nil, ErrSilent } + // Ensure trunk branch exists locally (it may be absent if the user + // renamed their initial branch before starting the stack). + if !git.BranchExists(s.Trunk.Branch) { + remote, err := pickRemote(cfg, result.CurrentBranch, "") + if err != nil { + if !errors.Is(err, errInterrupt) { + cfg.Errorf("failed to resolve remote: %s", err) + } + return nil, ErrSilent + } + if err := ensureLocalTrunk(cfg, s.Trunk.Branch, remote); err != nil { + cfg.Errorf("%s", err) + return nil, ErrSilent + } + } + // Show loading indicator while syncing PRs fmt.Fprintf(cfg.Err, "Loading stack...") diff --git a/cmd/rebase.go b/cmd/rebase.go index c18d20a..f1acea5 100644 --- a/cmd/rebase.go +++ b/cmd/rebase.go @@ -130,6 +130,12 @@ func runRebase(cfg *config.Config, opts *rebaseOptions) error { cfg.Successf("Fetched %s", remote) } + // Ensure trunk exists locally before fast-forward or cascade rebase. + if err := ensureLocalTrunk(cfg, s.Trunk.Branch, remote); err != nil { + cfg.Errorf("%s", err) + return ErrSilent + } + // Fast-forward trunk so the cascade rebase targets the latest upstream. fastForwardTrunk(cfg, s.Trunk.Branch, remote, currentBranch) diff --git a/cmd/trunk.go b/cmd/trunk.go index 90c66b2..4e6fc63 100644 --- a/cmd/trunk.go +++ b/cmd/trunk.go @@ -42,6 +42,21 @@ func runTrunk(cfg *config.Config) error { return nil } + // Ensure trunk exists locally before checkout. + if !git.BranchExists(trunk) { + remote, err := pickRemote(cfg, currentBranch, "") + if err != nil { + if !errors.Is(err, errInterrupt) { + cfg.Errorf("failed to resolve remote: %s", err) + } + return ErrSilent + } + if err := ensureLocalTrunk(cfg, trunk, remote); err != nil { + cfg.Errorf("%s", err) + return ErrSilent + } + } + if err := git.CheckoutBranch(trunk); err != nil { return err } diff --git a/cmd/trunk_test.go b/cmd/trunk_test.go index 5851bed..13228da 100644 --- a/cmd/trunk_test.go +++ b/cmd/trunk_test.go @@ -212,3 +212,53 @@ func TestTrunk_RejectsArgs(t *testing.T) { assert.Error(t, err, "should reject positional arguments") } + +func TestTrunk_MissingLocallyCreatedFromRemote(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{{Branch: "b1"}, {Branch: "b2"}}, + } + + var checkedOut []string + var createdBranch string + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + mock := &git.MockOps{ + GitDirFn: func() (string, error) { return tmpDir, nil }, + CurrentBranchFn: func() (string, error) { return "b1", nil }, + BranchExistsFn: func(name string) bool { + // trunk does not exist locally + return name != "main" + }, + ResolveRemoteFn: func(branch string) (string, error) { + return "origin", nil + }, + FetchBranchesFn: func(remote string, branches []string) error { + return nil + }, + CreateBranchFn: func(name, base string) error { + createdBranch = name + return nil + }, + CheckoutBranchFn: func(name string) error { + checkedOut = append(checkedOut, name) + return nil + }, + } + restore := git.SetOps(mock) + defer restore() + + cfg, outR, errR := config.NewTestConfig() + cmd := TrunkCmd(cfg) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + output := readCfgOutput(cfg, outR, errR) + + assert.NoError(t, err) + assert.Equal(t, "main", createdBranch, "should create trunk from remote") + assert.Equal(t, []string{"main"}, checkedOut) + assert.Contains(t, output, "Created local trunk branch main from origin/main") +} diff --git a/cmd/utils.go b/cmd/utils.go index 77c2a96..1d9c71e 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -693,11 +693,34 @@ func resolveOriginalRefs(s *stack.Stack) (map[string]string, error) { return originalRefs, nil } +// ensureLocalTrunk ensures the trunk branch exists locally. If it does not, +// it fetches the branch from the remote and creates a local tracking branch. +// This handles the case where a user started their stack after renaming their +// initial branch (e.g. `git branch -m newbranch`), leaving no local trunk. +func ensureLocalTrunk(cfg *config.Config, trunk, remote string) error { + if git.BranchExists(trunk) { + return nil + } + + if err := git.FetchBranches(remote, []string{trunk}); err != nil { + return fmt.Errorf("could not fetch trunk branch %s from %s: %w", trunk, remote, err) + } + + remoteTrunk := remote + "/" + trunk + if err := git.CreateBranch(trunk, remoteTrunk); err != nil { + return fmt.Errorf("could not create local trunk branch %s from %s: %w", trunk, remoteTrunk, err) + } + + cfg.Successf("Created local trunk branch %s from %s", trunk, remoteTrunk) + return nil +} + // fastForwardTrunk fast-forwards the trunk branch to match its remote tracking // branch. Returns true if trunk was updated. func fastForwardTrunk(cfg *config.Config, trunk, remote, currentBranch string) bool { // If the local trunk branch doesn't exist, there's nothing to - // fast-forward. The remote tracking ref is sufficient for rebasing. + // fast-forward. Callers should use ensureLocalTrunk beforehand if + // they need trunk to be resolvable as a local ref. if !git.BranchExists(trunk) { return false } diff --git a/cmd/utils_test.go b/cmd/utils_test.go index 412f7e1..49ce5d9 100644 --- a/cmd/utils_test.go +++ b/cmd/utils_test.go @@ -725,3 +725,88 @@ func TestWarnStacksUnavailableOrPAT_ShowsNotEnabledForOAuth(t *testing.T) { assert.Contains(t, output, "Stacked PRs are not enabled for this repository") assert.NotContains(t, output, "Personal access tokens") } + +func TestEnsureLocalTrunk_AlreadyExists(t *testing.T) { + mock := &git.MockOps{ + BranchExistsFn: func(name string) bool { + return name == "main" + }, + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + err := ensureLocalTrunk(cfg, "main", "origin") + assert.NoError(t, err) +} + +func TestEnsureLocalTrunk_FetchesAndCreates(t *testing.T) { + var fetchedBranches []string + var createdBranch, createdBase string + + mock := &git.MockOps{ + BranchExistsFn: func(name string) bool { + return false + }, + FetchBranchesFn: func(remote string, branches []string) error { + fetchedBranches = branches + return nil + }, + CreateBranchFn: func(name, base string) error { + createdBranch = name + createdBase = base + return nil + }, + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + err := ensureLocalTrunk(cfg, "main", "origin") + + assert.NoError(t, err) + assert.Equal(t, []string{"main"}, fetchedBranches) + assert.Equal(t, "main", createdBranch) + assert.Equal(t, "origin/main", createdBase) +} + +func TestEnsureLocalTrunk_FetchFails(t *testing.T) { + mock := &git.MockOps{ + BranchExistsFn: func(name string) bool { + return false + }, + FetchBranchesFn: func(remote string, branches []string) error { + return fmt.Errorf("network error") + }, + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + err := ensureLocalTrunk(cfg, "main", "origin") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not fetch trunk branch main from origin") +} + +func TestEnsureLocalTrunk_CreateFails(t *testing.T) { + mock := &git.MockOps{ + BranchExistsFn: func(name string) bool { + return false + }, + FetchBranchesFn: func(remote string, branches []string) error { + return nil + }, + CreateBranchFn: func(name, base string) error { + return fmt.Errorf("ref not found") + }, + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + err := ensureLocalTrunk(cfg, "main", "origin") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not create local trunk branch main") +}