Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 144 additions & 3 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package auth

import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strings"
"time"

"github.com/spf13/cobra"
"golang.org/x/term"
Expand All @@ -30,6 +32,23 @@ func NewHiddenLoginCommand(m *config.Manifest) *cobra.Command {
return cmd
}

type oauthDeviceStartResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
Interval int64 `json:"interval"`
ExpiresIn int64 `json:"expires_in"`
}

type oauthDeviceTokenResponse struct {
Status string `json:"status"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
User map[string]string `json:"user"`
}

func rootString(cmd *cobra.Command, name string) string {
v, _ := cmd.Root().PersistentFlags().GetString(name)
return v
Expand All @@ -40,9 +59,112 @@ func rootBool(cmd *cobra.Command, name string) bool {
return v
}

func oauthDeviceLogin(cmd *cobra.Command, m *config.Manifest, hostname string, provider string, insecure bool) (config.HostEntry, error) {
login := m.Auth.Login
if login == nil || login.Type != config.AuthLoginOAuthDevice {
return config.HostEntry{}, errors.New("auth.login with type oauth_device is required for --auth-type oauth")
}
body := map[string]string{"hostname": hostname}
provider = strings.TrimSpace(provider)
if provider != "" {
body["provider"] = provider
}
data, err := runtime.DoRaw(cmd.Context(), hostname, "POST", login.StartPath, body, runtime.ClientOptions{Insecure: insecure, Timeout: 10 * time.Second})
if err != nil {
return config.HostEntry{}, fmt.Errorf("start oauth login: %w", err)
}
var start oauthDeviceStartResponse
if err := json.Unmarshal(data, &start); err != nil {
return config.HostEntry{}, fmt.Errorf("decode oauth start response: %w", err)
}
if start.DeviceCode == "" {
return config.HostEntry{}, errors.New("oauth start response missing device_code")
}
verificationURL := start.VerificationURIComplete
if verificationURL == "" {
verificationURL = start.VerificationURI
}
if verificationURL == "" {
return config.HostEntry{}, errors.New("oauth start response missing verification_uri")
}
fmt.Fprintf(os.Stderr, "Open this URL to authenticate: %s\n", verificationURL)
if start.UserCode != "" {
fmt.Fprintf(os.Stderr, "Code: %s\n", start.UserCode)
}
token, err := pollOAuthDeviceToken(cmd, hostname, login.TokenPath, start, insecure)
if err != nil {
return config.HostEntry{}, err
}
entry := config.HostEntry{
AuthType: "bearer",
LoginType: config.AuthLoginOAuthDevice,
LoginProvider: provider,
OAuthToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
Insecure: insecure,
}
if token.ExpiresIn > 0 {
entry.OAuthExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix()
}
if token.User != nil {
entry.User = token.User["email"]
if entry.User == "" {
entry.User = token.User["name"]
}
}
return entry, nil
}

func pollOAuthDeviceToken(cmd *cobra.Command, hostname string, tokenPath string, start oauthDeviceStartResponse, insecure bool) (oauthDeviceTokenResponse, error) {
expiresIn := start.ExpiresIn
if expiresIn <= 0 {
expiresIn = 600
}
interval := start.Interval
if interval <= 0 {
interval = 5
}
deadline := time.Now().Add(time.Duration(expiresIn) * time.Second)
for {
if time.Now().After(deadline) {
return oauthDeviceTokenResponse{}, errors.New("oauth login expired")
}
data, err := runtime.DoRaw(cmd.Context(), hostname, "POST", tokenPath, map[string]string{
"device_code": start.DeviceCode,
}, runtime.ClientOptions{Insecure: insecure, Timeout: 10 * time.Second})
if err != nil {
return oauthDeviceTokenResponse{}, fmt.Errorf("poll oauth login: %w", err)
}
var token oauthDeviceTokenResponse
if err := json.Unmarshal(data, &token); err != nil {
return oauthDeviceTokenResponse{}, fmt.Errorf("decode oauth token response: %w", err)
}
if token.AccessToken != "" {
return token, nil
}
switch token.Status {
case "pending", "":
timer := time.NewTimer(time.Duration(interval) * time.Second)
select {
case <-cmd.Context().Done():
timer.Stop()
return oauthDeviceTokenResponse{}, cmd.Context().Err()
case <-timer.C:
}
case "denied":
return oauthDeviceTokenResponse{}, errors.New("oauth login denied")
case "expired":
return oauthDeviceTokenResponse{}, errors.New("oauth login expired")
default:
return oauthDeviceTokenResponse{}, fmt.Errorf("oauth login failed with status %q", token.Status)
}
}
}

func newLogin(m *config.Manifest) *cobra.Command {
var (
authType string
provider string
withToken bool
skipValidate bool
)
Expand Down Expand Up @@ -110,8 +232,17 @@ func newLogin(m *config.Manifest) *cobra.Command {
return err
}
entry.BasicPassword = pass
case "oauth":
if withToken {
return errors.New("--with-token cannot be used with --auth-type oauth")
}
var err error
entry, err = oauthDeviceLogin(cmd, m, hostname, provider, insecure)
if err != nil {
return err
}
default:
return fmt.Errorf("unknown auth type: %q (use bearer, apikey, or basic)", authType)
return fmt.Errorf("unknown auth type: %q (use bearer, apikey, basic, or oauth)", authType)
}

if !skipValidate {
Expand All @@ -126,7 +257,9 @@ func newLogin(m *config.Manifest) *cobra.Command {
}
return fmt.Errorf("credential validation failed against %s: %w", hostname, err)
}
entry.User = result.Username
if result.Username != "" {
entry.User = result.Username
}
if entry.User != "" {
fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", entry.User)
}
Expand All @@ -144,7 +277,8 @@ func newLogin(m *config.Manifest) *cobra.Command {
return nil
},
}
cmd.Flags().StringVar(&authType, "auth-type", "", "Authentication type: bearer (default), apikey, basic")
cmd.Flags().StringVar(&authType, "auth-type", "", "Authentication type: bearer (default), apikey, basic, oauth")
cmd.Flags().StringVar(&provider, "provider", "", "OAuth provider hint passed to the service")
cmd.Flags().BoolVar(&withToken, "with-token", false, "Read token/key from stdin")
cmd.Flags().BoolVar(&skipValidate, "skip-validate", false, "Do not validate credentials against the server")
return cmd
Expand Down Expand Up @@ -225,6 +359,13 @@ func printStatus(hostname string, e config.HostEntry) {
}
credential := maskedCredential(e)
fmt.Fprintf(os.Stdout, "%s\n ✓ Logged in as %s\n ✓ Auth: %s\n ✓ Credential: %s\n", hostname, user, authLabel, credential)
if e.LoginType != "" {
loginLabel := e.LoginType
if e.LoginProvider != "" {
loginLabel += " (" + e.LoginProvider + ")"
}
fmt.Fprintf(os.Stdout, " ✓ Login: %s\n", loginLabel)
}
}

func maskedCredential(e config.HostEntry) string {
Expand Down
96 changes: 96 additions & 0 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package auth

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/spf13/cobra"

"github.com/lathe-cli/lathe/pkg/config"
)

func TestOAuthDeviceLoginSavesBearerHost(t *testing.T) {
var startCalled bool
var tokenCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/start":
startCalled = true
var body map[string]string
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode start body: %v", err)
}
if body["provider"] != "github" {
t.Errorf("provider = %q, want github", body["provider"])
}
if body["hostname"] == "" {
t.Error("hostname missing")
}
_ = json.NewEncoder(w).Encode(map[string]any{
"device_code": "device-1",
"user_code": "ABCD",
"verification_uri_complete": "https://example.com/device?code=ABCD",
"expires_in": 60,
})
case "/token":
tokenCalled = true
var body map[string]string
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode token body: %v", err)
}
if body["device_code"] != "device-1" {
t.Errorf("device_code = %q, want device-1", body["device_code"])
}
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "access-1",
"refresh_token": "refresh-1",
"expires_in": 3600,
"user": map[string]string{
"email": "octo@example.com",
},
})
case "/validate":
_ = json.NewEncoder(w).Encode(map[string]any{})
default:
http.NotFound(w, r)
}
}))
defer srv.Close()

m := &config.Manifest{
CLI: config.CLIInfo{Name: "demo", ConfigDir: "demo", ConfigDirEnv: "DEMO_CONFIG_DIR", HostEnv: "DEMO_HOST"},
Auth: config.AuthInfo{Login: &config.AuthLogin{
Type: config.AuthLoginOAuthDevice,
StartPath: "/start",
TokenPath: "/token",
}, Validate: &config.AuthValidate{Method: "GET", Path: "/validate"}},
}
config.Bind(m)
t.Setenv("DEMO_CONFIG_DIR", t.TempDir())

root := &cobra.Command{Use: "demo"}
root.PersistentFlags().String("hostname", srv.URL, "")
root.PersistentFlags().Bool("insecure", false, "")
root.AddCommand(NewCommand(m))
root.SetArgs([]string{"auth", "login", "--auth-type", "oauth", "--provider", "github"})

if err := root.Execute(); err != nil {
t.Fatalf("Execute: %v", err)
}
if !startCalled || !tokenCalled {
t.Fatalf("startCalled=%v tokenCalled=%v", startCalled, tokenCalled)
}
hosts, err := config.LoadHosts()
if err != nil {
t.Fatalf("LoadHosts: %v", err)
}
entry, ok := hosts.Get(srv.URL)
if !ok {
t.Fatal("host not saved")
}
if entry.AuthType != "bearer" || entry.LoginType != config.AuthLoginOAuthDevice || entry.LoginProvider != "github" || entry.OAuthToken != "access-1" || entry.OAuthRefreshToken != "refresh-1" || entry.User != "octo@example.com" || entry.OAuthExpiresAt == 0 {
t.Fatalf("entry = %+v", entry)
}
}
8 changes: 8 additions & 0 deletions internal/codegen/render/skill.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ func renderSkillMD(manifest *config.Manifest, refs []moduleRef) string {
fmt.Fprintf(&b, "2. Inspect the exact command with `%s commands show <path...> --json` before executing an unfamiliar command.\n", cli)
fmt.Fprintf(&b, "3. If the command detail has `auth.required=true`, run `%s auth status --hostname <host>` before execution. Use `http.default_hostname` when present unless the user provides `--hostname` or `$%s`.\n", cli, manifest.CLI.HostEnv)
fmt.Fprintf(&b, "4. Execute only after flags, body, auth, HTTP path, and output hints are clear from `commands show`.\n\n")
if manifest.Auth.Login != nil && manifest.Auth.Login.Type == config.AuthLoginOAuthDevice {
b.WriteString("## Auth Login\n\n")
fmt.Fprintf(&b, "- Use `%s auth login --auth-type oauth --hostname <host> --provider <provider>` when the user needs browser-based OAuth login.\n", cli)
b.WriteString("- The saved host will use `auth_type: bearer`; OAuth is the login method, and the resulting API credential is a bearer token.\n\n")
}
b.WriteString("## General Commands\n\n")
fmt.Fprintf(&b, "- `%s commands --json`: full generated command catalog.\n", cli)
fmt.Fprintf(&b, "- `%s commands --include-hidden --json`: include hidden generated commands.\n", cli)
Expand Down Expand Up @@ -572,6 +577,9 @@ func renderCatalogReference(manifest *config.Manifest) string {
b.WriteString("Use `-o json` for machine-readable command output. Other supported formats are `table`, `yaml`, and `raw`.\n\n")
b.WriteString("## Auth\n\n")
fmt.Fprintf(&b, "If command detail returns `auth.required=true`, run `%s auth status --hostname <host>` before execution. Use `http.default_hostname` when present unless the user provides `--hostname` or `$%s`; if no matching host is logged in, stop and ask the user to authenticate.\n", cli, manifest.CLI.HostEnv)
if manifest.Auth.Login != nil && manifest.Auth.Login.Type == config.AuthLoginOAuthDevice {
fmt.Fprintf(&b, "For browser-based OAuth login, run `%s auth login --auth-type oauth --hostname <host> --provider <provider>`. `auth_type: bearer` in `hosts.yml` is expected after login because API requests use the issued bearer token.\n", cli)
}
return b.String()
}

Expand Down
20 changes: 12 additions & 8 deletions pkg/config/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ func NormalizeHostname(s string) string {

// HostEntry mirrors gh's per-host record in hosts.yml.
type HostEntry struct {
AuthType string `yaml:"auth_type,omitempty"`
User string `yaml:"user,omitempty"`
OAuthToken string `yaml:"oauth_token,omitempty"`
APIKey string `yaml:"api_key,omitempty"`
APIKeyHeader string `yaml:"api_key_header,omitempty"`
BasicUser string `yaml:"basic_user,omitempty"`
BasicPassword string `yaml:"basic_password,omitempty"`
Insecure bool `yaml:"insecure,omitempty"`
AuthType string `yaml:"auth_type,omitempty"`
LoginType string `yaml:"login_type,omitempty"`
LoginProvider string `yaml:"login_provider,omitempty"`
User string `yaml:"user,omitempty"`
OAuthToken string `yaml:"oauth_token,omitempty"`
OAuthRefreshToken string `yaml:"oauth_refresh_token,omitempty"`
OAuthExpiresAt int64 `yaml:"oauth_expires_at,omitempty"`
APIKey string `yaml:"api_key,omitempty"`
APIKeyHeader string `yaml:"api_key_header,omitempty"`
BasicUser string `yaml:"basic_user,omitempty"`
BasicPassword string `yaml:"basic_password,omitempty"`
Insecure bool `yaml:"insecure,omitempty"`
}

type Hosts struct {
Expand Down
38 changes: 38 additions & 0 deletions pkg/config/hosts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package config

import "testing"

func TestHostsRoundTripOAuthLoginFields(t *testing.T) {
m := &Manifest{CLI: CLIInfo{Name: "demo", ConfigDir: "demo", ConfigDirEnv: "DEMO_CONFIG_DIR", HostEnv: "DEMO_HOST"}}
Bind(m)
t.Setenv("DEMO_CONFIG_DIR", t.TempDir())

hosts, err := LoadHosts()
if err != nil {
t.Fatalf("LoadHosts: %v", err)
}
hosts.Set("https://api.example.com", HostEntry{
AuthType: "bearer",
LoginType: AuthLoginOAuthDevice,
LoginProvider: "github",
User: "octo@example.com",
OAuthToken: "access",
OAuthRefreshToken: "refresh",
OAuthExpiresAt: 1790000000,
})
if err := hosts.Save(); err != nil {
t.Fatalf("Save: %v", err)
}

loaded, err := LoadHosts()
if err != nil {
t.Fatalf("LoadHosts reload: %v", err)
}
entry, ok := loaded.Get("api.example.com")
if !ok {
t.Fatal("missing host")
}
if entry.AuthType != "bearer" || entry.LoginType != AuthLoginOAuthDevice || entry.LoginProvider != "github" || entry.OAuthToken != "access" || entry.OAuthRefreshToken != "refresh" || entry.OAuthExpiresAt != 1790000000 {
t.Fatalf("entry = %+v", entry)
}
}
Loading