Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 154 additions & 30 deletions cmd/spotify/auth/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,41 @@
package auth

import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"sync"
"time"

"github.com/drn/dots/cli/config"
"github.com/drn/dots/pkg/log"
"github.com/drn/dots/pkg/run"
jsoniter "github.com/json-iterator/go"
"github.com/manifoldco/promptui"
)

const spotifyTokenURL = "https://accounts.spotify.com/api/token"

// authTimeout bounds how long the CLI waits for the user to complete the
// browser consent flow before giving up.
const authTimeout = 2 * time.Minute

// httpClient bounds Spotify API calls to a reasonable interactive timeout so
// the CLI can't hang on a stalled connection.
var httpClient = &http.Client{Timeout: 10 * time.Second}

// FetchAccessToken - Returns a valid access token for the Spotify API.
// * If no cached access token or refresh token
// * Starts a local loopback server on the redirect URI's port
// * Opens browser to authorization URL
// * Accepts user input of authorization code
// * Captures the authorization code from the OAuth callback automatically
// * Exchanges authorization code for access token and refresh token
// * If access token is expired
// * Exchange refresh token for a new access token
Expand All @@ -36,8 +45,7 @@ func FetchAccessToken() string {
refreshToken := config.Read("spotify.refresh_token")

if accessToken == "" || refreshToken == "" {
authorize()
accessToken, refreshToken = exchangeAuthorizationCode(inputCode())
accessToken, refreshToken = exchangeAuthorizationCode(authorize())
config.Write("spotify.access_token", accessToken)
config.Write("spotify.refresh_token", refreshToken)
} else if refreshNeeded(accessToken) {
Expand All @@ -49,22 +57,159 @@ func FetchAccessToken() string {
return accessToken
}

func authorize() {
// authorize runs the OAuth authorization-code flow using a loopback redirect.
// SPOTIFY_REDIRECT_URI must be a loopback URL with an explicit port (e.g.
// http://127.0.0.1:8888/callback) that is also registered on the Spotify app.
// It starts a local HTTP server on that port, opens the browser to Spotify's
// consent screen, and blocks until Spotify redirects back with an
// authorization code, returning that code.
func authorize() string {
redirectURI := os.Getenv("SPOTIFY_REDIRECT_URI")
redirect, err := url.Parse(redirectURI)
if err != nil || redirect.Port() == "" || !isLoopback(redirect.Hostname()) {
log.Error(
"SPOTIFY_REDIRECT_URI must be a loopback URL with a port, "+
"e.g. http://127.0.0.1:8888/callback (got %q)", redirectURI,
)
os.Exit(1)
}

state, err := randomState()
if err != nil {
log.Error("could not generate OAuth state: %s", err)
os.Exit(1)
}

// Bind to the redirect's own host:port so the listener matches the address
// the browser will be redirected to.
addr := net.JoinHostPort(redirect.Hostname(), redirect.Port())
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Error("could not start local server on %s: %s", addr, err)
os.Exit(1)
}

params := url.Values{
"response_type": {"code"},
"client_id": {os.Getenv("SPOTIFY_CLIENT_ID")},
"redirect_uri": {os.Getenv("SPOTIFY_REDIRECT_URI")},
"redirect_uri": {redirectURI},
"scope": {
strings.Join([]string{
"user-read-currently-playing",
"user-library-read",
"user-library-modify",
}, " "),
},
"state": {"spotify"},
"state": {state},
}
authURL := "https://accounts.spotify.com/authorize?" + params.Encode()
run.Execute(`open "` + authURL + `"`)

code, err := captureAuthCode(listener, callbackPath(redirect), state, func() {
fmt.Println("Opening browser to authorize Spotify…")
// Launch the browser without a shell so the auth URL is never
// interpreted by zsh.
if execErr := exec.Command("open", authURL).Start(); execErr != nil {
log.Warning("could not open browser automatically; visit:\n%s", authURL)
}
})
if err != nil {
log.Error("%s", err)
os.Exit(1)
}
return code
}

// isLoopback reports whether host is a loopback address Spotify will redirect
// back to and that we can safely bind a local server on.
func isLoopback(host string) bool {
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}

// callbackPath returns the path the loopback server should listen on, defaulting
// to "/" when the redirect URI has no path (http.ServeMux rejects an empty
// pattern).
func callbackPath(redirect *url.URL) string {
if redirect.Path == "" {
return "/"
}
return redirect.Path
}

// captureAuthCode serves the OAuth callback on listener, runs open() to launch
// the browser, and blocks until Spotify redirects back. The wantState value
// guards against CSRF. It returns the authorization code, or an error if the
// callback reports failure, the state mismatches, or the timeout elapses.
func captureAuthCode(listener net.Listener, path, wantState string, open func()) (string, error) {
type result struct {
code string
err error
}
results := make(chan result, 1)

// once ensures only the first callback delivers a result; a browser retry
// or refresh that hits the handler again is answered without blocking on a
// channel send that no one is waiting to receive.
var once sync.Once
deliver := func(res result) {
once.Do(func() { results <- res })
}

mux := http.NewServeMux()
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
switch {
case query.Get("error") != "":
http.Error(w, "Spotify authorization failed.", http.StatusBadRequest)
deliver(result{err: fmt.Errorf("authorization denied: %s", query.Get("error"))})
case query.Get("state") != wantState:
http.Error(w, "State mismatch.", http.StatusBadRequest)
deliver(result{err: errors.New("state mismatch in OAuth callback")})
case query.Get("code") == "":
http.Error(w, "Missing authorization code.", http.StatusBadRequest)
deliver(result{err: errors.New("missing authorization code in OAuth callback")})
default:
io.WriteString(w, "Spotify authorization complete — you can close this tab.")
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
deliver(result{code: query.Get("code")})
}
})

server := &http.Server{Handler: mux}
go server.Serve(listener)
// Shut down gracefully so the handler's "you can close this tab" response
// finishes flushing to the browser before the connection is torn down.
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
server.Close()
}
}()

open()

select {
case res := <-results:
return res.code, res.err
case <-time.After(authTimeout):
return "", errors.New("timed out waiting for Spotify authorization")
}
}

// randomState returns a cryptographically random hex string used as the OAuth
// state parameter to defend against CSRF on the callback.
func randomState() (string, error) {
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}

func refreshNeeded(accessToken string) bool {
Expand Down Expand Up @@ -145,27 +290,6 @@ func SendRequest(
return data, response.StatusCode
}

func inputCode() string {
prompt := promptui.Prompt{
Label: "Authorization code",
Validate: validateInput,
}

value, err := prompt.Run()
if err != nil {
log.Error("%s", err)
os.Exit(1)
}
return value
}

func validateInput(input string) error {
if strings.TrimSpace(input) == "" {
return errors.New("must not be blank")
}
return nil
}

// HandleRequestError exits if the error is non-nil.
func HandleRequestError(err error) {
if err == nil {
Expand Down
120 changes: 120 additions & 0 deletions cmd/spotify/auth/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,133 @@ package auth

import (
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)

// callback fires an HTTP GET against the loopback server to simulate Spotify
// redirecting back with the given query string.
func callback(t *testing.T, addr, query string) func() {
t.Helper()
return func() {
go func() {
resp, err := http.Get("http://" + addr + "/callback?" + query)
if err == nil {
resp.Body.Close()
}
}()
}
}

func loopbackListener(t *testing.T) (net.Listener, string) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
return listener, listener.Addr().String()
}

func TestCaptureAuthCode_Success(t *testing.T) {
listener, addr := loopbackListener(t)

code, err := captureAuthCode(listener, "/callback", "state-123",
callback(t, addr, "code=auth-code&state=state-123"))

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != "auth-code" {
t.Errorf("code = %q, want auth-code", code)
}
}

func TestCaptureAuthCode_StateMismatch(t *testing.T) {
listener, addr := loopbackListener(t)

_, err := captureAuthCode(listener, "/callback", "want",
callback(t, addr, "code=auth-code&state=evil"))

if err == nil {
t.Fatal("expected error on state mismatch, got nil")
}
}

func TestCaptureAuthCode_AuthDenied(t *testing.T) {
listener, addr := loopbackListener(t)

_, err := captureAuthCode(listener, "/callback", "want",
callback(t, addr, "error=access_denied&state=want"))

if err == nil {
t.Fatal("expected error when callback reports failure, got nil")
}
if !strings.Contains(err.Error(), "access_denied") {
t.Errorf("error = %q, want it to mention access_denied", err)
}
}

func TestCaptureAuthCode_MissingCode(t *testing.T) {
listener, addr := loopbackListener(t)

_, err := captureAuthCode(listener, "/callback", "want",
callback(t, addr, "state=want"))

if err == nil {
t.Fatal("expected error when code is missing, got nil")
}
}

func TestCallbackPath_DefaultsToRoot(t *testing.T) {
noPath, _ := url.Parse("http://127.0.0.1:8888")
if got := callbackPath(noPath); got != "/" {
t.Errorf("callbackPath(no path) = %q, want /", got)
}

withPath, _ := url.Parse("http://127.0.0.1:8888/callback")
if got := callbackPath(withPath); got != "/callback" {
t.Errorf("callbackPath(/callback) = %q, want /callback", got)
}
}

func TestIsLoopback(t *testing.T) {
cases := map[string]bool{
"127.0.0.1": true,
"localhost": true,
"::1": true,
"0.0.0.0": false,
"evil.com": false,
"10.0.0.5": false,
"": false,
}
for host, want := range cases {
if got := isLoopback(host); got != want {
t.Errorf("isLoopback(%q) = %v, want %v", host, got, want)
}
}
}

func TestRandomState_UniqueAndHex(t *testing.T) {
a, err := randomState()
if err != nil {
t.Fatalf("randomState: %v", err)
}
b, err := randomState()
if err != nil {
t.Fatalf("randomState: %v", err)
}
if a == b {
t.Errorf("randomState produced identical values: %q", a)
}
if len(a) != 32 {
t.Errorf("len = %d, want 32 hex chars", len(a))
}
}

func TestSendRequest_GET_NoQueryNoBody(t *testing.T) {
var (
gotMethod string
Expand Down
5 changes: 4 additions & 1 deletion cmd/spotify/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// The following ENV variables must exist:
// SPOTIFY_CLIENT_ID
// SPOTIFY_CLIENT_SECRET
// SPOTIFY_REDIRECT_URI
// SPOTIFY_REDIRECT_URI - a loopback URL with an explicit port, e.g.
// http://127.0.0.1:8888/callback. The CLI starts a local server on that
// port to capture the OAuth callback, so the same URL must be registered
// as a redirect URI on the Spotify app.

// This package provides a Spotify CLI to toggle, save, and remove the current
// song
Expand Down
Loading