Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions cmd/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var (
syncAbort bool
syncCherryPick bool
syncCrossWorktree bool
syncAll bool
// stdinReader allows tests to inject mock input for prompts
stdinReader io.Reader = os.Stdin
)
Expand Down Expand Up @@ -67,6 +68,9 @@ Stack branches are always pushed to 'origin'.`,
Example: ` # Sync all branches and update PRs
stack sync

# Sync the full stack including children below the current branch
stack sync --all

# Sync fetching base branches from upstream (fork workflow)
stack sync upstream

Expand Down Expand Up @@ -132,6 +136,7 @@ func init() {
syncCmd.Flags().BoolVarP(&syncAbort, "abort", "a", false, "Abort an interrupted sync and clean up state")
syncCmd.Flags().BoolVar(&syncCherryPick, "cherry-pick", false, "Rebuild polluted branches by cherry-picking unique commits (creates backup)")
syncCmd.Flags().BoolVar(&syncCrossWorktree, "cross-worktree", false, "Also sync branches checked out in other worktrees (uses git -C)")
syncCmd.Flags().BoolVar(&syncAll, "all", false, "Sync the full stack including children below the current branch")
}

func runSync(gitClient git.GitClient, githubClient github.GitHubClient, syncRemote string) error {
Expand Down Expand Up @@ -350,15 +355,14 @@ func runSync(gitClient git.GitClient, githubClient github.GitHubClient, syncRemo
}
}()

// While network operations run in background, do local work
// Get only branches in the current branch's stack
// While network operations run in background, determine which branches to sync
// Get the ancestor chain for the current branch
chain, err := stack.GetStackChain(gitClient, originalBranch)
if err != nil {
return fmt.Errorf("failed to get stack chain: %w", err)
}

