Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pkg/auth/kas.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
Expand All @@ -16,6 +17,14 @@ import (
"github.com/google/uuid"
)

func hostFromURL(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil || parsed == nil {
return ""
}
return parsed.Host
}

var _ OAuth = KasAuthenticator{}

const CredentialProviderKAS entity.CredentialProvider = "kas"
Expand Down Expand Up @@ -103,7 +112,7 @@ func (a KasAuthenticator) MakeLoginCall(id, email string) (LoginCallResponse, er
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return LoginCallResponse{}, breverrors.WrapAndTrace(err)
return LoginCallResponse{}, breverrors.WrapAndTrace(breverrors.WrapNetworkError(err, hostFromURL(a.BaseURL)))
}
defer resp.Body.Close() //nolint:errcheck // fine

Expand Down Expand Up @@ -219,6 +228,9 @@ func (a KasAuthenticator) retrieveIDToken(sessionKey, deviceID string) (string,

tokenResp, err := client.Do(tokenReq)
if err != nil {
if breverrors.IsNetworkError(err) {
return "", breverrors.WrapNetworkError(err, hostFromURL(a.BaseURL))
}
return "", fmt.Errorf("error sending token request: %v", err)
}
defer tokenResp.Body.Close() //nolint:errcheck // fine
Expand Down
102 changes: 102 additions & 0 deletions pkg/auth/kas_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package auth

import (
stderrors "errors"
"net/http"
"net/http/httptest"
"net/url"
"testing"

breverrors "github.com/brevdev/brev-cli/pkg/errors"
"github.com/stretchr/testify/assert"
)

func Test_hostFromURL(t *testing.T) {
cases := []struct {
in string
want string
}{
{"https://api.ngc.nvidia.com", "api.ngc.nvidia.com"},
{"https://api.ngc.nvidia.com/token", "api.ngc.nvidia.com"},
{"http://localhost:8080", "localhost:8080"},
{"", ""},
{"://nonsense", ""},
}
for _, tc := range cases {
t.Run(tc.in, func(t *testing.T) {
assert.Equal(t, tc.want, hostFromURL(tc.in))
})
}
}

// closedServerURL returns a URL that is guaranteed to fail to connect: an
// httptest.Server that has already been Close()d.
func closedServerURL(t *testing.T) string {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
server.Close()
return server.URL
}

// MakeLoginCall against an unreachable host should surface a friendly
// *NetworkError instead of the raw transport error, and the host should
// match the BaseURL we tried to reach.
func TestMakeLoginCall_NetworkError(t *testing.T) {
baseURL := closedServerURL(t)
auth := KasAuthenticator{BaseURL: baseURL}

_, err := auth.MakeLoginCall("device-id", "user@example.com")
if !assert.Error(t, err) {
return
}

var netErr *breverrors.NetworkError
if !assert.True(t, stderrors.As(err, &netErr), "expected *NetworkError, got %T: %v", err, err) {
return
}

expectedHost, _ := url.Parse(baseURL)
assert.Equal(t, expectedHost.Host, netErr.Host)
assert.Contains(t, netErr.Error(), "internet connection")
}

// retrieveIDToken against an unreachable host should surface a friendly
// *NetworkError. This is the path hit by `brev shell` when the cached
// access token has expired and api.ngc.nvidia.com is unreachable.
func TestRetrieveIDToken_NetworkError(t *testing.T) {
baseURL := closedServerURL(t)
auth := KasAuthenticator{BaseURL: baseURL}

_, err := auth.retrieveIDToken("session-key", "device-id")
if !assert.Error(t, err) {
return
}

var netErr *breverrors.NetworkError
if !assert.True(t, stderrors.As(err, &netErr), "expected *NetworkError, got %T: %v", err, err) {
return
}

expectedHost, _ := url.Parse(baseURL)
assert.Equal(t, expectedHost.Host, netErr.Host)
}

// retrieveIDToken against a server that returns HTTP 4xx/5xx (i.e. a
// non-network error) should NOT be classified as a network error.
func TestRetrieveIDToken_NonNetworkErrorUnchanged(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"requestStatus":{"statusCode":"INTERNAL_ERROR"}}`))
}))
defer server.Close()

auth := KasAuthenticator{BaseURL: server.URL}
_, err := auth.retrieveIDToken("session-key", "device-id")
if !assert.Error(t, err) {
return
}

var netErr *breverrors.NetworkError
assert.False(t, stderrors.As(err, &netErr), "HTTP error should not be classified as a network error: %v", err)
assert.Contains(t, err.Error(), "status code: 500")
}
5 changes: 5 additions & 0 deletions pkg/cmd/cmderrors/cmderrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func DisplayAndHandleError(err error) {
case breverrors.ValidationError:
// do not report error
prettyErr = (t.Yellow(errors.Cause(err).Error()))
case *breverrors.NetworkError:
// network failure is a user-facing condition, not a bug — show
// a friendly message and skip Sentry reporting.
netErr, _ := errors.Cause(err).(*breverrors.NetworkError)
prettyErr = t.Yellow(netErr.Error()) + "\n" + t.Yellow(netErr.Directive())
case breverrors.WorkspaceNotRunning: // report error to track when this occurs, but don't print stacktrace to user unless in dev mode
er.ReportError(err)
prettyErr = (t.Yellow(errors.Cause(err).Error()))
Expand Down
67 changes: 67 additions & 0 deletions pkg/cmd/cmderrors/cmderrors_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,68 @@
package cmderrors

import (
"bytes"
"io"
"os"
"testing"

breverrors "github.com/brevdev/brev-cli/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// captureStderr runs fn while collecting everything written to os.Stderr.
func captureStderr(t *testing.T, fn func()) string {
t.Helper()
orig := os.Stderr
r, w, err := os.Pipe()
require.NoError(t, err)
os.Stderr = w

done := make(chan struct{})
var buf bytes.Buffer
go func() {
_, _ = io.Copy(&buf, r)
close(done)
}()

fn()

// Close the writer before reading buf so the io.Copy goroutine sees
// EOF and returns. Without this, the deferred close ran after the
// return statement and buf was always empty.
_ = w.Close()
<-done
_ = r.Close()
os.Stderr = orig

return buf.String()
}

// A NetworkError wrapped through WrapAndTrace should render as a short,
// friendly message — no stack trace, no "github.com/brevdev/..." lines.
func TestDisplayAndHandleError_NetworkErrorIsFriendly(t *testing.T) {
netErr := &breverrors.NetworkError{Host: "api.ngc.nvidia.com"}
wrapped := breverrors.WrapAndTrace(breverrors.WrapAndTrace(netErr))

out := captureStderr(t, func() {
DisplayAndHandleError(wrapped)
})

assert.Contains(t, out, "api.ngc.nvidia.com")
assert.Contains(t, out, "internet connection")
// The hallmark of the old convoluted output: stack trace lines pointing
// at github.com/brevdev/brev-cli source paths. The friendly message
// must not include them.
assert.NotContains(t, out, "github.com/brevdev/brev-cli/pkg/auth")
assert.NotContains(t, out, "[error]")
}

// A non-network error should still render through the default red path.
func TestDisplayAndHandleError_PlainError(t *testing.T) {
err := breverrors.New("something else broke")
out := captureStderr(t, func() {
DisplayAndHandleError(err)
})
assert.Contains(t, out, "something else broke")
}
83 changes: 83 additions & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package errors

import (
"fmt"
"net"
"net/url"
"runtime"
"strconv"
"time"
Expand Down Expand Up @@ -130,6 +132,87 @@ func (d *DeclineToLoginError) Directive() string { return "log in to run this co

var NetworkErrorMessage = "possible internet connection problem"

// NetworkError is a user-facing error for transport-level failures (DNS
// lookup failures, dial timeouts, connection refused, etc.) when reaching a
// remote service. It hides the underlying stacktrace and produces a short,
// actionable message suitable for end users.
type NetworkError struct {
// Host is the host the CLI was trying to reach (e.g. "api.ngc.nvidia.com").
// Empty if no host could be derived from the original error.
Host string
// Cause is the underlying transport error.
Cause error
}

func (e *NetworkError) Error() string {
if e.Host != "" {
return fmt.Sprintf("Could not reach %s — check your internet connection and try again", e.Host)
}
return "Could not reach the network — check your internet connection and try again"
}

func (e *NetworkError) Directive() string {
if e.Host != "" {
return fmt.Sprintf("Verify you can resolve %s and that no firewall or proxy is blocking it. If the host is reachable, the service may be temporarily unavailable.", e.Host)
}
return "Verify your internet connection. If the network is healthy, the service may be temporarily unavailable."
}

func (e *NetworkError) Unwrap() error { return e.Cause }

// IsNetworkError reports whether err (or any error in its chain) is a
// transport-level network failure such as a DNS lookup failure, dial
// timeout, or connection refusal.
func IsNetworkError(err error) bool {
if err == nil {
return false
}
var dnsErr *net.DNSError
if stderrors.As(err, &dnsErr) {
return true
}
var opErr *net.OpError
if stderrors.As(err, &opErr) {
return true
}
var netErr net.Error
if stderrors.As(err, &netErr) {
return true
}
return false
}

// HostFromURLError returns the host from a *url.Error in err's chain, or ""
// if no URL is available. Useful when wrapping HTTP client errors.
func HostFromURLError(err error) string {
var urlErr *url.Error
if !stderrors.As(err, &urlErr) {
return ""
}
parsed, perr := url.Parse(urlErr.URL)
if perr != nil || parsed == nil {
return ""
}
return parsed.Host
}

// WrapNetworkError returns a *NetworkError wrapping err if err is a
// transport-level network failure; otherwise it returns err unchanged. The
// fallbackHost is used only if the host cannot be derived from err.
func WrapNetworkError(err error, fallbackHost string) error {
if err == nil {
return nil
}
if !IsNetworkError(err) {
return err
}
host := HostFromURLError(err)
if host == "" {
host = fallbackHost
}
return &NetworkError{Host: host, Cause: err}
}

type CredentialsFileNotFound struct{}

func (e *CredentialsFileNotFound) Directive() string {
Expand Down
Loading
Loading