Skip to content

Commit 92db283

Browse files
feat(oauth): add stdio OAuth 2.1 login core library
Introduce internal/oauth, a self-contained library that performs the user-facing GitHub OAuth login the stdio server uses to obtain a token without a pre-provisioned PAT. It is independent of MCP: client concerns (elicitation) sit behind the Prompter interface so the flows are testable without a live session. What it provides: - Authorization-code + PKCE flow with a local loopback callback server, state/CSRF validation, and XSS-safe result pages. - Device-authorization flow as a fallback (headless, containers). - A Manager that selects the most secure available channel (browser auto-open -> URL elicitation -> last-resort user action), runs a single flow at a time, and exposes a refreshing token source. Both GitHub OAuth Apps and GitHub Apps are supported without special casing: the token is modeled as an x/oauth2 refreshing TokenSource, so expiring GitHub App user tokens are renewed transparently (the gap that made a stored-token approach silently die after ~8h). When a client lacks secure URL elicitation and the flow falls back to a tool-response message, the message advises the user that their agent/CLI/ IDE does not appear to support URL elicitation and suggests requesting it for improved security. Tests exercise real protocol behavior against an httptest GitHub stand-in: PKCE challenge/verifier, GitHub App refresh-on-expiry, device polling, URL elicitation, declined prompts, the last-resort action with advisory, and single-flight concurrency. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 308ae5b commit 92db283

13 files changed

Lines changed: 1481 additions & 1 deletion

File tree

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/spf13/viper v1.21.0
2020
github.com/stretchr/testify v1.11.1
2121
github.com/yosida95/uritemplate/v3 v3.0.2
22+
golang.org/x/oauth2 v0.35.0
2223
)
2324

