diff --git a/cmd/spotify/auth/root.go b/cmd/spotify/auth/root.go index b6745edc..373e4618 100644 --- a/cmd/spotify/auth/root.go +++ b/cmd/spotify/auth/root.go @@ -2,32 +2,41 @@ package auth import ( + "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "io" + "net" "net/http" "net/url" "os" + "os/exec" "strings" + "sync" "time" "github.com/drn/dots/cli/config" "github.com/drn/dots/pkg/log" - "github.com/drn/dots/pkg/run" jsoniter "github.com/json-iterator/go" - "github.com/manifoldco/promptui" ) const spotifyTokenURL = "https://accounts.spotify.com/api/token" +// authTimeout bounds how long the CLI waits for the user to complete the +// browser consent flow before giving up. +const authTimeout = 2 * time.Minute + // httpClient bounds Spotify API calls to a reasonable interactive timeout so // the CLI can't hang on a stalled connection. var httpClient = &http.Client{Timeout: 10 * time.Second} // FetchAccessToken - Returns a valid access token for the Spotify API. // * If no cached access token or refresh token +// * Starts a local loopback server on the redirect URI's port // * Opens browser to authorization URL -// * Accepts user input of authorization code +// * Captures the authorization code from the OAuth callback automatically // * Exchanges authorization code for access token and refresh token // * If access token is expired // * Exchange refresh token for a new access token @@ -36,8 +45,7 @@ func FetchAccessToken() string { refreshToken := config.Read("spotify.refresh_token") if accessToken == "" || refreshToken == "" { - authorize() - accessToken, refreshToken = exchangeAuthorizationCode(inputCode()) + accessToken, refreshToken = exchangeAuthorizationCode(authorize()) config.Write("spotify.access_token", accessToken) config.Write("spotify.refresh_token", refreshToken) } else if refreshNeeded(accessToken) { @@ -49,11 +57,42 @@ func FetchAccessToken() string { return accessToken } -func authorize() { +// authorize runs the OAuth authorization-code flow using a loopback redirect. +// SPOTIFY_REDIRECT_URI must be a loopback URL with an explicit port (e.g. +// http://127.0.0.1:8888/callback) that is also registered on the Spotify app. +// It starts a local HTTP server on that port, opens the browser to Spotify's +// consent screen, and blocks until Spotify redirects back with an +// authorization code, returning that code. +func authorize() string { + redirectURI := os.Getenv("SPOTIFY_REDIRECT_URI") + redirect, err := url.Parse(redirectURI) + if err != nil || redirect.Port() == "" || !isLoopback(redirect.Hostname()) { + log.Error( + "SPOTIFY_REDIRECT_URI must be a loopback URL with a port, "+ + "e.g. http://127.0.0.1:8888/callback (got %q)", redirectURI, + ) + os.Exit(1) + } + + state, err := randomState() + if err != nil { + log.Error("could not generate OAuth state: %s", err) + os.Exit(1) + } + + // Bind to the redirect's own host:port so the listener matches the address + // the browser will be redirected to. + addr := net.JoinHostPort(redirect.Hostname(), redirect.Port()) + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Error("could not start local server on %s: %s", addr, err) + os.Exit(1) + } + params := url.Values{ "response_type": {"code"}, "client_id": {os.Getenv("SPOTIFY_CLIENT_ID")}, - "redirect_uri": {os.Getenv("SPOTIFY_REDIRECT_URI")}, + "redirect_uri": {redirectURI}, "scope": { strings.Join([]string{ "user-read-currently-playing", @@ -61,10 +100,116 @@ func authorize() { "user-library-modify", }, " "), }, - "state": {"spotify"}, + "state": {state}, } authURL := "https://accounts.spotify.com/authorize?" + params.Encode() - run.Execute(`open "` + authURL + `"`) + + code, err := captureAuthCode(listener, callbackPath(redirect), state, func() { + fmt.Println("Opening browser to authorize Spotify…") + // Launch the browser without a shell so the auth URL is never + // interpreted by zsh. + if execErr := exec.Command("open", authURL).Start(); execErr != nil { + log.Warning("could not open browser automatically; visit:\n%s", authURL) + } + }) + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + return code +} + +// isLoopback reports whether host is a loopback address Spotify will redirect +// back to and that we can safely bind a local server on. +func isLoopback(host string) bool { + if host == "localhost" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + +// callbackPath returns the path the loopback server should listen on, defaulting +// to "/" when the redirect URI has no path (http.ServeMux rejects an empty +// pattern). +func callbackPath(redirect *url.URL) string { + if redirect.Path == "" { + return "/" + } + return redirect.Path +} + +// captureAuthCode serves the OAuth callback on listener, runs open() to launch +// the browser, and blocks until Spotify redirects back. The wantState value +// guards against CSRF. It returns the authorization code, or an error if the +// callback reports failure, the state mismatches, or the timeout elapses. +func captureAuthCode(listener net.Listener, path, wantState string, open func()) (string, error) { + type result struct { + code string + err error + } + results := make(chan result, 1) + + // once ensures only the first callback delivers a result; a browser retry + // or refresh that hits the handler again is answered without blocking on a + // channel send that no one is waiting to receive. + var once sync.Once + deliver := func(res result) { + once.Do(func() { results <- res }) + } + + mux := http.NewServeMux() + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + switch { + case query.Get("error") != "": + http.Error(w, "Spotify authorization failed.", http.StatusBadRequest) + deliver(result{err: fmt.Errorf("authorization denied: %s", query.Get("error"))}) + case query.Get("state") != wantState: + http.Error(w, "State mismatch.", http.StatusBadRequest) + deliver(result{err: errors.New("state mismatch in OAuth callback")}) + case query.Get("code") == "": + http.Error(w, "Missing authorization code.", http.StatusBadRequest) + deliver(result{err: errors.New("missing authorization code in OAuth callback")}) + default: + io.WriteString(w, "Spotify authorization complete — you can close this tab.") + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + deliver(result{code: query.Get("code")}) + } + }) + + server := &http.Server{Handler: mux} + go server.Serve(listener) + // Shut down gracefully so the handler's "you can close this tab" response + // finishes flushing to the browser before the connection is torn down. + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + server.Close() + } + }() + + open() + + select { + case res := <-results: + return res.code, res.err + case <-time.After(authTimeout): + return "", errors.New("timed out waiting for Spotify authorization") + } +} + +// randomState returns a cryptographically random hex string used as the OAuth +// state parameter to defend against CSRF on the callback. +func randomState() (string, error) { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil } func refreshNeeded(accessToken string) bool { @@ -145,27 +290,6 @@ func SendRequest( return data, response.StatusCode } -func inputCode() string { - prompt := promptui.Prompt{ - Label: "Authorization code", - Validate: validateInput, - } - - value, err := prompt.Run() - if err != nil { - log.Error("%s", err) - os.Exit(1) - } - return value -} - -func validateInput(input string) error { - if strings.TrimSpace(input) == "" { - return errors.New("must not be blank") - } - return nil -} - // HandleRequestError exits if the error is non-nil. func HandleRequestError(err error) { if err == nil { diff --git a/cmd/spotify/auth/root_test.go b/cmd/spotify/auth/root_test.go index b89b42b6..b2f8f6ac 100644 --- a/cmd/spotify/auth/root_test.go +++ b/cmd/spotify/auth/root_test.go @@ -2,6 +2,7 @@ package auth import ( "io" + "net" "net/http" "net/http/httptest" "net/url" @@ -9,6 +10,125 @@ import ( "testing" ) +// callback fires an HTTP GET against the loopback server to simulate Spotify +// redirecting back with the given query string. +func callback(t *testing.T, addr, query string) func() { + t.Helper() + return func() { + go func() { + resp, err := http.Get("http://" + addr + "/callback?" + query) + if err == nil { + resp.Body.Close() + } + }() + } +} + +func loopbackListener(t *testing.T) (net.Listener, string) { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + return listener, listener.Addr().String() +} + +func TestCaptureAuthCode_Success(t *testing.T) { + listener, addr := loopbackListener(t) + + code, err := captureAuthCode(listener, "/callback", "state-123", + callback(t, addr, "code=auth-code&state=state-123")) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != "auth-code" { + t.Errorf("code = %q, want auth-code", code) + } +} + +func TestCaptureAuthCode_StateMismatch(t *testing.T) { + listener, addr := loopbackListener(t) + + _, err := captureAuthCode(listener, "/callback", "want", + callback(t, addr, "code=auth-code&state=evil")) + + if err == nil { + t.Fatal("expected error on state mismatch, got nil") + } +} + +func TestCaptureAuthCode_AuthDenied(t *testing.T) { + listener, addr := loopbackListener(t) + + _, err := captureAuthCode(listener, "/callback", "want", + callback(t, addr, "error=access_denied&state=want")) + + if err == nil { + t.Fatal("expected error when callback reports failure, got nil") + } + if !strings.Contains(err.Error(), "access_denied") { + t.Errorf("error = %q, want it to mention access_denied", err) + } +} + +func TestCaptureAuthCode_MissingCode(t *testing.T) { + listener, addr := loopbackListener(t) + + _, err := captureAuthCode(listener, "/callback", "want", + callback(t, addr, "state=want")) + + if err == nil { + t.Fatal("expected error when code is missing, got nil") + } +} + +func TestCallbackPath_DefaultsToRoot(t *testing.T) { + noPath, _ := url.Parse("http://127.0.0.1:8888") + if got := callbackPath(noPath); got != "/" { + t.Errorf("callbackPath(no path) = %q, want /", got) + } + + withPath, _ := url.Parse("http://127.0.0.1:8888/callback") + if got := callbackPath(withPath); got != "/callback" { + t.Errorf("callbackPath(/callback) = %q, want /callback", got) + } +} + +func TestIsLoopback(t *testing.T) { + cases := map[string]bool{ + "127.0.0.1": true, + "localhost": true, + "::1": true, + "0.0.0.0": false, + "evil.com": false, + "10.0.0.5": false, + "": false, + } + for host, want := range cases { + if got := isLoopback(host); got != want { + t.Errorf("isLoopback(%q) = %v, want %v", host, got, want) + } + } +} + +func TestRandomState_UniqueAndHex(t *testing.T) { + a, err := randomState() + if err != nil { + t.Fatalf("randomState: %v", err) + } + b, err := randomState() + if err != nil { + t.Fatalf("randomState: %v", err) + } + if a == b { + t.Errorf("randomState produced identical values: %q", a) + } + if len(a) != 32 { + t.Errorf("len = %d, want 32 hex chars", len(a)) + } +} + func TestSendRequest_GET_NoQueryNoBody(t *testing.T) { var ( gotMethod string diff --git a/cmd/spotify/root.go b/cmd/spotify/root.go index 813fbdfc..39f7de75 100644 --- a/cmd/spotify/root.go +++ b/cmd/spotify/root.go @@ -3,7 +3,10 @@ // The following ENV variables must exist: // SPOTIFY_CLIENT_ID // SPOTIFY_CLIENT_SECRET -// SPOTIFY_REDIRECT_URI +// SPOTIFY_REDIRECT_URI - a loopback URL with an explicit port, e.g. +// http://127.0.0.1:8888/callback. The CLI starts a local server on that +// port to capture the OAuth callback, so the same URL must be registered +// as a redirect URI on the Spotify app. // This package provides a Spotify CLI to toggle, save, and remove the current // song