From 004a382c191d09facb845b6d0e204727ec9aaa55 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:23:46 +0200 Subject: [PATCH 1/7] removed unused func --- internal/oauthdevice/device_flow_test.go | 509 +++++++++++++++++++++++ 1 file changed, 509 insertions(+) create mode 100644 internal/oauthdevice/device_flow_test.go diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauthdevice/device_flow_test.go new file mode 100644 index 0000000000..02b3923d88 --- /dev/null +++ b/internal/oauthdevice/device_flow_test.go @@ -0,0 +1,509 @@ +package oauthdevice + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +const ( + testDeviceAuthPath = "/device/code" + testTokenPath = "/token" +) + +type testServerOptions struct { + handlers map[string]http.HandlerFunc + wellKnownFunc func(w http.ResponseWriter, r *http.Request) +} + +func newTestServer(t *testing.T, opts testServerOptions) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case wellKnownPath: + if opts.wellKnownFunc != nil { + opts.wellKnownFunc(w, r) + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + Issuer: "http://" + r.Host, + DeviceAuthorizationEndpoint: "http://" + r.Host + testDeviceAuthPath, + TokenEndpoint: "http://" + r.Host + testTokenPath, + }) + } + default: + if handler, ok := opts.handlers[r.URL.Path]; ok { + handler(w, r) + } else { + t.Errorf("unexpected path: %s", r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + } + } + })) +} + +func TestDiscover_Success(t *testing.T) { + server := newTestServer(t, testServerOptions{}) + defer server.Close() + + client := NewClient() + config, err := client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + if config.DeviceAuthorizationEndpoint != server.URL+testDeviceAuthPath { + t.Errorf("DeviceAuthorizationEndpoint = %q, want %q", config.DeviceAuthorizationEndpoint, server.URL+testDeviceAuthPath) + } + if config.TokenEndpoint != server.URL+testTokenPath { + t.Errorf("TokenEndpoint = %q, want %q", config.TokenEndpoint, server.URL+testTokenPath) + } +} + +func TestDiscover_Caching(t *testing.T) { + var callCount int32 + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + DeviceAuthorizationEndpoint: "http://example.com/device", + TokenEndpoint: "http://example.com/token", + }) + }, + }) + defer server.Close() + + client := NewClient() + + // Populate the cache + _, err := client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + // Second call should use cache + _, err = client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + if atomic.LoadInt32(&callCount) != 1 { + t.Errorf("callCount = %d, want 1 (second call should use cache)", callCount) + } +} + +func TestDiscover_Error(t *testing.T) { + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Discover(context.Background(), server.URL) + if err == nil { + t.Fatal("Discover() expected error, got nil") + } + + if !strings.Contains(err.Error(), "404") { + t.Errorf("error = %q, want to contain '404'", err.Error()) + } +} + +func TestStart_Success(t *testing.T) { + wantResponse := DeviceAuthResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + VerificationURI: "https://example.com/device", + VerificationURIComplete: "https://example.com/device?user_code=ABCD-1234", + ExpiresIn: 1800, + Interval: 5, + } + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if got := r.FormValue("client_id"); got != ClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wantResponse) + }, + }, + }) + defer server.Close() + + client := NewClient() + resp, err := client.Start(context.Background(), server.URL, nil) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if resp.DeviceCode != wantResponse.DeviceCode { + t.Errorf("DeviceCode = %q, want %q", resp.DeviceCode, wantResponse.DeviceCode) + } + if resp.UserCode != wantResponse.UserCode { + t.Errorf("UserCode = %q, want %q", resp.UserCode, wantResponse.UserCode) + } + if resp.VerificationURI != wantResponse.VerificationURI { + t.Errorf("VerificationURI = %q, want %q", resp.VerificationURI, wantResponse.VerificationURI) + } + if resp.VerificationURIComplete != wantResponse.VerificationURIComplete { + t.Errorf("VerificationURIComplete = %q, want %q", resp.VerificationURIComplete, wantResponse.VerificationURIComplete) + } + if resp.ExpiresIn != wantResponse.ExpiresIn { + t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, wantResponse.ExpiresIn) + } + if resp.Interval != wantResponse.Interval { + t.Errorf("Interval = %d, want %d", resp.Interval, wantResponse.Interval) + } +} + +func TestStart_WithScopes(t *testing.T) { + var receivedScope string + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + receivedScope = r.FormValue("scope") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(DeviceAuthResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + VerificationURI: "https://example.com/device", + ExpiresIn: 1800, + Interval: 5, + }) + }, + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, []string{"read", "write"}) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if receivedScope != "read write" { + t.Errorf("scope = %q, want %q", receivedScope, "read write") + } +} + +func TestStart_Error(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "invalid_client", + ErrorDescription: "Unknown client", + }) + }, + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("Start() expected error, got nil") + } + + wantErr := "device auth failed: invalid_client: Unknown client" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestStart_NoDeviceEndpoint(t *testing.T) { + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + TokenEndpoint: "http://example.com/token", + }) + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("Start() expected error, got nil") + } + + if !strings.Contains(err.Error(), "device authorization endpoint not found") { + t.Errorf("error = %q, want to contain 'device authorization endpoint not found'", err.Error()) + } +} + +func TestPoll_Success(t *testing.T) { + wantToken := TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + } + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if got := r.FormValue("client_id"); got != ClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + } + if got := r.FormValue("grant_type"); got != GrantTypeDeviceCode { + t.Errorf("unexpected grant_type: got %q", got) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wantToken) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != wantToken.AccessToken { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, wantToken.AccessToken) + } + if resp.TokenType != wantToken.TokenType { + t.Errorf("TokenType = %q, want %q", resp.TokenType, wantToken.TokenType) + } +} + +func TestPoll_AuthorizationPending(t *testing.T) { + var callCount int32 + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + + w.Header().Set("Content-Type", "application/json") + + if count < 3 { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + ErrorDescription: "The user has not yet completed authorization", + }) + return + } + + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test-access-token") + } + + if atomic.LoadInt32(&callCount) != 3 { + t.Errorf("callCount = %d, want 3", callCount) + } +} + +func TestPoll_SlowDown(t *testing.T) { + var callCount int32 + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + + w.Header().Set("Content-Type", "application/json") + + if count == 1 { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "slow_down", + }) + return + } + + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test-access-token") + } + + if atomic.LoadInt32(&callCount) != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } +} + +func TestPoll_ExpiredToken(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "expired_token", + ErrorDescription: "The device code has expired", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "device code expired" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_AccessDenied(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "access_denied", + ErrorDescription: "The user denied the request", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "authorization was denied by the user" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_Timeout(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "device code expired" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_ContextCancellation(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + }) + }, + }, + }) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + client := NewClient().(*httpClient) + _, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + if err != context.Canceled && !strings.Contains(err.Error(), "context canceled") { + t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err) + } +} From 9f9d2404ad613571fb755dccca8bcbe16548aebb Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:25:07 +0200 Subject: [PATCH 2/7] update usage --- cmd/src/login.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/src/login.go b/cmd/src/login.go index ab5a097c71..bc5af1e3a0 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -28,6 +28,10 @@ Examples: Authenticate to Sourcegraph.com: $ src login https://sourcegraph.com + + Use OAuth device flow to authenticate: + + $ src login --device-flow https://sourcegraph.com ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) From bb43d20f10ac1512e69a9ea0070d69e68b1ce3ff Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:31:04 +0200 Subject: [PATCH 3/7] spelling --- internal/oauthdevice/device_flow.go | 304 ++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 internal/oauthdevice/device_flow.go diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go new file mode 100644 index 0000000000..851779cc3a --- /dev/null +++ b/internal/oauthdevice/device_flow.go @@ -0,0 +1,304 @@ +// Package oauthdevice implements the OAuth 2.0 Device Authorization Grant (RFC 8628) +// for authenticating with Sourcegraph instances. +package oauthdevice + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const ( + ClientID = "sgo_cid_sourcegraph-cli" + + wellKnownPath = "/.well-known/openid-configuration" + + GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" + + ScopeOpenID string = "openid" + ScopeProfile string = "profile" + ScopeEmail string = "email" + ScopeOfflineAccess string = "offline_access" + ScopeUserAll string = "user:all" +) + +var defaultScopes = []string{ScopeEmail, ScopeOfflineAccess, ScopeOpenID, ScopeProfile, ScopeUserAll} + +// OIDCConfiguration represents the relevant fields from the OpenID Connect +// Discovery document at /.well-known/openid-configuration +type OIDCConfiguration struct { + Issuer string `json:"issuer,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` +} + +type DeviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type ErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +type Client interface { + Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) + Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) + Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) +} + +type httpClient struct { + client *http.Client + // cached OIDC configuration per endpoint + configCache map[string]*OIDCConfiguration +} + +func NewClient() Client { + return &httpClient{ + client: &http.Client{ + Timeout: 30 * time.Second, + }, + configCache: make(map[string]*OIDCConfiguration), + } +} + +func NewClientWithHTTPClient(c *http.Client) Client { + return &httpClient{ + client: c, + configCache: make(map[string]*OIDCConfiguration), + } +} + +// Discover fetches the openid-configuration which contains all the routes a client should +// use for authorization, device flows, tokens etc. +// +// Before making any requests, the configCache is checked and if there is a cache hit, the +// cached config is returned. +func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) { + endpoint = strings.TrimRight(endpoint, "/") + + if config, ok := c.configCache[endpoint]; ok { + return config, nil + } + + reqURL := endpoint + wellKnownPath + + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return nil, errors.Wrap(err, "creating discovery request") + } + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "discovery request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading discovery response") + } + + if resp.StatusCode != http.StatusOK { + return nil, errors.Newf("discovery failed with status %d: %s", resp.StatusCode, string(body)) + } + + var config OIDCConfiguration + if err := json.Unmarshal(body, &config); err != nil { + return nil, errors.Wrap(err, "parsing discovery response") + } + + c.configCache[endpoint] = &config + + return &config, nil +} + +// Start starts the OAuth device flow with the given endpoint. If no scopes are given the default scopes are used. +// +// Default Scopes: "openid" "profile" "email" "offline_access" "user:all" +func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) { + endpoint = strings.TrimRight(endpoint, "/") + + // Discover OIDC configuration + config, err := c.Discover(ctx, endpoint) + if err != nil { + return nil, errors.Wrap(err, "OIDC discovery failed") + } + + if config.DeviceAuthorizationEndpoint == "" { + return nil, errors.New("device authorization endpoint not found in OIDC configuration; the server may not support device flow") + } + + data := url.Values{} + data.Set("client_id", ClientID) + if len(scopes) > 0 { + data.Set("scope", strings.Join(scopes, " ")) + } else { + data.Set("scope", strings.Join(defaultScopes, " ")) + } + + req, err := http.NewRequestWithContext(ctx, "POST", config.DeviceAuthorizationEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating device auth request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "device auth request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading device auth response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, errors.Newf("device auth failed: %s: %s", errResp.Error, errResp.ErrorDescription) + } + return nil, errors.Newf("device auth failed with status %d: %s", resp.StatusCode, string(body)) + } + + var authResp DeviceAuthResponse + if err := json.Unmarshal(body, &authResp); err != nil { + return nil, errors.Wrap(err, "parsing device auth response") + } + + return &authResp, nil +} + +// Poll polls the OAuth token endpoint until the device has been authorized or not +// +// We poll as long as the authorization is pending. If the server tells us to slow down, we will wait 5 secs extra. +// +// Polling will stop when: +// - Device is authorized, and a token is returned +// - Device code has expried +// - User denied authorization +func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) { + endpoint = strings.TrimRight(endpoint, "/") + + // Discover OIDC configuration (should be cached from Start) + config, err := c.Discover(ctx, endpoint) + if err != nil { + return nil, errors.Wrap(err, "OIDC discovery failed") + } + + if config.TokenEndpoint == "" { + return nil, errors.New("token endpoint not found in OIDC configuration") + } + + deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) + + for { + if time.Now().After(deadline) { + return nil, errors.New("device code expired") + } + + if !testing.Testing() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + } + + tokenResp, err := c.pollOnce(ctx, config.TokenEndpoint, deviceCode) + if err != nil { + var pollErr *PollError + if errors.As(err, &pollErr) { + switch pollErr.Code { + case "authorization_pending": + continue + case "slow_down": + interval += 5 * time.Second + continue + case "expired_token": + return nil, errors.New("device code expired") + case "access_denied": + return nil, errors.New("authorization was denied by the user") + } + } + return nil, err + } + + return tokenResp, nil + } +} + +type PollError struct { + Code string + Description string +} + +func (e *PollError) Error() string { + if e.Description != "" { + return fmt.Sprintf("%s: %s", e.Code, e.Description) + } + return e.Code +} + +func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { + data := url.Values{} + data.Set("client_id", ClientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", GrantTypeDeviceCode) + + req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating token request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "token request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading token response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, &PollError{Code: errResp.Error, Description: errResp.ErrorDescription} + } + return nil, errors.Newf("token request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "parsing token response") + } + + return &tokenResp, nil +} From 85dc6a42cbfe1457c0ca96a09561dd7a4d69ffd2 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:47:34 +0200 Subject: [PATCH 4/7] remove emoji --- cmd/src/login.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/cmd/src/login.go b/cmd/src/login.go index bc5af1e3a0..52d881cc58 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -126,6 +126,43 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s } fmt.Fprintln(out) fmt.Fprintf(out, "✔️ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg) + + if p.useDeviceFlow { + fmt.Fprintln(out) + fmt.Fprintf(out, "To use this access token, set the following environment variables in your terminal:\n\n") + fmt.Fprintf(out, " export SRC_ENDPOINT=%s\n", endpointArg) + fmt.Fprintf(out, " export SRC_ACCESS_TOKEN=%s\n", cfg.AccessToken) + } + fmt.Fprintln(out) return nil } + +func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauthdevice.Client) (string, error) { + authResp, err := client.Start(ctx, endpoint, nil) + if err != nil { + return "", err + } + + fmt.Fprintln(out) + fmt.Fprintf(out, "To authenticate, visit %s and enter the code: %s\n", authResp.VerificationURI, authResp.UserCode) + if authResp.VerificationURIComplete != "" { + fmt.Fprintln(out) + fmt.Fprintf(out, "Alternatively, you can open: %s\n", authResp.VerificationURIComplete) + } + fmt.Fprintln(out) + fmt.Fprint(out, "Waiting for authorization...") + defer fmt.Fprintf(out, "DONE\n\n") + + interval := time.Duration(authResp.Interval) * time.Second + if interval <= 0 { + interval = 5 * time.Second + } + + tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + if err != nil { + return "", err + } + + return tokenResp.AccessToken, nil +} From 3c7d64333aef513ce274256aad292c5751ab61fc Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:58:10 +0200 Subject: [PATCH 5/7] add refresh token to device response unmarshall --- internal/oauthdevice/device_flow.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go index 851779cc3a..5e3de8971c 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauthdevice/device_flow.go @@ -50,10 +50,11 @@ type DeviceAuthResponse struct { } type TokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` } type ErrorResponse struct { From 75bf1b4efc29e45b875033be6410d66c943b39f8 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:10:24 +0200 Subject: [PATCH 6/7] make NewClient take ClientID as param --- internal/oauthdevice/device_flow.go | 14 +++++---- internal/oauthdevice/device_flow_test.go | 36 ++++++++++++------------ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go index 5e3de8971c..c278dd4ba3 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauthdevice/device_flow.go @@ -17,8 +17,10 @@ import ( ) const ( - ClientID = "sgo_cid_sourcegraph-cli" + // DefaultClientID is a predefined Client ID built into Sourcegraph + DefaultClientID = "sgo_cid_sourcegraph-cli" + // wellKnownPath is the path on the sourcegraph server where clients can discover OAuth configuration wellKnownPath = "/.well-known/openid-configuration" GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" @@ -69,13 +71,15 @@ type Client interface { } type httpClient struct { - client *http.Client + clientID string + client *http.Client // cached OIDC configuration per endpoint configCache map[string]*OIDCConfiguration } -func NewClient() Client { +func NewClient(clientID string) Client { return &httpClient{ + clientID: clientID, client: &http.Client{ Timeout: 30 * time.Second, }, @@ -152,7 +156,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string } data := url.Values{} - data.Set("client_id", ClientID) + data.Set("client_id", DefaultClientID) if len(scopes) > 0 { data.Set("scope", strings.Join(scopes, " ")) } else { @@ -266,7 +270,7 @@ func (e *PollError) Error() string { func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { data := url.Values{} - data.Set("client_id", ClientID) + data.Set("client_id", DefaultClientID) data.Set("device_code", deviceCode) data.Set("grant_type", GrantTypeDeviceCode) diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauthdevice/device_flow_test.go index 02b3923d88..e60e1f9b1a 100644 --- a/internal/oauthdevice/device_flow_test.go +++ b/internal/oauthdevice/device_flow_test.go @@ -50,7 +50,7 @@ func TestDiscover_Success(t *testing.T) { server := newTestServer(t, testServerOptions{}) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) config, err := client.Discover(context.Background(), server.URL) if err != nil { t.Fatalf("Discover() error = %v", err) @@ -78,7 +78,7 @@ func TestDiscover_Caching(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) // Populate the cache _, err := client.Discover(context.Background(), server.URL) @@ -105,7 +105,7 @@ func TestDiscover_Error(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Discover(context.Background(), server.URL) if err == nil { t.Fatal("Discover() expected error, got nil") @@ -141,8 +141,8 @@ func TestStart_Success(t *testing.T) { return } - if got := r.FormValue("client_id"); got != ClientID { - t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + if got := r.FormValue("client_id"); got != DefaultClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID) } w.Header().Set("Content-Type", "application/json") @@ -152,7 +152,7 @@ func TestStart_Success(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) resp, err := client.Start(context.Background(), server.URL, nil) if err != nil { t.Fatalf("Start() error = %v", err) @@ -204,7 +204,7 @@ func TestStart_WithScopes(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, []string{"read", "write"}) if err != nil { t.Fatalf("Start() error = %v", err) @@ -230,7 +230,7 @@ func TestStart_Error(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, nil) if err == nil { t.Fatal("Start() expected error, got nil") @@ -253,7 +253,7 @@ func TestStart_NoDeviceEndpoint(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, nil) if err == nil { t.Fatal("Start() expected error, got nil") @@ -287,8 +287,8 @@ func TestPoll_Success(t *testing.T) { return } - if got := r.FormValue("client_id"); got != ClientID { - t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + if got := r.FormValue("client_id"); got != DefaultClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID) } if got := r.FormValue("grant_type"); got != GrantTypeDeviceCode { t.Errorf("unexpected grant_type: got %q", got) @@ -301,7 +301,7 @@ func TestPoll_Success(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -343,7 +343,7 @@ func TestPoll_AuthorizationPending(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -385,7 +385,7 @@ func TestPoll_SlowDown(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -415,7 +415,7 @@ func TestPoll_ExpiredToken(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -442,7 +442,7 @@ func TestPoll_AccessDenied(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -468,7 +468,7 @@ func TestPoll_Timeout(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -497,7 +497,7 @@ func TestPoll_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600) if err == nil { t.Fatal("Poll() expected error, got nil") From fd1668e02942b29d72e2337998189fb236b823e2 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:10:36 +0200 Subject: [PATCH 7/7] add flag to set client-id for device-flow --- cmd/src/login.go | 58 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 52d881cc58..e42632c3b0 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -7,9 +7,11 @@ import ( "io" "os" "strings" + "time" "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/oauthdevice" ) func init() { @@ -17,7 +19,7 @@ func init() { Usage: - src login SOURCEGRAPH_URL + src login [flags] SOURCEGRAPH_URL Examples: @@ -32,6 +34,11 @@ Examples: Use OAuth device flow to authenticate: $ src login --device-flow https://sourcegraph.com + + + Override the default client id used during device flow when authenticating: + + $ src login --device-flow https://sourcegraph.com --client-id sgo_my_own_client_id ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) @@ -41,7 +48,9 @@ Examples: } var ( - apiFlags = api.NewFlags(flagSet) + apiFlags = api.NewFlags(flagSet) + useDeviceFlow = flagSet.Bool("device-flow", false, "Use OAuth device flow to obtain an access token interactively") + OAuthClientID = flagSet.String("client-id", oauthdevice.DefaultClientID, "Client ID to use with OAuth device flow. Will use the predefined src cli client ID if not specified.") ) handler := func(args []string) error { @@ -56,9 +65,21 @@ Examples: return cmderrors.Usage("expected exactly one argument: the Sourcegraph URL, or SRC_ENDPOINT to be set") } + if *OAuthClientID == "" { + return cmderrors.Usage("no value specified for client-id") + } + client := cfg.apiClient(apiFlags, io.Discard) - return loginCmd(context.Background(), cfg, client, endpoint, os.Stdout) + return loginCmd(context.Background(), loginParams{ + cfg: cfg, + client: client, + endpoint: endpoint, + out: os.Stdout, + useDeviceFlow: *useDeviceFlow, + apiFlags: apiFlags, + deviceFlowClient: oauthdevice.NewClient(*OAuthClientID), + }) } commands = append(commands, &command{ @@ -68,8 +89,21 @@ Examples: }) } -func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg string, out io.Writer) error { - endpointArg = cleanEndpoint(endpointArg) +type loginParams struct { + cfg *config + client api.Client + endpoint string + out io.Writer + useDeviceFlow bool + apiFlags *api.Flags + deviceFlowClient oauthdevice.Client +} + +func loginCmd(ctx context.Context, p loginParams) error { + endpointArg := cleanEndpoint(p.endpoint) + cfg := p.cfg + client := p.client + out := p.out printProblem := func(problem string) { fmt.Fprintf(out, "❌ Problem: %s\n", problem) @@ -90,7 +124,19 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint - if noToken || endpointConflict { + + if p.useDeviceFlow { + token, err := runDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) + if err != nil { + printProblem(fmt.Sprintf("Device flow authentication failed: %s", err)) + fmt.Fprintln(out, createAccessTokenMessage) + return cmderrors.ExitCode1 + } + + cfg.AccessToken = token + cfg.Endpoint = endpointArg + client = cfg.apiClient(p.apiFlags, out) + } else if noToken || endpointConflict { fmt.Fprintln(out) switch { case noToken: