Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions internal/auth/device_flow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package auth

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)

const (
DefaultDeviceCodePath = "/v1/cli/device_codes"
DefaultPollPath = "/v1/cli/device_codes/poll"
)

type DeviceCodeResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}

type PollResponse struct {
Status string `json:"status"`
TestAPIKey string `json:"test_api_key,omitempty"`
LiveAPIKey string `json:"live_api_key,omitempty"`
OrganizationName string `json:"organization_name,omitempty"`
}

type DeviceFlowClient struct {
baseURL string
httpClient *http.Client
}

type DeviceFlowOption func(*DeviceFlowClient)

func WithBaseURL(u string) DeviceFlowOption {
return func(c *DeviceFlowClient) { c.baseURL = u }
}

func WithHTTPClient(hc *http.Client) DeviceFlowOption {
return func(c *DeviceFlowClient) { c.httpClient = hc }
}

func NewDeviceFlowClient(opts ...DeviceFlowOption) *DeviceFlowClient {
c := &DeviceFlowClient{
baseURL: "https://api.fintoc.com",
httpClient: &http.Client{Timeout: 15 * time.Second},
}
for _, opt := range opts {
opt(c)
}
return c
}

func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
url := c.baseURL + DefaultDeviceCodePath

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "fintoc-cli")

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("device code request failed (%d): %s", resp.StatusCode, string(body))
}

var result DeviceCodeResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decoding device code response: %w", err)
}

return &result, nil
}

func (c *DeviceFlowClient) Poll(ctx context.Context, deviceCode string) (*PollResponse, error) {
url := c.baseURL + DefaultPollPath

payload, _ := json.Marshal(map[string]string{"device_code": deviceCode})

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "fintoc-cli")

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("polling device code: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if strings.Contains(string(body), "slow_down") {
return &PollResponse{Status: "slow_down"}, nil
}
return nil, fmt.Errorf("poll request failed (%d): %s", resp.StatusCode, string(body))
}

var result PollResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decoding poll response: %w", err)
}

return &result, nil
}

type PollCallback func(attempt int, elapsed time.Duration)

func (c *DeviceFlowClient) PollUntilComplete(ctx context.Context, deviceCode string, interval int, expiresIn int, onPoll PollCallback) (*PollResponse, error) {
if interval < 1 {
interval = 5
}

deadline := time.After(time.Duration(expiresIn) * time.Second)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()

start := time.Now()
attempt := 0

for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-deadline:
return nil, fmt.Errorf("device code expired after %ds", expiresIn)
case <-ticker.C:
attempt++
if onPoll != nil {
onPoll(attempt, time.Since(start))
}

result, err := c.Poll(ctx, deviceCode)
if err != nil {
if attempt < 3 {
continue
}
return nil, err
}

switch result.Status {
case "authorization_pending", "slow_down":
continue
case "complete":
return result, nil
default:
return nil, fmt.Errorf("unexpected poll status: %s", result.Status)
}
}
}
}
207 changes: 207 additions & 0 deletions internal/auth/device_flow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package auth

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)

func TestRequestDeviceCode_Success(t *testing.T) {
expected := DeviceCodeResponse{
DeviceCode: "abc123",
UserCode: "BCDF-GHJK",
VerificationURI: "https://app.fintoc.com/cli/authorize",
ExpiresIn: 900,
Interval: 5,
}

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != DefaultDeviceCodePath {
t.Errorf("expected path %s, got %s", DefaultDeviceCodePath, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(expected)
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
resp, err := client.RequestDeviceCode(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if resp.DeviceCode != expected.DeviceCode {
t.Errorf("device_code: got %q, want %q", resp.DeviceCode, expected.DeviceCode)
}
if resp.UserCode != expected.UserCode {
t.Errorf("user_code: got %q, want %q", resp.UserCode, expected.UserCode)
}
if resp.VerificationURI != expected.VerificationURI {
t.Errorf("verification_uri: got %q, want %q", resp.VerificationURI, expected.VerificationURI)
}
if resp.Interval != expected.Interval {
t.Errorf("interval: got %d, want %d", resp.Interval, expected.Interval)
}
}

func TestRequestDeviceCode_ServerError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error": "internal"}`))
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
_, err := client.RequestDeviceCode(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
}

func TestPoll_AuthorizationPending(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(PollResponse{Status: "authorization_pending"})
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
resp, err := client.Poll(context.Background(), "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Status != "authorization_pending" {
t.Errorf("status: got %q, want %q", resp.Status, "authorization_pending")
}
}

func TestPoll_Complete(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]string
json.NewDecoder(r.Body).Decode(&body)
if body["device_code"] != "abc123" {
t.Errorf("device_code: got %q, want %q", body["device_code"], "abc123")
}
json.NewEncoder(w).Encode(PollResponse{
Status: "complete",
TestAPIKey: "sk_test_xxx",
LiveAPIKey: "sk_live_xxx",
OrganizationName: "Acme Inc",
})
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
resp, err := client.Poll(context.Background(), "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Status != "complete" {
t.Errorf("status: got %q, want %q", resp.Status, "complete")
}
if resp.TestAPIKey != "sk_test_xxx" {
t.Errorf("test_api_key: got %q, want %q", resp.TestAPIKey, "sk_test_xxx")
}
if resp.LiveAPIKey != "sk_live_xxx" {
t.Errorf("live_api_key: got %q, want %q", resp.LiveAPIKey, "sk_live_xxx")
}
if resp.OrganizationName != "Acme Inc" {
t.Errorf("organization_name: got %q, want %q", resp.OrganizationName, "Acme Inc")
}
}

func TestPollUntilComplete_Success(t *testing.T) {
var calls int32

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&calls, 1)
if n < 3 {
json.NewEncoder(w).Encode(PollResponse{Status: "authorization_pending"})
return
}
json.NewEncoder(w).Encode(PollResponse{
Status: "complete",
TestAPIKey: "sk_test_done",
OrganizationName: "Test Org",
})
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
resp, err := client.PollUntilComplete(context.Background(), "abc", 1, 30, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Status != "complete" {
t.Errorf("status: got %q, want %q", resp.Status, "complete")
}
if resp.TestAPIKey != "sk_test_done" {
t.Errorf("test_api_key: got %q, want %q", resp.TestAPIKey, "sk_test_done")
}
}

func TestPollUntilComplete_ContextCanceled(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(PollResponse{Status: "authorization_pending"})
}))
defer srv.Close()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
_, err := client.PollUntilComplete(ctx, "abc", 1, 60, nil)
if err == nil {
t.Fatal("expected error from canceled context")
}
}

func TestPollUntilComplete_CallbackCalled(t *testing.T) {
var calls int32
var cbCount int

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&calls, 1)
if n < 2 {
json.NewEncoder(w).Encode(PollResponse{Status: "authorization_pending"})
return
}
json.NewEncoder(w).Encode(PollResponse{Status: "complete", TestAPIKey: "sk_test_cb"})
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
_, err := client.PollUntilComplete(context.Background(), "abc", 1, 30, func(attempt int, elapsed time.Duration) {
cbCount++
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cbCount < 1 {
t.Errorf("callback was not called (count=%d)", cbCount)
}
}

func TestPoll_SendsDeviceCodeInBody(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]string
json.NewDecoder(r.Body).Decode(&body)
if body["device_code"] != "my_device_code" {
t.Errorf("device_code in body: got %q, want %q", body["device_code"], "my_device_code")
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Content-Type: got %q, want %q", r.Header.Get("Content-Type"), "application/json")
}
json.NewEncoder(w).Encode(PollResponse{Status: "authorization_pending"})
}))
defer srv.Close()

client := NewDeviceFlowClient(WithBaseURL(srv.URL))
client.Poll(context.Background(), "my_device_code")
}
Loading