diff --git a/pkg/auth/kas.go b/pkg/auth/kas.go index 4af3240dc..64a5d7931 100644 --- a/pkg/auth/kas.go +++ b/pkg/auth/kas.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "strings" "time" @@ -16,6 +17,14 @@ import ( "github.com/google/uuid" ) +func hostFromURL(rawURL string) string { + parsed, err := url.Parse(rawURL) + if err != nil || parsed == nil { + return "" + } + return parsed.Host +} + var _ OAuth = KasAuthenticator{} const CredentialProviderKAS entity.CredentialProvider = "kas" @@ -103,7 +112,7 @@ func (a KasAuthenticator) MakeLoginCall(id, email string) (LoginCallResponse, er client := &http.Client{} resp, err := client.Do(req) if err != nil { - return LoginCallResponse{}, breverrors.WrapAndTrace(err) + return LoginCallResponse{}, breverrors.WrapAndTrace(breverrors.WrapNetworkError(err, hostFromURL(a.BaseURL))) } defer resp.Body.Close() //nolint:errcheck // fine @@ -219,6 +228,9 @@ func (a KasAuthenticator) retrieveIDToken(sessionKey, deviceID string) (string, tokenResp, err := client.Do(tokenReq) if err != nil { + if breverrors.IsNetworkError(err) { + return "", breverrors.WrapNetworkError(err, hostFromURL(a.BaseURL)) + } return "", fmt.Errorf("error sending token request: %v", err) } defer tokenResp.Body.Close() //nolint:errcheck // fine diff --git a/pkg/auth/kas_test.go b/pkg/auth/kas_test.go new file mode 100644 index 000000000..1f470d1e0 --- /dev/null +++ b/pkg/auth/kas_test.go @@ -0,0 +1,102 @@ +package auth + +import ( + stderrors "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func Test_hostFromURL(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"https://api.ngc.nvidia.com", "api.ngc.nvidia.com"}, + {"https://api.ngc.nvidia.com/token", "api.ngc.nvidia.com"}, + {"http://localhost:8080", "localhost:8080"}, + {"", ""}, + {"://nonsense", ""}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + assert.Equal(t, tc.want, hostFromURL(tc.in)) + }) + } +} + +// closedServerURL returns a URL that is guaranteed to fail to connect: an +// httptest.Server that has already been Close()d. +func closedServerURL(t *testing.T) string { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + server.Close() + return server.URL +} + +// MakeLoginCall against an unreachable host should surface a friendly +// *NetworkError instead of the raw transport error, and the host should +// match the BaseURL we tried to reach. +func TestMakeLoginCall_NetworkError(t *testing.T) { + baseURL := closedServerURL(t) + auth := KasAuthenticator{BaseURL: baseURL} + + _, err := auth.MakeLoginCall("device-id", "user@example.com") + if !assert.Error(t, err) { + return + } + + var netErr *breverrors.NetworkError + if !assert.True(t, stderrors.As(err, &netErr), "expected *NetworkError, got %T: %v", err, err) { + return + } + + expectedHost, _ := url.Parse(baseURL) + assert.Equal(t, expectedHost.Host, netErr.Host) + assert.Contains(t, netErr.Error(), "internet connection") +} + +// retrieveIDToken against an unreachable host should surface a friendly +// *NetworkError. This is the path hit by `brev shell` when the cached +// access token has expired and api.ngc.nvidia.com is unreachable. +func TestRetrieveIDToken_NetworkError(t *testing.T) { + baseURL := closedServerURL(t) + auth := KasAuthenticator{BaseURL: baseURL} + + _, err := auth.retrieveIDToken("session-key", "device-id") + if !assert.Error(t, err) { + return + } + + var netErr *breverrors.NetworkError + if !assert.True(t, stderrors.As(err, &netErr), "expected *NetworkError, got %T: %v", err, err) { + return + } + + expectedHost, _ := url.Parse(baseURL) + assert.Equal(t, expectedHost.Host, netErr.Host) +} + +// retrieveIDToken against a server that returns HTTP 4xx/5xx (i.e. a +// non-network error) should NOT be classified as a network error. +func TestRetrieveIDToken_NonNetworkErrorUnchanged(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"requestStatus":{"statusCode":"INTERNAL_ERROR"}}`)) + })) + defer server.Close() + + auth := KasAuthenticator{BaseURL: server.URL} + _, err := auth.retrieveIDToken("session-key", "device-id") + if !assert.Error(t, err) { + return + } + + var netErr *breverrors.NetworkError + assert.False(t, stderrors.As(err, &netErr), "HTTP error should not be classified as a network error: %v", err) + assert.Contains(t, err.Error(), "status code: 500") +} diff --git a/pkg/cmd/cmderrors/cmderrors.go b/pkg/cmd/cmderrors/cmderrors.go index 4b07989a8..4922da428 100644 --- a/pkg/cmd/cmderrors/cmderrors.go +++ b/pkg/cmd/cmderrors/cmderrors.go @@ -33,6 +33,11 @@ func DisplayAndHandleError(err error) { case breverrors.ValidationError: // do not report error prettyErr = (t.Yellow(errors.Cause(err).Error())) + case *breverrors.NetworkError: + // network failure is a user-facing condition, not a bug — show + // a friendly message and skip Sentry reporting. + netErr, _ := errors.Cause(err).(*breverrors.NetworkError) + prettyErr = t.Yellow(netErr.Error()) + "\n" + t.Yellow(netErr.Directive()) case breverrors.WorkspaceNotRunning: // report error to track when this occurs, but don't print stacktrace to user unless in dev mode er.ReportError(err) prettyErr = (t.Yellow(errors.Cause(err).Error())) diff --git a/pkg/cmd/cmderrors/cmderrors_test.go b/pkg/cmd/cmderrors/cmderrors_test.go index fdce70204..914fb98d6 100644 --- a/pkg/cmd/cmderrors/cmderrors_test.go +++ b/pkg/cmd/cmderrors/cmderrors_test.go @@ -1 +1,68 @@ package cmderrors + +import ( + "bytes" + "io" + "os" + "testing" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// captureStderr runs fn while collecting everything written to os.Stderr. +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + orig := os.Stderr + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stderr = w + + done := make(chan struct{}) + var buf bytes.Buffer + go func() { + _, _ = io.Copy(&buf, r) + close(done) + }() + + fn() + + // Close the writer before reading buf so the io.Copy goroutine sees + // EOF and returns. Without this, the deferred close ran after the + // return statement and buf was always empty. + _ = w.Close() + <-done + _ = r.Close() + os.Stderr = orig + + return buf.String() +} + +// A NetworkError wrapped through WrapAndTrace should render as a short, +// friendly message — no stack trace, no "github.com/brevdev/..." lines. +func TestDisplayAndHandleError_NetworkErrorIsFriendly(t *testing.T) { + netErr := &breverrors.NetworkError{Host: "api.ngc.nvidia.com"} + wrapped := breverrors.WrapAndTrace(breverrors.WrapAndTrace(netErr)) + + out := captureStderr(t, func() { + DisplayAndHandleError(wrapped) + }) + + assert.Contains(t, out, "api.ngc.nvidia.com") + assert.Contains(t, out, "internet connection") + // The hallmark of the old convoluted output: stack trace lines pointing + // at github.com/brevdev/brev-cli source paths. The friendly message + // must not include them. + assert.NotContains(t, out, "github.com/brevdev/brev-cli/pkg/auth") + assert.NotContains(t, out, "[error]") +} + +// A non-network error should still render through the default red path. +func TestDisplayAndHandleError_PlainError(t *testing.T) { + err := breverrors.New("something else broke") + out := captureStderr(t, func() { + DisplayAndHandleError(err) + }) + assert.Contains(t, out, "something else broke") +} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 236a0f7aa..8c42ecc37 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -2,6 +2,8 @@ package errors import ( "fmt" + "net" + "net/url" "runtime" "strconv" "time" @@ -130,6 +132,87 @@ func (d *DeclineToLoginError) Directive() string { return "log in to run this co var NetworkErrorMessage = "possible internet connection problem" +// NetworkError is a user-facing error for transport-level failures (DNS +// lookup failures, dial timeouts, connection refused, etc.) when reaching a +// remote service. It hides the underlying stacktrace and produces a short, +// actionable message suitable for end users. +type NetworkError struct { + // Host is the host the CLI was trying to reach (e.g. "api.ngc.nvidia.com"). + // Empty if no host could be derived from the original error. + Host string + // Cause is the underlying transport error. + Cause error +} + +func (e *NetworkError) Error() string { + if e.Host != "" { + return fmt.Sprintf("Could not reach %s — check your internet connection and try again", e.Host) + } + return "Could not reach the network — check your internet connection and try again" +} + +func (e *NetworkError) Directive() string { + if e.Host != "" { + return fmt.Sprintf("Verify you can resolve %s and that no firewall or proxy is blocking it. If the host is reachable, the service may be temporarily unavailable.", e.Host) + } + return "Verify your internet connection. If the network is healthy, the service may be temporarily unavailable." +} + +func (e *NetworkError) Unwrap() error { return e.Cause } + +// IsNetworkError reports whether err (or any error in its chain) is a +// transport-level network failure such as a DNS lookup failure, dial +// timeout, or connection refusal. +func IsNetworkError(err error) bool { + if err == nil { + return false + } + var dnsErr *net.DNSError + if stderrors.As(err, &dnsErr) { + return true + } + var opErr *net.OpError + if stderrors.As(err, &opErr) { + return true + } + var netErr net.Error + if stderrors.As(err, &netErr) { + return true + } + return false +} + +// HostFromURLError returns the host from a *url.Error in err's chain, or "" +// if no URL is available. Useful when wrapping HTTP client errors. +func HostFromURLError(err error) string { + var urlErr *url.Error + if !stderrors.As(err, &urlErr) { + return "" + } + parsed, perr := url.Parse(urlErr.URL) + if perr != nil || parsed == nil { + return "" + } + return parsed.Host +} + +// WrapNetworkError returns a *NetworkError wrapping err if err is a +// transport-level network failure; otherwise it returns err unchanged. The +// fallbackHost is used only if the host cannot be derived from err. +func WrapNetworkError(err error, fallbackHost string) error { + if err == nil { + return nil + } + if !IsNetworkError(err) { + return err + } + host := HostFromURLError(err) + if host == "" { + host = fallbackHost + } + return &NetworkError{Host: host, Cause: err} +} + type CredentialsFileNotFound struct{} func (e *CredentialsFileNotFound) Directive() string { diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go index 8cd020077..e7012c128 100644 --- a/pkg/errors/errors_test.go +++ b/pkg/errors/errors_test.go @@ -1,7 +1,12 @@ package errors import ( + stderrors "errors" "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" "testing" pkgerrors "github.com/pkg/errors" @@ -234,3 +239,131 @@ func Test_combine(t *testing.T) { assert.Equal(t, "my error 1", cerr.Error()) } + +func Test_IsNetworkError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"plain error", New("oops"), false}, + {"dns error", &net.DNSError{Name: "example.invalid", Err: "no such host"}, true}, + {"op error", &net.OpError{Op: "dial", Net: "tcp", Err: stderrors.New("connection refused")}, true}, + {"url error wrapping op error", &url.Error{ + Op: "Get", + URL: "https://api.ngc.nvidia.com/token", + Err: &net.OpError{Op: "dial", Net: "tcp", Err: stderrors.New("connection refused")}, + }, true}, + {"wrapped dns error", fmt.Errorf("outer: %w", &net.DNSError{Name: "example.invalid", Err: "no such host"}), true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, IsNetworkError(tc.err)) + }) + } +} + +func Test_HostFromURLError(t *testing.T) { + urlErr := &url.Error{Op: "Get", URL: "https://api.ngc.nvidia.com/token", Err: stderrors.New("boom")} + assert.Equal(t, "api.ngc.nvidia.com", HostFromURLError(urlErr)) + + // Wrapped url.Error is also handled. + wrapped := fmt.Errorf("outer: %w", urlErr) + assert.Equal(t, "api.ngc.nvidia.com", HostFromURLError(wrapped)) + + // Non-url.Error returns empty. + assert.Equal(t, "", HostFromURLError(stderrors.New("plain"))) + assert.Equal(t, "", HostFromURLError(nil)) +} + +func Test_WrapNetworkError(t *testing.T) { + t.Run("nil", func(t *testing.T) { + assert.Nil(t, WrapNetworkError(nil, "host")) + }) + + t.Run("non-network error passes through", func(t *testing.T) { + orig := New("not a network error") + result := WrapNetworkError(orig, "host") + assert.Equal(t, orig, result) + var netErr *NetworkError + assert.False(t, stderrors.As(result, &netErr), "non-network error should not be wrapped") + }) + + t.Run("network error gets wrapped with host from url.Error", func(t *testing.T) { + inner := &url.Error{ + Op: "Get", + URL: "https://api.ngc.nvidia.com/token", + Err: &net.OpError{Op: "dial", Net: "tcp", Err: stderrors.New("connection refused")}, + } + wrapped := WrapNetworkError(inner, "fallback.example.com") + + var netErr *NetworkError + if assert.True(t, stderrors.As(wrapped, &netErr)) { + assert.Equal(t, "api.ngc.nvidia.com", netErr.Host) + assert.Same(t, inner, netErr.Cause) + } + }) + + t.Run("network error uses fallback host when no url.Error in chain", func(t *testing.T) { + inner := &net.OpError{Op: "dial", Net: "tcp", Err: stderrors.New("connection refused")} + wrapped := WrapNetworkError(inner, "fallback.example.com") + + var netErr *NetworkError + if assert.True(t, stderrors.As(wrapped, &netErr)) { + assert.Equal(t, "fallback.example.com", netErr.Host) + } + }) +} + +func Test_NetworkError_Messages(t *testing.T) { + withHost := &NetworkError{Host: "api.ngc.nvidia.com"} + assert.Contains(t, withHost.Error(), "api.ngc.nvidia.com") + assert.Contains(t, withHost.Error(), "internet connection") + assert.Contains(t, withHost.Directive(), "api.ngc.nvidia.com") + + withoutHost := &NetworkError{} + assert.Contains(t, withoutHost.Error(), "internet connection") + assert.NotEmpty(t, withoutHost.Directive()) +} + +func Test_NetworkError_UnwrapToCause(t *testing.T) { + cause := stderrors.New("boom") + netErr := &NetworkError{Host: "h", Cause: cause} + assert.Same(t, cause, netErr.Unwrap()) + assert.True(t, stderrors.Is(netErr, cause)) +} + +// Verifies the integration path the CLI exercises: a real http.Client.Do +// against a closed listener returns an error that IsNetworkError detects +// and WrapNetworkError converts into a *NetworkError with the right host. +func Test_WrapNetworkError_RealHTTPClient(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + server.Close() + + resp, err := http.Get(server.URL) //nolint:noctx,bodyclose // test + if resp != nil { + _ = resp.Body.Close() + } + if !assert.Error(t, err) { + return + } + assert.True(t, IsNetworkError(err), "expected closed-server error to be classified as network error: %v", err) + + wrapped := WrapNetworkError(err, "") + var netErr *NetworkError + if assert.True(t, stderrors.As(wrapped, &netErr)) { + // Host should be derived from the url.Error embedded in err. + expectedHost, _ := url.Parse(server.URL) + assert.Equal(t, expectedHost.Host, netErr.Host) + } +} + +// pkgerrors.Cause must walk through WrapAndTrace layers and stop at the +// NetworkError so the cmd-level error renderer can detect it. +func Test_NetworkError_SurvivesWrapAndTrace(t *testing.T) { + netErr := &NetworkError{Host: "api.ngc.nvidia.com", Cause: stderrors.New("boom")} + wrapped := WrapAndTrace(WrapAndTrace(netErr)) + cause := pkgerrors.Cause(wrapped) + assert.IsType(t, &NetworkError{}, cause) +} diff --git a/pkg/store/http.go b/pkg/store/http.go index 60884f810..b53559979 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -37,6 +37,17 @@ func NewNoAuthHTTPClient(brevAPIURL string) *NoAuthHTTPClient { return &NoAuthHTTPClient{restyClient} } +// silentRestyLogger discards all log output from resty. We use it so that +// transient request errors (e.g. network outages while resolving an access +// token) don't dump multi-line stack traces to stderr alongside the +// friendly error rendered by cmderrors. Enable BREV_DEBUG_HTTP to +// re-enable verbose tracing. +type silentRestyLogger struct{} + +func (silentRestyLogger) Errorf(string, ...interface{}) {} +func (silentRestyLogger) Warnf(string, ...interface{}) {} +func (silentRestyLogger) Debugf(string, ...interface{}) {} + func NewRestyClient(brevAPIURL string) *resty.Client { restyClient := resty.New() restyClient.SetBaseURL(brevAPIURL) @@ -158,6 +169,12 @@ func NewAuthHTTPClient(auth Auth, brevAPIURL string, options ...Option) *AuthHTT } restyClient := NewRestyClient(brevAPIURL) restyClient.Debug = opts.Debug + if !opts.Debug { + // Silence resty's stderr WARN/ERROR lines for OnBeforeRequest + // failures. The cmderrors layer already renders a friendly, + // user-facing message; the multi-line resty trace is just noise. + restyClient.SetLogger(silentRestyLogger{}) + } restyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { token, err := auth.GetAccessToken() if err != nil {