2425
require (
@@ -40,7 +41,6 @@ require (
4041
github.com/subosito/gotenv v1.6.0 // indirect
4142
go.yaml.in/yaml/v3 v3.0.4 // indirect
4243
golang.org/x/net v0.38.0 // indirect
43-
golang.org/x/oauth2 v0.35.0 // indirect
4444
golang.org/x/sys v0.41.0 // indirect
4545
golang.org/x/text v0.28.0 // indirect
4646
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect

internal/oauth/callback.go

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package oauth
2+
3+
import (
4+
"context"
5+
"embed"
6+
"fmt"
7+
"html/template"
8+
"net"
9+
"net/http"
10+
"time"
11+
)
12+
13+
//go:embed templates/*.html
14+
var templateFS embed.FS
15+
16+
var (
17+
errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html"))
18+
successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html"))
19+
)
20+
21+
// callbackResult is delivered by the callback server once the browser redirect
22+
// arrives. Exactly one of code or err is set.
23+
type callbackResult struct {
24+
code string
25+
err error
26+
}
27+
28+
// callbackServer is a short-lived local HTTP server that captures the
29+
// authorization code from the OAuth redirect.
30+
type callbackServer struct {
31+
server *http.Server
32+
listener net.Listener
33+
redirect string
34+
results chan callbackResult
35+
}
36+
37+
// listenCallback binds the local callback listener.
38+
//
39+
// A random port (port == 0) binds to 127.0.0.1 only: the redirect target is
40+
// loopback and never reachable off-host. A fixed port binds to all interfaces
41+
// because Docker's published-port DNAT delivers traffic to the container's eth0
42+
// rather than to loopback; exposure is still constrained by the host-side
43+
// publish (e.g. -p 127.0.0.1:8085:8085).
44+
func listenCallback(port int) (net.Listener, error) {
45+
host := "127.0.0.1"
46+
if port > 0 {
47+
host = "0.0.0.0"
48+
}
49+
addr := fmt.Sprintf("%s:%d", host, port)
50+
listener, err := net.Listen("tcp", addr)
51+
if err != nil {
52+
return nil, fmt.Errorf("starting callback listener on %s: %w", addr, err)
53+
}
54+
return listener, nil
55+
}
56+
57+
// newCallbackServer starts a callback server on listener that validates state
58+
// and reports the result on a buffered channel. The redirect URI always uses
59+
// localhost so it matches the value registered on the OAuth/GitHub App.
60+
func newCallbackServer(listener net.Listener, expectedState string) *callbackServer {
61+
cs := &callbackServer{
62+
server: &http.Server{ReadHeaderTimeout: 10 * time.Second}, // ReadHeaderTimeout guards against Slowloris.
63+
listener: listener,
64+
redirect: fmt.Sprintf("http://localhost:%d/callback", listener.Addr().(*net.TCPAddr).Port),
65+
results: make(chan callbackResult, 1),
66+
}
67+
cs.server.Handler = cs.handler(expectedState)
68+
69+
go func() {
70+
if err := cs.server.Serve(listener); err != nil && err != http.ErrServerClosed {
71+
cs.report(callbackResult{err: fmt.Errorf("callback server: %w", err)})
72+
}
73+
}()
74+
75+
return cs
76+
}
77+
78+
// handler renders the callback endpoint. It reports the outcome exactly once and
79+
// always shows the user a friendly page.
80+
func (cs *callbackServer) handler(expectedState string) http.Handler {
81+
mux := http.NewServeMux()
82+
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
83+
q := r.URL.Query()
84+
85+
if errCode := q.Get("error"); errCode != "" {
86+
msg := errCode
87+
if desc := q.Get("error_description"); desc != "" {
88+
msg = fmt.Sprintf("%s: %s", errCode, desc)
89+
}
90+
cs.report(callbackResult{err: fmt.Errorf("authorization failed: %s", msg)})
91+
renderError(w, msg)
92+
return
93+
}
94+
95+
if q.Get("state") != expectedState {
96+
cs.report(callbackResult{err: fmt.Errorf("state mismatch (possible CSRF)")})
97+
renderError(w, "state mismatch")
98+
return
99+
}
100+
101+
code := q.Get("code")
102+
if code == "" {
103+
cs.report(callbackResult{err: fmt.Errorf("no authorization code in callback")})
104+
renderError(w, "no authorization code received")
105+
return
106+
}
107+
108+
cs.report(callbackResult{code: code})
109+
renderSuccess(w)
110+
})
111+
return mux
112+
}
113+
114+
// report delivers the first outcome and drops later ones (the channel is
115+
// buffered for one; subsequent redirect retries must not block the handler).
116+
func (cs *callbackServer) report(res callbackResult) {
117+
select {
118+
case cs.results <- res:
119+
default:
120+
}
121+
}
122+
123+
// wait blocks for the callback outcome or ctx cancellation, then shuts the
124+
// server down. It is safe to call once per server.
125+
func (cs *callbackServer) wait(ctx context.Context) (string, error) {
126+
defer cs.close()
127+
select {
128+
case res := <-cs.results:
129+
return res.code, res.err
130+
case <-ctx.Done():
131+
return "", ctx.Err()
132+
}
133+
}
134+
135+
func (cs *callbackServer) close() {
136+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
137+
defer cancel()
138+
_ = cs.server.Shutdown(shutdownCtx)
139+
_ = cs.listener.Close()
140+
}
141+
142+
func renderSuccess(w http.ResponseWriter) {
143+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
144+
if err := successTemplate.Execute(w, nil); err != nil {
145+
http.Error(w, "internal error", http.StatusInternalServerError)
146+
}
147+
}
148+
149+
// renderError shows the failure page. html/template auto-escapes msg, so a
150+
// hostile error_description cannot inject markup.
151+
func renderError(w http.ResponseWriter, msg string) {
152+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
153+
if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: msg}); err != nil {
154+
http.Error(w, "internal error", http.StatusInternalServerError)
155+
}
156+
}