if len(chain) == 0 {
// Wait for parallel operations before returning
wg.Wait()
if fetchErr != nil {
return fmt.Errorf("failed to fetch: %w", fetchErr)
Expand All @@ -374,21 +378,41 @@ func runSync(gitClient git.GitClient, githubClient github.GitHubClient, syncRemo
return nil
}

// Build set of branches in current stack
chainSet := make(map[string]bool)
for _, b := range chain {
chainSet[b] = true
// Determine which branches to sync
var branchSet map[string]bool

if syncAll {
// Sync the full tree: find the root stack branch and get all descendants
root := chain[0]
if root == baseBranch && len(chain) > 1 {
root = chain[1]
}

descendants, err := stack.GetDescendants(gitClient, root)
if err != nil {
return fmt.Errorf("failed to get stack tree: %w", err)
}

branchSet = make(map[string]bool)
for _, d := range descendants {
branchSet[d] = true
}
} else {
// Default: sync only the ancestor chain up to the current branch
branchSet = make(map[string]bool)
for _, b := range chain {
branchSet[b] = true
}
}

// Get all stack branches and filter to current stack only
allStackBranches, err := stack.GetStackBranches(gitClient)
if err != nil {
return fmt.Errorf("failed to get stack branches: %w", err)
}

var stackBranches []stack.StackBranch
for _, b := range allStackBranches {
if chainSet[b.Name] {
if branchSet[b.Name] {
stackBranches = append(stackBranches, b)
}
}
Expand All @@ -400,32 +424,28 @@ func runSync(gitClient git.GitClient, githubClient github.GitHubClient, syncRemo
existingBranchNames[b.Name] = true
}

// Walk the chain and add missing branches
for i, branchName := range chain {
if branchName == baseBranch {
continue // Skip base branch
continue
}
if existingBranchNames[branchName] {
continue // Already in stackBranches
continue
}

// Infer parent from chain (previous branch in the chain)
var inferredParent string
if i > 0 {
inferredParent = chain[i-1]
} else {
inferredParent = baseBranch
}

// Check if branch exists locally before adding
if gitClient.BranchExists(branchName) {
stackBranches = append(stackBranches, stack.StackBranch{
Name: branchName,
Parent: inferredParent,
})
existingBranchNames[branchName] = true

// Configure stackparent so future syncs work correctly
configKey := fmt.Sprintf("branch.%s.stackparent", branchName)
if err := gitClient.SetConfig(configKey, inferredParent); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to set stackparent for %s: %v\n", branchName, err)
Expand Down
8 changes: 8 additions & 0 deletions decision-log.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

Architectural and design decisions for Stackinator.

## 2026-05-20 — Add `--all` flag to `stack sync`

**Decision**: Allow `stack sync --all` to sync the full stack, not just the ancestor chain.

**Context**: The default `stack sync` only processes the ancestor chain from the base branch up to the current branch. Children below the current branch are not synced. Users had to check out a leaf branch or run sync multiple times to propagate changes through the full stack.

**Resolution**: Added `--all` flag that finds the root of the current branch's stack (first stack branch in the ancestor chain), then uses `GetDescendants` (BFS through children map) to collect all branches in the stack. The per-branch processing loop is unchanged — topological sort ensures correct ordering, and each branch rebases onto its configured parent.

## 2026-05-07 — Add `--cross-worktree` flag to `stack sync`

**Decision**: Allow `stack sync` to rebase branches checked out in other worktrees via `git -C <path>`, gated behind `--cross-worktree`.
Expand Down
30 changes: 30 additions & 0 deletions internal/stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,36 @@ func GetChildrenOf(gitClient git.GitClient, branch string) ([]StackBranch, error
return children, nil
}

// GetDescendants returns the specified branch and all its descendants in the stack
func GetDescendants(gitClient git.GitClient, branch string) ([]string, error) {
allBranches, err := GetStackBranches(gitClient)
if err != nil {
return nil, err
}

childrenMap := make(map[string][]string)
for _, b := range allBranches {
childrenMap[b.Parent] = append(childrenMap[b.Parent], b.Name)
}

var result []string
queue := []string{branch}
visited := make(map[string]bool)

for len(queue) > 0 {
current := queue[0]
queue = queue[1:]
if visited[current] {
continue
}
visited[current] = true
result = append(result, current)
queue = append(queue, childrenMap[current]...)
}

return result, nil
}

// GetStackChain returns the chain from the base to the specified branch
func GetStackChain(gitClient git.GitClient, branch string) ([]string, error) {
// Get all parents at once for efficiency
Expand Down
75 changes: 75 additions & 0 deletions internal/stack/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,81 @@ func TestGetStackBranches(t *testing.T) {
}
}

func TestGetDescendants(t *testing.T) {
testutil.SetupTest()
defer testutil.TeardownTest()

tests := []struct {
name string
branch string
stackParents map[string]string
expectedDescendants []string
expectError bool
}{
{
name: "branch with children and grandchildren",
branch: "feature-a",
stackParents: map[string]string{
"feature-a": "main",
"feature-b": "feature-a",
"feature-c": "feature-b",
"other": "main",
},
expectedDescendants: []string{"feature-a", "feature-b", "feature-c"},
expectError: false,
},
{
name: "leaf branch with no children",
branch: "feature-c",
stackParents: map[string]string{
"feature-a": "main",
"feature-b": "feature-a",
"feature-c": "feature-b",
},
expectedDescendants: []string{"feature-c"},
expectError: false,
},
{
name: "branch with multiple children",
branch: "feature-a",
stackParents: map[string]string{
"feature-a": "main",
"feature-b": "feature-a",
"feature-c": "feature-a",
},
expectedDescendants: []string{"feature-a", "feature-b", "feature-c"},
expectError: false,
},
{
name: "branch not in any stack",
branch: "untracked",
stackParents: map[string]string{
"feature-a": "main",
},
expectedDescendants: []string{"untracked"},
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockGit := new(testutil.MockGitClient)
mockGit.On("GetAllStackParents").Return(tt.stackParents, nil)

descendants, err := GetDescendants(mockGit, tt.branch)

if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.ElementsMatch(t, tt.expectedDescendants, descendants)
}

mockGit.AssertExpectations(t)
})
}
}

func TestGetStackChain(t *testing.T) {
testutil.SetupTest()
defer testutil.TeardownTest()
Expand Down
Loading