From aeda0e124bd9472d0d0e72dd92f8c28085310047 Mon Sep 17 00:00:00 2001 From: samzong Date: Sat, 27 Jun 2026 19:16:43 -0400 Subject: [PATCH] feat(auth): add oauth device login and token refresh Signed-off-by: samzong --- internal/auth/auth.go | 147 ++++++++++++++++++++++++++++++- internal/auth/auth_test.go | 96 ++++++++++++++++++++ internal/codegen/render/skill.go | 8 ++ pkg/config/hosts.go | 20 +++-- pkg/config/hosts_test.go | 38 ++++++++ pkg/config/manifest.go | 27 ++++++ pkg/config/manifest_test.go | 61 +++++++++++++ pkg/runtime/client.go | 62 +++++++++---- pkg/runtime/client_test.go | 31 +++++++ pkg/runtime/ctx.go | 27 ++++-- pkg/runtime/ctx_test.go | 82 +++++++++++++++++ pkg/runtime/oauth.go | 107 ++++++++++++++++++++++ 12 files changed, 671 insertions(+), 35 deletions(-) create mode 100644 internal/auth/auth_test.go create mode 100644 pkg/config/hosts_test.go create mode 100644 pkg/runtime/oauth.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 7e6e18a..7dd54d9 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -2,11 +2,13 @@ package auth import ( "bufio" + "encoding/json" "errors" "fmt" "io" "os" "strings" + "time" "github.com/spf13/cobra" "golang.org/x/term" @@ -30,6 +32,23 @@ func NewHiddenLoginCommand(m *config.Manifest) *cobra.Command { return cmd } +type oauthDeviceStartResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + Interval int64 `json:"interval"` + ExpiresIn int64 `json:"expires_in"` +} + +type oauthDeviceTokenResponse struct { + Status string `json:"status"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + User map[string]string `json:"user"` +} + func rootString(cmd *cobra.Command, name string) string { v, _ := cmd.Root().PersistentFlags().GetString(name) return v @@ -40,9 +59,112 @@ func rootBool(cmd *cobra.Command, name string) bool { return v } +func oauthDeviceLogin(cmd *cobra.Command, m *config.Manifest, hostname string, provider string, insecure bool) (config.HostEntry, error) { + login := m.Auth.Login + if login == nil || login.Type != config.AuthLoginOAuthDevice { + return config.HostEntry{}, errors.New("auth.login with type oauth_device is required for --auth-type oauth") + } + body := map[string]string{"hostname": hostname} + provider = strings.TrimSpace(provider) + if provider != "" { + body["provider"] = provider + } + data, err := runtime.DoRaw(cmd.Context(), hostname, "POST", login.StartPath, body, runtime.ClientOptions{Insecure: insecure, Timeout: 10 * time.Second}) + if err != nil { + return config.HostEntry{}, fmt.Errorf("start oauth login: %w", err) + } + var start oauthDeviceStartResponse + if err := json.Unmarshal(data, &start); err != nil { + return config.HostEntry{}, fmt.Errorf("decode oauth start response: %w", err) + } + if start.DeviceCode == "" { + return config.HostEntry{}, errors.New("oauth start response missing device_code") + } + verificationURL := start.VerificationURIComplete + if verificationURL == "" { + verificationURL = start.VerificationURI + } + if verificationURL == "" { + return config.HostEntry{}, errors.New("oauth start response missing verification_uri") + } + fmt.Fprintf(os.Stderr, "Open this URL to authenticate: %s\n", verificationURL) + if start.UserCode != "" { + fmt.Fprintf(os.Stderr, "Code: %s\n", start.UserCode) + } + token, err := pollOAuthDeviceToken(cmd, hostname, login.TokenPath, start, insecure) + if err != nil { + return config.HostEntry{}, err + } + entry := config.HostEntry{ + AuthType: "bearer", + LoginType: config.AuthLoginOAuthDevice, + LoginProvider: provider, + OAuthToken: token.AccessToken, + OAuthRefreshToken: token.RefreshToken, + Insecure: insecure, + } + if token.ExpiresIn > 0 { + entry.OAuthExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() + } + if token.User != nil { + entry.User = token.User["email"] + if entry.User == "" { + entry.User = token.User["name"] + } + } + return entry, nil +} + +func pollOAuthDeviceToken(cmd *cobra.Command, hostname string, tokenPath string, start oauthDeviceStartResponse, insecure bool) (oauthDeviceTokenResponse, error) { + expiresIn := start.ExpiresIn + if expiresIn <= 0 { + expiresIn = 600 + } + interval := start.Interval + if interval <= 0 { + interval = 5 + } + deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) + for { + if time.Now().After(deadline) { + return oauthDeviceTokenResponse{}, errors.New("oauth login expired") + } + data, err := runtime.DoRaw(cmd.Context(), hostname, "POST", tokenPath, map[string]string{ + "device_code": start.DeviceCode, + }, runtime.ClientOptions{Insecure: insecure, Timeout: 10 * time.Second}) + if err != nil { + return oauthDeviceTokenResponse{}, fmt.Errorf("poll oauth login: %w", err) + } + var token oauthDeviceTokenResponse + if err := json.Unmarshal(data, &token); err != nil { + return oauthDeviceTokenResponse{}, fmt.Errorf("decode oauth token response: %w", err) + } + if token.AccessToken != "" { + return token, nil + } + switch token.Status { + case "pending", "": + timer := time.NewTimer(time.Duration(interval) * time.Second) + select { + case <-cmd.Context().Done(): + timer.Stop() + return oauthDeviceTokenResponse{}, cmd.Context().Err() + case <-timer.C: + } + case "denied": + return oauthDeviceTokenResponse{}, errors.New("oauth login denied") + case "expired": + return oauthDeviceTokenResponse{}, errors.New("oauth login expired") + default: + return oauthDeviceTokenResponse{}, fmt.Errorf("oauth login failed with status %q", token.Status) + } + } +} + func newLogin(m *config.Manifest) *cobra.Command { var ( authType string + provider string withToken bool skipValidate bool ) @@ -110,8 +232,17 @@ func newLogin(m *config.Manifest) *cobra.Command { return err } entry.BasicPassword = pass + case "oauth": + if withToken { + return errors.New("--with-token cannot be used with --auth-type oauth") + } + var err error + entry, err = oauthDeviceLogin(cmd, m, hostname, provider, insecure) + if err != nil { + return err + } default: - return fmt.Errorf("unknown auth type: %q (use bearer, apikey, or basic)", authType) + return fmt.Errorf("unknown auth type: %q (use bearer, apikey, basic, or oauth)", authType) } if !skipValidate { @@ -126,7 +257,9 @@ func newLogin(m *config.Manifest) *cobra.Command { } return fmt.Errorf("credential validation failed against %s: %w", hostname, err) } - entry.User = result.Username + if result.Username != "" { + entry.User = result.Username + } if entry.User != "" { fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", entry.User) } @@ -144,7 +277,8 @@ func newLogin(m *config.Manifest) *cobra.Command { return nil }, } - cmd.Flags().StringVar(&authType, "auth-type", "", "Authentication type: bearer (default), apikey, basic") + cmd.Flags().StringVar(&authType, "auth-type", "", "Authentication type: bearer (default), apikey, basic, oauth") + cmd.Flags().StringVar(&provider, "provider", "", "OAuth provider hint passed to the service") cmd.Flags().BoolVar(&withToken, "with-token", false, "Read token/key from stdin") cmd.Flags().BoolVar(&skipValidate, "skip-validate", false, "Do not validate credentials against the server") return cmd @@ -225,6 +359,13 @@ func printStatus(hostname string, e config.HostEntry) { } credential := maskedCredential(e) fmt.Fprintf(os.Stdout, "%s\n ✓ Logged in as %s\n ✓ Auth: %s\n ✓ Credential: %s\n", hostname, user, authLabel, credential) + if e.LoginType != "" { + loginLabel := e.LoginType + if e.LoginProvider != "" { + loginLabel += " (" + e.LoginProvider + ")" + } + fmt.Fprintf(os.Stdout, " ✓ Login: %s\n", loginLabel) + } } func maskedCredential(e config.HostEntry) string { diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..fc61ac1 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,96 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/spf13/cobra" + + "github.com/lathe-cli/lathe/pkg/config" +) + +func TestOAuthDeviceLoginSavesBearerHost(t *testing.T) { + var startCalled bool + var tokenCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/start": + startCalled = true + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode start body: %v", err) + } + if body["provider"] != "github" { + t.Errorf("provider = %q, want github", body["provider"]) + } + if body["hostname"] == "" { + t.Error("hostname missing") + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "device-1", + "user_code": "ABCD", + "verification_uri_complete": "https://example.com/device?code=ABCD", + "expires_in": 60, + }) + case "/token": + tokenCalled = true + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode token body: %v", err) + } + if body["device_code"] != "device-1" { + t.Errorf("device_code = %q, want device-1", body["device_code"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-1", + "refresh_token": "refresh-1", + "expires_in": 3600, + "user": map[string]string{ + "email": "octo@example.com", + }, + }) + case "/validate": + _ = json.NewEncoder(w).Encode(map[string]any{}) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + m := &config.Manifest{ + CLI: config.CLIInfo{Name: "demo", ConfigDir: "demo", ConfigDirEnv: "DEMO_CONFIG_DIR", HostEnv: "DEMO_HOST"}, + Auth: config.AuthInfo{Login: &config.AuthLogin{ + Type: config.AuthLoginOAuthDevice, + StartPath: "/start", + TokenPath: "/token", + }, Validate: &config.AuthValidate{Method: "GET", Path: "/validate"}}, + } + config.Bind(m) + t.Setenv("DEMO_CONFIG_DIR", t.TempDir()) + + root := &cobra.Command{Use: "demo"} + root.PersistentFlags().String("hostname", srv.URL, "") + root.PersistentFlags().Bool("insecure", false, "") + root.AddCommand(NewCommand(m)) + root.SetArgs([]string{"auth", "login", "--auth-type", "oauth", "--provider", "github"}) + + if err := root.Execute(); err != nil { + t.Fatalf("Execute: %v", err) + } + if !startCalled || !tokenCalled { + t.Fatalf("startCalled=%v tokenCalled=%v", startCalled, tokenCalled) + } + hosts, err := config.LoadHosts() + if err != nil { + t.Fatalf("LoadHosts: %v", err) + } + entry, ok := hosts.Get(srv.URL) + if !ok { + t.Fatal("host not saved") + } + if entry.AuthType != "bearer" || entry.LoginType != config.AuthLoginOAuthDevice || entry.LoginProvider != "github" || entry.OAuthToken != "access-1" || entry.OAuthRefreshToken != "refresh-1" || entry.User != "octo@example.com" || entry.OAuthExpiresAt == 0 { + t.Fatalf("entry = %+v", entry) + } +} diff --git a/internal/codegen/render/skill.go b/internal/codegen/render/skill.go index e27655b..899239c 100644 --- a/internal/codegen/render/skill.go +++ b/internal/codegen/render/skill.go @@ -496,6 +496,11 @@ func renderSkillMD(manifest *config.Manifest, refs []moduleRef) string { fmt.Fprintf(&b, "2. Inspect the exact command with `%s commands show --json` before executing an unfamiliar command.\n", cli) fmt.Fprintf(&b, "3. If the command detail has `auth.required=true`, run `%s auth status --hostname ` before execution. Use `http.default_hostname` when present unless the user provides `--hostname` or `$%s`.\n", cli, manifest.CLI.HostEnv) fmt.Fprintf(&b, "4. Execute only after flags, body, auth, HTTP path, and output hints are clear from `commands show`.\n\n") + if manifest.Auth.Login != nil && manifest.Auth.Login.Type == config.AuthLoginOAuthDevice { + b.WriteString("## Auth Login\n\n") + fmt.Fprintf(&b, "- Use `%s auth login --auth-type oauth --hostname --provider ` when the user needs browser-based OAuth login.\n", cli) + b.WriteString("- The saved host will use `auth_type: bearer`; OAuth is the login method, and the resulting API credential is a bearer token.\n\n") + } b.WriteString("## General Commands\n\n") fmt.Fprintf(&b, "- `%s commands --json`: full generated command catalog.\n", cli) fmt.Fprintf(&b, "- `%s commands --include-hidden --json`: include hidden generated commands.\n", cli) @@ -572,6 +577,9 @@ func renderCatalogReference(manifest *config.Manifest) string { b.WriteString("Use `-o json` for machine-readable command output. Other supported formats are `table`, `yaml`, and `raw`.\n\n") b.WriteString("## Auth\n\n") fmt.Fprintf(&b, "If command detail returns `auth.required=true`, run `%s auth status --hostname ` before execution. Use `http.default_hostname` when present unless the user provides `--hostname` or `$%s`; if no matching host is logged in, stop and ask the user to authenticate.\n", cli, manifest.CLI.HostEnv) + if manifest.Auth.Login != nil && manifest.Auth.Login.Type == config.AuthLoginOAuthDevice { + fmt.Fprintf(&b, "For browser-based OAuth login, run `%s auth login --auth-type oauth --hostname --provider `. `auth_type: bearer` in `hosts.yml` is expected after login because API requests use the issued bearer token.\n", cli) + } return b.String() } diff --git a/pkg/config/hosts.go b/pkg/config/hosts.go index e58730b..b4878e1 100644 --- a/pkg/config/hosts.go +++ b/pkg/config/hosts.go @@ -22,14 +22,18 @@ func NormalizeHostname(s string) string { // HostEntry mirrors gh's per-host record in hosts.yml. type HostEntry struct { - AuthType string `yaml:"auth_type,omitempty"` - User string `yaml:"user,omitempty"` - OAuthToken string `yaml:"oauth_token,omitempty"` - APIKey string `yaml:"api_key,omitempty"` - APIKeyHeader string `yaml:"api_key_header,omitempty"` - BasicUser string `yaml:"basic_user,omitempty"` - BasicPassword string `yaml:"basic_password,omitempty"` - Insecure bool `yaml:"insecure,omitempty"` + AuthType string `yaml:"auth_type,omitempty"` + LoginType string `yaml:"login_type,omitempty"` + LoginProvider string `yaml:"login_provider,omitempty"` + User string `yaml:"user,omitempty"` + OAuthToken string `yaml:"oauth_token,omitempty"` + OAuthRefreshToken string `yaml:"oauth_refresh_token,omitempty"` + OAuthExpiresAt int64 `yaml:"oauth_expires_at,omitempty"` + APIKey string `yaml:"api_key,omitempty"` + APIKeyHeader string `yaml:"api_key_header,omitempty"` + BasicUser string `yaml:"basic_user,omitempty"` + BasicPassword string `yaml:"basic_password,omitempty"` + Insecure bool `yaml:"insecure,omitempty"` } type Hosts struct { diff --git a/pkg/config/hosts_test.go b/pkg/config/hosts_test.go new file mode 100644 index 0000000..02ed413 --- /dev/null +++ b/pkg/config/hosts_test.go @@ -0,0 +1,38 @@ +package config + +import "testing" + +func TestHostsRoundTripOAuthLoginFields(t *testing.T) { + m := &Manifest{CLI: CLIInfo{Name: "demo", ConfigDir: "demo", ConfigDirEnv: "DEMO_CONFIG_DIR", HostEnv: "DEMO_HOST"}} + Bind(m) + t.Setenv("DEMO_CONFIG_DIR", t.TempDir()) + + hosts, err := LoadHosts() + if err != nil { + t.Fatalf("LoadHosts: %v", err) + } + hosts.Set("https://api.example.com", HostEntry{ + AuthType: "bearer", + LoginType: AuthLoginOAuthDevice, + LoginProvider: "github", + User: "octo@example.com", + OAuthToken: "access", + OAuthRefreshToken: "refresh", + OAuthExpiresAt: 1790000000, + }) + if err := hosts.Save(); err != nil { + t.Fatalf("Save: %v", err) + } + + loaded, err := LoadHosts() + if err != nil { + t.Fatalf("LoadHosts reload: %v", err) + } + entry, ok := loaded.Get("api.example.com") + if !ok { + t.Fatal("missing host") + } + if entry.AuthType != "bearer" || entry.LoginType != AuthLoginOAuthDevice || entry.LoginProvider != "github" || entry.OAuthToken != "access" || entry.OAuthRefreshToken != "refresh" || entry.OAuthExpiresAt != 1790000000 { + t.Fatalf("entry = %+v", entry) + } +} diff --git a/pkg/config/manifest.go b/pkg/config/manifest.go index 6d9c298..6bafd2b 100644 --- a/pkg/config/manifest.go +++ b/pkg/config/manifest.go @@ -25,12 +25,20 @@ type CLIInfo struct { type AuthInfo struct { Validate *AuthValidate `yaml:"validate,omitempty"` + Login *AuthLogin `yaml:"login,omitempty"` } type UpdateInfo struct { GitHub *GitHubUpdate `yaml:"github,omitempty"` } +type AuthLogin struct { + Type string `yaml:"type"` + StartPath string `yaml:"start_path"` + TokenPath string `yaml:"token_path"` + RefreshPath string `yaml:"refresh_path,omitempty"` +} + type GitHubUpdate struct { Owner string `yaml:"owner"` Repo string `yaml:"repo"` @@ -52,6 +60,7 @@ const ( CommandPathAuto = "auto" CommandPathFlat = "flat" CommandPathNamespaced = "namespaced" + AuthLoginOAuthDevice = "oauth_device" ) // Load parses raw cli.yaml bytes into a Manifest. The caller (typically main.go) @@ -78,6 +87,24 @@ func Load(bytes []byte) (*Manifest, error) { return nil, fmt.Errorf("update.github.owner, update.github.repo, and update.github.asset are required") } } + if m.Auth.Login != nil { + m.Auth.Login.Type = strings.ToLower(strings.TrimSpace(m.Auth.Login.Type)) + m.Auth.Login.StartPath = strings.TrimSpace(m.Auth.Login.StartPath) + m.Auth.Login.TokenPath = strings.TrimSpace(m.Auth.Login.TokenPath) + m.Auth.Login.RefreshPath = strings.TrimSpace(m.Auth.Login.RefreshPath) + if m.Auth.Login.Type != AuthLoginOAuthDevice { + return nil, fmt.Errorf("auth.login.type must be %q", AuthLoginOAuthDevice) + } + if m.Auth.Login.StartPath == "" || m.Auth.Login.TokenPath == "" { + return nil, fmt.Errorf("auth.login.start_path and auth.login.token_path are required") + } + if !strings.HasPrefix(m.Auth.Login.StartPath, "/") || !strings.HasPrefix(m.Auth.Login.TokenPath, "/") { + return nil, fmt.Errorf("auth.login.start_path and auth.login.token_path must start with /") + } + if m.Auth.Login.RefreshPath != "" && !strings.HasPrefix(m.Auth.Login.RefreshPath, "/") { + return nil, fmt.Errorf("auth.login.refresh_path must start with /") + } + } m.CLI.CommandPath = strings.ToLower(strings.TrimSpace(m.CLI.CommandPath)) if m.CLI.CommandPath == "" { m.CLI.CommandPath = CommandPathAuto diff --git a/pkg/config/manifest_test.go b/pkg/config/manifest_test.go index e62e648..7d8e44b 100644 --- a/pkg/config/manifest_test.go +++ b/pkg/config/manifest_test.go @@ -8,6 +8,11 @@ cli: name: demo short: "demo CLI" auth: + login: + type: oauth_device + start_path: /auth/cli/start + token_path: /auth/cli/token + refresh_path: /auth/cli/refresh validate: method: POST path: /whoami @@ -25,6 +30,12 @@ auth: if m.Auth.Validate == nil { t.Fatal("expected Auth.Validate non-nil") } + if m.Auth.Login == nil { + t.Fatal("expected Auth.Login non-nil") + } + if m.Auth.Login.Type != AuthLoginOAuthDevice || m.Auth.Login.StartPath != "/auth/cli/start" || m.Auth.Login.TokenPath != "/auth/cli/token" || m.Auth.Login.RefreshPath != "/auth/cli/refresh" { + t.Errorf("unexpected AuthLogin: %+v", m.Auth.Login) + } if m.Auth.Validate.Method != "POST" || m.Auth.Validate.Path != "/whoami" { t.Errorf("unexpected AuthValidate: %+v", m.Auth.Validate) } @@ -146,6 +157,56 @@ cli: } } +func TestLoad_AuthLoginValidation(t *testing.T) { + tests := []struct { + name string + yaml string + }{ + { + name: "unsupported type", + yaml: ` +cli: + name: demo +auth: + login: + type: github + start_path: /start + token_path: /token +`, + }, + { + name: "missing token path", + yaml: ` +cli: + name: demo +auth: + login: + type: oauth_device + start_path: /start +`, + }, + { + name: "relative path", + yaml: ` +cli: + name: demo +auth: + login: + type: oauth_device + start_path: start + token_path: /token +`, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if _, err := Load([]byte(tc.yaml)); err == nil { + t.Fatal("Load succeeded, want error") + } + }) + } +} + func TestBindActive_Panics(t *testing.T) { boundMu.Lock() bound = nil diff --git a/pkg/runtime/client.go b/pkg/runtime/client.go index 00ecdb9..7b44730 100644 --- a/pkg/runtime/client.go +++ b/pkg/runtime/client.go @@ -15,15 +15,16 @@ import ( ) type ClientOptions struct { - Auth Authenticator - Transport http.RoundTripper - Insecure bool - Timeout time.Duration - Headers map[string]string - Debug bool - MaxRetries int - UserAgent string - Accept string + Auth Authenticator + RefreshAuth func(context.Context) (Authenticator, error) + Transport http.RoundTripper + Insecure bool + Timeout time.Duration + Headers map[string]string + Debug bool + MaxRetries int + UserAgent string + Accept string } // BaseURL normalizes a user-facing hostname into an absolute URL base. @@ -84,26 +85,51 @@ func DoRawFull(ctx context.Context, hostname, method, path string, body any, opt } u := base + path - var reader io.Reader - contentType := "" + bodyBytes, contentType, err := encodeRequestBody(body) + if err != nil { + return nil, err + } + + result, err := doRawFullOnce(ctx, method, u, bodyBytes, contentType, opts) + if err == nil { + return result, nil + } + var he *HTTPError + if !errors.As(err, &he) || he.Status != http.StatusUnauthorized || opts.RefreshAuth == nil { + return nil, err + } + auth, refreshErr := opts.RefreshAuth(ctx) + if refreshErr != nil { + return nil, fmt.Errorf("refresh auth after 401: %w", refreshErr) + } + opts.Auth = auth + opts.RefreshAuth = nil + return doRawFullOnce(ctx, method, u, bodyBytes, contentType, opts) +} + +func encodeRequestBody(body any) ([]byte, string, error) { if body != nil { switch b := body.(type) { case []byte: - reader = bytes.NewReader(b) - contentType = "application/json" + return b, "application/json", nil case url.Values: - reader = strings.NewReader(b.Encode()) - contentType = "application/x-www-form-urlencoded" + return []byte(b.Encode()), "application/x-www-form-urlencoded", nil default: raw, err := json.Marshal(b) if err != nil { - return nil, fmt.Errorf("marshal request body: %w", err) + return nil, "", fmt.Errorf("marshal request body: %w", err) } - reader = bytes.NewReader(raw) - contentType = "application/json" + return raw, "application/json", nil } } + return nil, "", nil +} +func doRawFullOnce(ctx context.Context, method, u string, body []byte, contentType string, opts ClientOptions) (*RawResult, error) { + var reader io.Reader + if body != nil { + reader = bytes.NewReader(body) + } req, err := http.NewRequestWithContext(ctx, method, u, reader) if err != nil { return nil, err diff --git a/pkg/runtime/client_test.go b/pkg/runtime/client_test.go index 9c08147..b502a25 100644 --- a/pkg/runtime/client_test.go +++ b/pkg/runtime/client_test.go @@ -142,3 +142,34 @@ func TestDoRaw_EncodesFormBody(t *testing.T) { t.Errorf("body = %q, want %q", gotBody, form.Encode()) } } + +func TestDoRaw_RefreshesAuthAndRetriesOnceOn401(t *testing.T) { + var seen []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen = append(seen, r.Header.Get("Authorization")) + if len(seen) == 1 { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"expired"}`)) + return + } + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + data, err := DoRaw(context.Background(), srv.URL, "POST", "/x", map[string]string{"a": "b"}, ClientOptions{ + Auth: BearerAuth{Token: "old"}, + RefreshAuth: func(context.Context) (Authenticator, error) { + return BearerAuth{Token: "new"}, nil + }, + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("DoRaw: %v", err) + } + if string(data) != `{"ok":true}` { + t.Fatalf("data = %s", data) + } + if len(seen) != 2 || seen[0] != "Bearer old" || seen[1] != "Bearer new" { + t.Fatalf("authorization sequence = %#v", seen) + } +} diff --git a/pkg/runtime/ctx.go b/pkg/runtime/ctx.go index 779dd68..87e3827 100644 --- a/pkg/runtime/ctx.go +++ b/pkg/runtime/ctx.go @@ -75,16 +75,24 @@ func loadHostOptions(cmd *cobra.Command, defaultHostname string) (string, Client if !ok { return "", ClientOptions{}, notAuthenticatedToHost(hostname) } + insecure := e.Insecure + if v, err := cmd.Root().PersistentFlags().GetBool("insecure"); err == nil && v { + insecure = true + } + e, err = refreshHostAuthIfNeeded(cmd.Context(), hostname, hosts, e, insecure) + if err != nil { + return "", ClientOptions{}, err + } auth, err := NewAuthFromHost(e) if err != nil { return "", ClientOptions{}, err } opts := ClientOptions{ Auth: auth, - Insecure: e.Insecure, + Insecure: insecure, } - if v, err := cmd.Root().PersistentFlags().GetBool("insecure"); err == nil && v { - opts.Insecure = true + if canRefreshHostAuth(e) { + opts.RefreshAuth = refreshAuthFunc(hostname, insecure, e.OAuthToken) } return hostname, opts, nil } @@ -110,16 +118,23 @@ func tryLoadHostOptions(cmd *cobra.Command, defaultHostname string) (string, Cli } return hostname, opts, nil } + insecure := e.Insecure + if v, err := cmd.Root().PersistentFlags().GetBool("insecure"); err == nil && v { + insecure = true + } + if refreshed, err := refreshHostAuthIfNeeded(cmd.Context(), hostname, hosts, e, insecure); err == nil { + e = refreshed + } auth, err := NewAuthFromHost(e) if err != nil { return hostname, ClientOptions{}, nil } opts := ClientOptions{ Auth: auth, - Insecure: e.Insecure, + Insecure: insecure, } - if v, err := cmd.Root().PersistentFlags().GetBool("insecure"); err == nil && v { - opts.Insecure = true + if canRefreshHostAuth(e) { + opts.RefreshAuth = refreshAuthFunc(hostname, insecure, e.OAuthToken) } return hostname, opts, nil } diff --git a/pkg/runtime/ctx_test.go b/pkg/runtime/ctx_test.go index 5b1dfd3..4e5faa4 100644 --- a/pkg/runtime/ctx_test.go +++ b/pkg/runtime/ctx_test.go @@ -1,9 +1,14 @@ package runtime import ( + "context" + "encoding/json" "errors" + "net/http" + "net/http/httptest" "strings" "testing" + "time" "github.com/spf13/cobra" @@ -47,3 +52,80 @@ func TestResolveHost_UsesBoundHostEnv(t *testing.T) { t.Errorf("want example.internal, got %q", got) } } + +func TestLoadHostOptionsRefreshesExpiredOAuthToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/refresh" { + http.NotFound(w, r) + return + } + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode body: %v", err) + } + if body["refresh_token"] != "refresh-old" { + t.Errorf("refresh_token = %q, want refresh-old", body["refresh_token"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-new", + "refresh_token": "refresh-new", + "expires_in": 3600, + }) + })) + defer srv.Close() + + config.Bind(&config.Manifest{ + CLI: config.CLIInfo{Name: "demo", ConfigDir: "demo", ConfigDirEnv: "DEMO_CONFIG_DIR", HostEnv: "DEMO_HOST"}, + Auth: config.AuthInfo{Login: &config.AuthLogin{ + Type: config.AuthLoginOAuthDevice, + StartPath: "/start", + TokenPath: "/token", + RefreshPath: "/refresh", + }}, + }) + t.Setenv("DEMO_CONFIG_DIR", t.TempDir()) + hosts, err := config.LoadHosts() + if err != nil { + t.Fatalf("LoadHosts: %v", err) + } + hosts.Set(srv.URL, config.HostEntry{ + AuthType: "bearer", + OAuthToken: "access-old", + OAuthRefreshToken: "refresh-old", + OAuthExpiresAt: time.Now().Add(-time.Hour).Unix(), + }) + if err := hosts.Save(); err != nil { + t.Fatalf("Save: %v", err) + } + + root := &cobra.Command{Use: "demo"} + root.SetContext(context.Background()) + root.PersistentFlags().String("hostname", srv.URL, "") + root.PersistentFlags().Bool("insecure", false, "") + + hostname, opts, err := loadHostOptions(root, "") + if err != nil { + t.Fatalf("loadHostOptions: %v", err) + } + if hostname != config.NormalizeHostname(srv.URL) { + t.Fatalf("hostname = %q", hostname) + } + auth, ok := opts.Auth.(BearerAuth) + if !ok || auth.Token != "access-new" { + t.Fatalf("auth = %#v", opts.Auth) + } + if opts.RefreshAuth == nil { + t.Fatal("RefreshAuth is nil") + } + reloaded, err := config.LoadHosts() + if err != nil { + t.Fatalf("LoadHosts reload: %v", err) + } + entry, ok := reloaded.Get(srv.URL) + if !ok { + t.Fatal("host missing") + } + if entry.OAuthToken != "access-new" || entry.OAuthRefreshToken != "refresh-new" || entry.OAuthExpiresAt <= time.Now().Unix() { + t.Fatalf("entry = %+v", entry) + } +} diff --git a/pkg/runtime/oauth.go b/pkg/runtime/oauth.go new file mode 100644 index 0000000..75560db --- /dev/null +++ b/pkg/runtime/oauth.go @@ -0,0 +1,107 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/lathe-cli/lathe/pkg/config" +) + +const oauthRefreshSkew = 60 + +type oauthTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` +} + +func refreshAuthFunc(hostname string, insecure bool, rejectedToken string) func(context.Context) (Authenticator, error) { + return func(ctx context.Context) (Authenticator, error) { + hosts, err := config.LoadHosts() + if err != nil { + return nil, err + } + entry, ok := hosts.Get(hostname) + if !ok { + return nil, notAuthenticatedToHost(hostname) + } + if entry.OAuthToken != "" && entry.OAuthToken != rejectedToken { + return NewAuthFromHost(entry) + } + entry, err = refreshHostAuth(ctx, hostname, hosts, entry, insecure) + if err != nil { + return nil, err + } + return NewAuthFromHost(entry) + } +} + +func canRefreshHostAuth(entry config.HostEntry) bool { + login := config.Active().Auth.Login + return login != nil && login.RefreshPath != "" && entry.OAuthRefreshToken != "" +} + +func refreshHostAuthIfNeeded(ctx context.Context, hostname string, hosts *config.Hosts, entry config.HostEntry, insecure bool) (config.HostEntry, error) { + if !canRefreshHostAuth(entry) { + return entry, nil + } + if entry.OAuthExpiresAt == 0 || time.Now().Unix()+oauthRefreshSkew < entry.OAuthExpiresAt { + return entry, nil + } + return refreshHostAuth(ctx, hostname, hosts, entry, insecure) +} + +func refreshHostAuth(ctx context.Context, hostname string, hosts *config.Hosts, entry config.HostEntry, insecure bool) (config.HostEntry, error) { + login := config.Active().Auth.Login + if login == nil || login.RefreshPath == "" || entry.OAuthRefreshToken == "" { + return entry, fmt.Errorf("refresh token unavailable; run `%s auth login --auth-type oauth --hostname %s`", config.Active().CLI.Name, hostname) + } + data, err := DoRaw(ctx, hostname, "POST", login.RefreshPath, map[string]string{ + "refresh_token": entry.OAuthRefreshToken, + }, ClientOptions{Insecure: insecure, Timeout: 10 * time.Second}) + if err != nil { + if current, ok := refreshedByAnotherProcess(hostname, entry); ok { + return current, nil + } + return entry, err + } + var token oauthTokenResponse + if err := json.Unmarshal(data, &token); err != nil { + return entry, fmt.Errorf("decode refresh token response: %w", err) + } + if token.AccessToken == "" { + return entry, fmt.Errorf("refresh token response missing access_token") + } + entry.AuthType = "bearer" + entry.OAuthToken = token.AccessToken + if token.RefreshToken != "" { + entry.OAuthRefreshToken = token.RefreshToken + } + if token.ExpiresIn > 0 { + entry.OAuthExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() + } else { + entry.OAuthExpiresAt = 0 + } + hosts.Set(hostname, entry) + if err := hosts.Save(); err != nil { + return entry, err + } + return entry, nil +} + +func refreshedByAnotherProcess(hostname string, old config.HostEntry) (config.HostEntry, bool) { + hosts, err := config.LoadHosts() + if err != nil { + return config.HostEntry{}, false + } + current, ok := hosts.Get(hostname) + if !ok || current.AuthType != "bearer" || current.OAuthToken == "" { + return config.HostEntry{}, false + } + if current.OAuthToken == old.OAuthToken && current.OAuthRefreshToken == old.OAuthRefreshToken { + return config.HostEntry{}, false + } + return current, true +}