internal/oauth/callback_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package oauth
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
// serveCallback drives the callback handler with the given query string and
14+
// returns the recorded response and the single reported result.
15+
func serveCallback(t *testing.T, expectedState, query string) (*httptest.ResponseRecorder, callbackResult) {
16+
t.Helper()
17+
cs := &callbackServer{results: make(chan callbackResult, 1)}
18+
rec := httptest.NewRecorder()
19+
req := httptest.NewRequest(http.MethodGet, "/callback?"+query, nil)
20+
21+
cs.handler(expectedState).ServeHTTP(rec, req)
22+
23+
select {
24+
case res := <-cs.results:
25+
return rec, res
26+
default:
27+
t.Fatal("handler did not report a result")
28+
return nil, callbackResult{}
29+
}
30+
}
31+
32+
func TestCallbackHandlerSuccess(t *testing.T) {
33+
rec, res := serveCallback(t, "state123", "code=the-code&state=state123")
34+
35+
require.NoError(t, res.err)
36+
assert.Equal(t, "the-code", res.code)
37+
assert.Equal(t, http.StatusOK, rec.Code)
38+
assert.Contains(t, rec.Body.String(), "Authorization Successful")
39+
}
40+
41+
func TestCallbackHandlerStateMismatch(t *testing.T) {
42+
rec, res := serveCallback(t, "expected", "code=the-code&state=attacker")
43+
44+
require.Error(t, res.err)
45+
assert.Empty(t, res.code)
46+
assert.Contains(t, res.err.Error(), "state mismatch")
47+
assert.Contains(t, rec.Body.String(), "state mismatch")
48+
}
49+
50+
func TestCallbackHandlerMissingCode(t *testing.T) {
51+
_, res := serveCallback(t, "state123", "state=state123")
52+
53+
require.Error(t, res.err)
54+
assert.Contains(t, res.err.Error(), "no authorization code")
55+
}
56+
57+
func TestCallbackHandlerOAuthError(t *testing.T) {
58+
_, res := serveCallback(t, "state123", "error=access_denied&error_description=user+said+no")
59+
60+
require.Error(t, res.err)
61+
assert.Contains(t, res.err.Error(), "access_denied")
62+
assert.Contains(t, res.err.Error(), "user said no")
63+
}
64+
65+
func TestCallbackHandlerEscapesError(t *testing.T) {
66+
rec, _ := serveCallback(t, "state123", "error=evil&error_description=%3Cscript%3Ealert(1)%3C%2Fscript%3E")
67+
68+
body := rec.Body.String()
69+
assert.NotContains(t, body, "<script>", "error message must be HTML-escaped")
70+
assert.Contains(t, body, "&lt;script&gt;")
71+
}
72+
73+
func TestListenCallbackRandomPortIsLoopback(t *testing.T) {
74+
listener, err := listenCallback(0)
75+
require.NoError(t, err)
76+
defer listener.Close()
77+
78+
addr, ok := listener.Addr().(*net.TCPAddr)
79+
require.True(t, ok)
80+
assert.True(t, addr.IP.IsLoopback(), "random port must bind loopback only, got %s", addr.IP)
81+
assert.NotZero(t, addr.Port)
82+
}

internal/oauth/env.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package oauth
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"os"
7+
"os/exec"
8+
"runtime"
9+
"strings"
10+
)
11+
12+
// openBrowser tries to open url in the user's default browser. It returns an
13+
// error when no browser can plausibly be launched so the caller can fall back
14+
// to elicitation. On Linux it treats a headless session (no display server) as
15+
// unopenable, which is the common case for SSH and containers.
16+
func openBrowser(url string) error {
17+
var cmd *exec.Cmd
18+
switch runtime.GOOS {
19+
case "linux":
20+
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
21+
return fmt.Errorf("no display server detected")
22+
}
23+
cmd = exec.Command("xdg-open", url)
24+
case "darwin":
25+
cmd = exec.Command("open", url)
26+
case "windows":
27+
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
28+
default:
29+
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
30+
}
31+
32+
cmd.Stdout = io.Discard
33+
cmd.Stderr = io.Discard
34+
return cmd.Start()
35+
}
36+
37+
// isRunningInDocker reports whether the process is running inside a Docker (or
38+
// containerd) container. Detection relies on Linux-specific paths and is always
39+
// false elsewhere. It is used only to skip a PKCE flow that cannot work: a
40+
// random callback port inside a container cannot be reached from the host
41+
// browser, so we go straight to device flow in that case.
42+
func isRunningInDocker() bool {
43+
if runtime.GOOS != "linux" {
44+
return false
45+
}
46+
if _, err := os.Stat("/.dockerenv"); err == nil {
47+
return true
48+
}
49+
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
50+
s := string(data)
51+
if strings.Contains(s, "docker") || strings.Contains(s, "containerd") {
52+
return true
53+
}
54+
}
55+
return false
56+
}

0 commit comments

Comments
 (0)