diff --git a/cmd/push.go b/cmd/push.go index 2d6d088..8b35c88 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -2,9 +2,7 @@ package cmd import ( "errors" - "fmt" - "github.com/cli/go-gh/v2/pkg/prompter" "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" "github.com/github/gh-stack/internal/modify" @@ -138,40 +136,3 @@ func runPush(cfg *config.Config, opts *pushOptions) error { } return nil } - -// pickRemote determines which remote to push to. If remoteOverride is -// non-empty, it is returned directly. Otherwise it delegates to -// git.ResolveRemote for config-based resolution and remote listing. -// If multiple remotes exist with no configured default, the user is -// prompted to select one interactively. -func pickRemote(cfg *config.Config, branch, remoteOverride string) (string, error) { - if remoteOverride != "" { - return remoteOverride, nil - } - - remote, err := git.ResolveRemote(branch) - if err == nil { - return remote, nil - } - - var multi *git.ErrMultipleRemotes - if !errors.As(err, &multi) { - return "", err - } - - if !cfg.IsInteractive() { - return "", fmt.Errorf("multiple remotes configured; set remote.pushDefault or use an interactive terminal") - } - - p := prompter.New(cfg.In, cfg.Out, cfg.Err) - selected, promptErr := p.Select("Multiple remotes found. Which remote should be used?", "", multi.Remotes) - if promptErr != nil { - if isInterruptError(promptErr) { - clearSelectPrompt(cfg, len(multi.Remotes)) - printInterrupt(cfg) - return "", errInterrupt - } - return "", fmt.Errorf("remote selection: %w", promptErr) - } - return multi.Remotes[selected], nil -} diff --git a/cmd/push_test.go b/cmd/push_test.go index d44da1a..059b003 100644 --- a/cmd/push_test.go +++ b/cmd/push_test.go @@ -301,3 +301,106 @@ func TestPush_DoesNotCreatePRs(t *testing.T) { assert.NoError(t, err) assert.False(t, createPRCalled, "push should not create PRs") } + +func TestPickRemote_SavesWhenConfirmed(t *testing.T) { + savedRemote := "" + restore := git.SetOps(&git.MockOps{ + ResolveRemoteFn: func(string) (string, error) { + return "", &git.ErrMultipleRemotes{Remotes: []string{"origin", "upstream"}} + }, + SaveRemoteFn: func(r string) error { + savedRemote = r + return nil + }, + }) + defer restore() + + cfg, outR, errR := config.NewTestConfig() + cfg.ForceInteractive = true + cfg.SelectFn = func(prompt, defaultValue string, options []string) (int, error) { + return 1, nil // select "upstream" + } + cfg.ConfirmFn = func(prompt string, defaultValue bool) (bool, error) { + assert.Contains(t, prompt, "upstream") + assert.True(t, defaultValue) + return true, nil + } + + remote, err := pickRemote(cfg, "my-branch", "") + output := collectOutput(cfg, outR, errR) + + assert.NoError(t, err) + assert.Equal(t, "upstream", remote) + assert.Equal(t, "upstream", savedRemote) + assert.Contains(t, output, "Saved") + assert.Contains(t, output, "git config gh-stack.remote") + assert.Contains(t, output, "git config --unset gh-stack.remote") +} + +func TestPickRemote_SkipsSaveWhenDeclined(t *testing.T) { + saveCalled := false + restore := git.SetOps(&git.MockOps{ + ResolveRemoteFn: func(string) (string, error) { + return "", &git.ErrMultipleRemotes{Remotes: []string{"origin", "upstream"}} + }, + SaveRemoteFn: func(string) error { + saveCalled = true + return nil + }, + }) + defer restore() + + cfg, outR, errR := config.NewTestConfig() + cfg.ForceInteractive = true + cfg.SelectFn = func(prompt, defaultValue string, options []string) (int, error) { + return 0, nil // select "origin" + } + cfg.ConfirmFn = func(prompt string, defaultValue bool) (bool, error) { + return false, nil + } + + remote, err := pickRemote(cfg, "my-branch", "") + output := collectOutput(cfg, outR, errR) + + assert.NoError(t, err) + assert.Equal(t, "origin", remote) + assert.False(t, saveCalled, "SaveRemote should not be called when user declines") + assert.NotContains(t, output, "Saved") +} + +func TestPickRemote_SkipsPromptWhenSingleRemote(t *testing.T) { + restore := git.SetOps(&git.MockOps{ + ResolveRemoteFn: func(string) (string, error) { + return "origin", nil + }, + }) + defer restore() + + cfg, outR, errR := config.NewTestConfig() + + remote, err := pickRemote(cfg, "my-branch", "") + collectOutput(cfg, outR, errR) + + assert.NoError(t, err) + assert.Equal(t, "origin", remote) +} + +func TestPickRemote_OverrideTakesPrecedence(t *testing.T) { + resolveCalled := false + restore := git.SetOps(&git.MockOps{ + ResolveRemoteFn: func(string) (string, error) { + resolveCalled = true + return "", fmt.Errorf("should not be called") + }, + }) + defer restore() + + cfg, outR, errR := config.NewTestConfig() + + remote, err := pickRemote(cfg, "my-branch", "custom") + collectOutput(cfg, outR, errR) + + assert.NoError(t, err) + assert.Equal(t, "custom", remote) + assert.False(t, resolveCalled, "ResolveRemote should not be called when override is provided") +} diff --git a/cmd/utils.go b/cmd/utils.go index 1d9c71e..2cdd816 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -1033,3 +1033,90 @@ func ensureRerere(cfg *config.Config) error { } return nil } + +// pickRemote determines which remote to use. If remoteOverride is +// non-empty, it is returned directly. Otherwise it delegates to +// git.ResolveRemote for config-based resolution and remote listing. +// If multiple remotes exist with no configured default, the user is +// prompted to select one interactively and offered the option to save +// the choice via gh-stack.remote git config. +func pickRemote(cfg *config.Config, branch, remoteOverride string) (string, error) { + if remoteOverride != "" { + return remoteOverride, nil + } + + remote, err := git.ResolveRemote(branch) + if err == nil { + return remote, nil + } + + var multi *git.ErrMultipleRemotes + if !errors.As(err, &multi) { + return "", err + } + + if !cfg.IsInteractive() { + return "", fmt.Errorf("multiple remotes configured; set remote.pushDefault or use an interactive terminal") + } + + p := prompter.New(cfg.In, cfg.Out, cfg.Err) + selectFn := func(prompt, def string, opts []string) (int, error) { + if cfg.SelectFn != nil { + return cfg.SelectFn(prompt, def, opts) + } + return p.Select(prompt, def, opts) + } + + selected, promptErr := selectFn("Multiple remotes found. Which remote should be used?", "", multi.Remotes) + if promptErr != nil { + if isInterruptError(promptErr) { + if cfg.SelectFn == nil { + clearSelectPrompt(cfg, len(multi.Remotes)) + } + printInterrupt(cfg) + return "", errInterrupt + } + return "", fmt.Errorf("remote selection: %w", promptErr) + } + selectedRemote := multi.Remotes[selected] + + // Offer to save the selected remote for future operations. + save, confirmErr := confirmSaveRemote(cfg, selectedRemote) + if confirmErr != nil { + if errors.Is(confirmErr, errInterrupt) { + return "", errInterrupt + } + // Non-fatal: proceed with the selected remote even if the prompt fails. + return selectedRemote, nil + } + if save { + if saveErr := git.SaveRemote(selectedRemote); saveErr == nil { + cfg.Successf("Saved %q as the default remote for gh stack", selectedRemote) + cfg.Printf("To change later, run: %s", cfg.ColorCyan("git config gh-stack.remote ")) + cfg.Printf("To clear, run: %s", cfg.ColorCyan("git config --unset gh-stack.remote")) + } else { + cfg.Warningf("Could not save remote preference: %v", saveErr) + } + } + + return selectedRemote, nil +} + +// confirmSaveRemote asks the user whether to persist the selected remote +// for all future gh stack operations. Returns errInterrupt on Ctrl+C. +func confirmSaveRemote(cfg *config.Config, remote string) (bool, error) { + prompt := fmt.Sprintf("Save %q as the default remote for all gh stack operations?", remote) + if cfg.ConfirmFn != nil { + return cfg.ConfirmFn(prompt, true) + } + p := prompter.New(cfg.In, cfg.Out, cfg.Err) + ok, err := p.Confirm(prompt, true) + if err != nil { + if isInterruptError(err) { + printInterrupt(cfg) + return false, errInterrupt + } + return false, err + } + return ok, nil +} diff --git a/internal/git/git.go b/internal/git/git.go index d03f9c1..a848559 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -190,6 +190,22 @@ func SaveRerereDeclined() error { return ops.SaveRerereDeclined() } +// GetSavedRemote returns the remote saved via gh-stack.remote git config, +// or an error if none is configured. +func GetSavedRemote() (string, error) { + return ops.GetSavedRemote() +} + +// SaveRemote persists the given remote name to gh-stack.remote git config. +func SaveRemote(remote string) error { + return ops.SaveRemote(remote) +} + +// ClearRemote removes the gh-stack.remote git config entry. +func ClearRemote() error { + return ops.ClearRemote() +} + // RebaseOnto rebases a branch using the three-argument form: // // git rebase --onto diff --git a/internal/git/gitops.go b/internal/git/gitops.go index 214766f..bd4fe3e 100644 --- a/internal/git/gitops.go +++ b/internal/git/gitops.go @@ -37,6 +37,9 @@ type Ops interface { IsRerereEnabled() (bool, error) IsRerereDeclined() (bool, error) SaveRerereDeclined() error + GetSavedRemote() (string, error) + SaveRemote(remote string) error + ClearRemote() error RebaseOnto(newBase, oldBase, branch string, opts RebaseOpts) error RebaseContinue(opts RebaseOpts) error RebaseAbort() error @@ -197,9 +200,10 @@ func (d *defaultOps) Push(remote string, branches []string, force, atomic bool) // ResolveRemote determines the remote for pushing a branch. It checks git // config keys in priority order (branch..pushRemote, remote.pushDefault, -// branch..remote), then falls back to listing all remotes. If exactly -// one remote exists it is returned. If multiple exist, ErrMultipleRemotes is -// returned with the list attached. If none exist, a plain error is returned. +// branch..remote), then checks the gh-stack.remote saved preference, +// then falls back to listing all remotes. If exactly one remote exists it is +// returned. If multiple exist, ErrMultipleRemotes is returned with the list +// attached. If none exist, a plain error is returned. func (d *defaultOps) ResolveRemote(branch string) (string, error) { candidates := []string{ "branch." + branch + ".pushRemote", @@ -213,6 +217,11 @@ func (d *defaultOps) ResolveRemote(branch string) (string, error) { } } + // Check gh-stack saved remote preference. + if saved, err := d.GetSavedRemote(); err == nil && saved != "" { + return saved, nil + } + out, err := run("remote") if err != nil { return "", fmt.Errorf("could not list remotes: %w", err) @@ -268,6 +277,22 @@ func (d *defaultOps) SaveRerereDeclined() error { return runSilent("config", "gh-stack.rerere-declined", "true") } +func (d *defaultOps) GetSavedRemote() (string, error) { + out, err := run("config", "--get", "gh-stack.remote") + if err != nil { + return "", err + } + return out, nil +} + +func (d *defaultOps) SaveRemote(remote string) error { + return runSilent("config", "gh-stack.remote", remote) +} + +func (d *defaultOps) ClearRemote() error { + return runSilent("config", "--unset", "gh-stack.remote") +} + func (d *defaultOps) RebaseOnto(newBase, oldBase, branch string, opts RebaseOpts) error { args := []string{"rebase"} if opts.CommitterDateIsAuthorDate { diff --git a/internal/git/gitops_test.go b/internal/git/gitops_test.go index d2c171b..24ca680 100644 --- a/internal/git/gitops_test.go +++ b/internal/git/gitops_test.go @@ -414,3 +414,76 @@ func TestSplitCommitMessage(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// Integration tests for saved remote (gh-stack.remote) +// --------------------------------------------------------------------------- + +func TestIntegration_ResolveRemote_UsesSavedRemote(t *testing.T) { + _, cloneDir := setupBareAndClone(t) + restoreDir := withGitDir(t, cloneDir) + defer restoreDir() + + // Create a branch without upstream tracking. + gitExec(t, cloneDir, "checkout", "-b", "feature") + + // Add a second remote so there are multiple. + gitExec(t, cloneDir, "remote", "add", "upstream", cloneDir) + + // Without saved remote, multiple remotes should return ErrMultipleRemotes. + _, err := ResolveRemote("feature") + var multi *ErrMultipleRemotes + require.ErrorAs(t, err, &multi) + + // Save a remote preference. + gitExec(t, cloneDir, "config", "gh-stack.remote", "upstream") + + // Now ResolveRemote should return the saved remote. + remote, err := ResolveRemote("feature") + require.NoError(t, err) + assert.Equal(t, "upstream", remote) +} + +func TestIntegration_ResolveRemote_GitPushConfigTakesPrecedence(t *testing.T) { + _, cloneDir := setupBareAndClone(t) + restoreDir := withGitDir(t, cloneDir) + defer restoreDir() + + // Add a second remote. + gitExec(t, cloneDir, "remote", "add", "upstream", cloneDir) + + // Save gh-stack.remote to "upstream". + gitExec(t, cloneDir, "config", "gh-stack.remote", "upstream") + + // Set standard git push config to "origin" — this should take precedence. + gitExec(t, cloneDir, "config", "remote.pushDefault", "origin") + + remote, err := ResolveRemote("main") + require.NoError(t, err) + assert.Equal(t, "origin", remote) +} + +func TestIntegration_SaveAndGetRemote(t *testing.T) { + _, cloneDir := setupBareAndClone(t) + restoreDir := withGitDir(t, cloneDir) + defer restoreDir() + + // Initially no saved remote. + _, err := GetSavedRemote() + require.Error(t, err) + + // Save a remote. + require.NoError(t, SaveRemote("upstream")) + + // Should be retrievable. + saved, err := GetSavedRemote() + require.NoError(t, err) + assert.Equal(t, "upstream", saved) + + // Clear it. + require.NoError(t, ClearRemote()) + + // Should be gone. + _, err = GetSavedRemote() + require.Error(t, err) +} diff --git a/internal/git/mock_ops.go b/internal/git/mock_ops.go index 785e01f..29a6213 100644 --- a/internal/git/mock_ops.go +++ b/internal/git/mock_ops.go @@ -1,5 +1,7 @@ package git +import "fmt" + // MockOps is a test double for git operations. // Each field is an optional function that, when set, handles the corresponding // Ops method call. When nil, a reasonable default is returned. @@ -20,6 +22,9 @@ type MockOps struct { IsRerereEnabledFn func() (bool, error) IsRerereDeclinedFn func() (bool, error) SaveRerereDeclinedFn func() error + GetSavedRemoteFn func() (string, error) + SaveRemoteFn func(string) error + ClearRemoteFn func() error RebaseOntoFn func(string, string, string, RebaseOpts) error RebaseContinueFn func(RebaseOpts) error RebaseAbortFn func() error @@ -167,6 +172,27 @@ func (m *MockOps) SaveRerereDeclined() error { return nil } +func (m *MockOps) GetSavedRemote() (string, error) { + if m.GetSavedRemoteFn != nil { + return m.GetSavedRemoteFn() + } + return "", fmt.Errorf("not set") +} + +func (m *MockOps) SaveRemote(remote string) error { + if m.SaveRemoteFn != nil { + return m.SaveRemoteFn(remote) + } + return nil +} + +func (m *MockOps) ClearRemote() error { + if m.ClearRemoteFn != nil { + return m.ClearRemoteFn() + } + return nil +} + func (m *MockOps) RebaseOnto(newBase, oldBase, branch string, opts RebaseOpts) error { if m.RebaseOntoFn != nil { return m.RebaseOntoFn(newBase, oldBase, branch, opts)