diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index 6a00880476..f09373eea4 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -197,7 +197,7 @@ func serve(ctx context.Context) { log.WithError(err).Fatal("http server listen failed") } err = httpSrv.Serve(listener) - if err == http.ErrServerClosed { + if errors.Is(err, http.ErrServerClosed) { log.Info("http server closed") } else if err != nil { log.WithError(err).Fatal("http server serve failed") diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index c492c22cce..2c7f7ccfe4 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -2,6 +2,7 @@ package apierrors type ErrorCode = string +// All error codes the auth server returns are defined here. const ( // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. ErrorCodeUnknown ErrorCode = "unknown" diff --git a/internal/api/apierrors/errorcode_gen.go b/internal/api/apierrors/errorcode_gen.go new file mode 100644 index 0000000000..585d11a796 --- /dev/null +++ b/internal/api/apierrors/errorcode_gen.go @@ -0,0 +1,107 @@ +package apierrors + +//go:generate go test -run TestGenerate -args -generate +//go:generate go fmt + +var errorCodesMap = map[string]string{ + "anonymous_provider_disabled": "ErrorCodeAnonymousProviderDisabled", + "bad_code_verifier": "ErrorCodeBadCodeVerifier", + "bad_json": "ErrorCodeBadJSON", + "bad_jwt": "ErrorCodeBadJWT", + "bad_oauth_callback": "ErrorCodeBadOAuthCallback", + "bad_oauth_state": "ErrorCodeBadOAuthState", + "captcha_failed": "ErrorCodeCaptchaFailed", + "conflict": "ErrorCodeConflict", + "current_password_invalid": "ErrorCodeCurrentPasswordMismatch", + "current_password_required": "ErrorCodeCurrentPasswordRequired", + "custom_provider_not_found": "ErrorCodeCustomProviderNotFound", + "email_address_invalid": "ErrorCodeEmailAddressInvalid", + "email_address_not_authorized": "ErrorCodeEmailAddressNotAuthorized", + "email_address_not_provided": "ErrorCodeEmailAddressNotProvided", + "email_conflict_identity_not_deletable": "ErrorCodeEmailConflictIdentityNotDeletable", + "email_exists": "ErrorCodeEmailExists", + "email_not_confirmed": "ErrorCodeEmailNotConfirmed", + "email_provider_disabled": "ErrorCodeEmailProviderDisabled", + "feature_disabled": "ErrorCodeFeatureDisabled", + "flow_state_expired": "ErrorCodeFlowStateExpired", + "flow_state_not_found": "ErrorCodeFlowStateNotFound", + "hook_payload_invalid_content_type": "ErrorCodeHookPayloadInvalidContentType", + "hook_payload_over_size_limit": "ErrorCodeHookPayloadOverSizeLimit", + "hook_timeout": "ErrorCodeHookTimeout", + "hook_timeout_after_retry": "ErrorCodeHookTimeoutAfterRetry", + "identity_already_exists": "ErrorCodeIdentityAlreadyExists", + "identity_not_found": "ErrorCodeIdentityNotFound", + "insufficient_aal": "ErrorCodeInsufficientAAL", + "invalid_credentials": "ErrorCodeInvalidCredentials", + "invite_not_found": "ErrorCodeInviteNotFound", + "manual_linking_disabled": "ErrorCodeManualLinkingDisabled", + "mfa_challenge_expired": "ErrorCodeMFAChallengeExpired", + "mfa_factor_name_conflict": "ErrorCodeMFAFactorNameConflict", + "mfa_factor_not_found": "ErrorCodeMFAFactorNotFound", + "mfa_ip_address_mismatch": "ErrorCodeMFAIPAddressMismatch", + "mfa_phone_enroll_not_enabled": "ErrorCodeMFAPhoneEnrollDisabled", + "mfa_phone_verify_not_enabled": "ErrorCodeMFAPhoneVerifyDisabled", + "mfa_totp_enroll_not_enabled": "ErrorCodeMFATOTPEnrollDisabled", + "mfa_totp_verify_not_enabled": "ErrorCodeMFATOTPVerifyDisabled", + "mfa_verification_failed": "ErrorCodeMFAVerificationFailed", + "mfa_verification_rejected": "ErrorCodeMFAVerificationRejected", + "mfa_verified_factor_exists": "ErrorCodeMFAVerifiedFactorExists", + "mfa_webauthn_enroll_not_enabled": "ErrorCodeMFAWebAuthnEnrollDisabled", + "mfa_webauthn_verify_not_enabled": "ErrorCodeMFAWebAuthnVerifyDisabled", + "no_authorization": "ErrorCodeNoAuthorization", + "not_admin": "ErrorCodeNotAdmin", + "oauth_authorization_not_found": "ErrorCodeOAuthAuthorizationNotFound", + "oauth_client_not_found": "ErrorCodeOAuthClientNotFound", + "oauth_client_state_expired": "ErrorCodeOAuthClientStateExpired", + "oauth_client_state_not_found": "ErrorCodeOAuthClientStateNotFound", + "oauth_consent_not_found": "ErrorCodeOAuthConsentNotFound", + "oauth_dynamic_client_registration_disabled": "ErrorCodeOAuthDynamicClientRegistrationDisabled", + "oauth_invalid_state": "ErrorCodeOAuthInvalidState", + "oauth_provider_not_supported": "ErrorCodeOAuthProviderNotSupported", + "otp_disabled": "ErrorCodeOTPDisabled", + "otp_expired": "ErrorCodeOTPExpired", + "over_custom_provider_quota": "ErrorCodeOverCustomProviderQuota", + "over_email_send_rate_limit": "ErrorCodeOverEmailSendRateLimit", + "over_request_rate_limit": "ErrorCodeOverRequestRateLimit", + "over_sms_send_rate_limit": "ErrorCodeOverSMSSendRateLimit", + "phone_exists": "ErrorCodePhoneExists", + "phone_not_confirmed": "ErrorCodePhoneNotConfirmed", + "phone_provider_disabled": "ErrorCodePhoneProviderDisabled", + "provider_disabled": "ErrorCodeProviderDisabled", + "provider_email_needs_verification": "ErrorCodeProviderEmailNeedsVerification", + "reauthentication_needed": "ErrorCodeReauthenticationNeeded", + "reauthentication_not_valid": "ErrorCodeReauthenticationNotValid", + "refresh_token_already_used": "ErrorCodeRefreshTokenAlreadyUsed", + "refresh_token_not_found": "ErrorCodeRefreshTokenNotFound", + "request_timeout": "ErrorCodeRequestTimeout", + "same_password": "ErrorCodeSamePassword", + "saml_assertion_no_email": "ErrorCodeSAMLAssertionNoEmail", + "saml_assertion_no_user_id": "ErrorCodeSAMLAssertionNoUserID", + "saml_entity_id_mismatch": "ErrorCodeSAMLEntityIDMismatch", + "saml_idp_already_exists": "ErrorCodeSAMLIdPAlreadyExists", + "saml_idp_not_found": "ErrorCodeSAMLIdPNotFound", + "saml_metadata_fetch_failed": "ErrorCodeSAMLMetadataFetchFailed", + "saml_provider_disabled": "ErrorCodeSAMLProviderDisabled", + "saml_relay_state_expired": "ErrorCodeSAMLRelayStateExpired", + "saml_relay_state_not_found": "ErrorCodeSAMLRelayStateNotFound", + "session_expired": "ErrorCodeSessionExpired", + "session_not_found": "ErrorCodeSessionNotFound", + "signup_disabled": "ErrorCodeSignupDisabled", + "single_identity_not_deletable": "ErrorCodeSingleIdentityNotDeletable", + "sms_send_failed": "ErrorCodeSMSSendFailed", + "sso_domain_already_exists": "ErrorCodeSSODomainAlreadyExists", + "sso_provider_disabled": "ErrorCodeSSOProviderDisabled", + "sso_provider_not_found": "ErrorCodeSSOProviderNotFound", + "too_many_enrolled_mfa_factors": "ErrorCodeTooManyEnrolledMFAFactors", + "unexpected_audience": "ErrorCodeUnexpectedAudience", + "unexpected_failure": "ErrorCodeUnexpectedFailure", + "unknown": "ErrorCodeUnknown", + "user_already_exists": "ErrorCodeUserAlreadyExists", + "user_banned": "ErrorCodeUserBanned", + "user_not_found": "ErrorCodeUserNotFound", + "user_sso_managed": "ErrorCodeUserSSOManaged", + "validation_failed": "ErrorCodeValidationFailed", + "weak_password": "ErrorCodeWeakPassword", + "web3_provider_disabled": "ErrorCodeWeb3ProviderDisabled", + "web3_unsupported_chain": "ErrorCodeWeb3UnsupportedChain", +} diff --git a/internal/api/apierrors/errorcode_test.go b/internal/api/apierrors/errorcode_test.go new file mode 100644 index 0000000000..7f34c1c4bc --- /dev/null +++ b/internal/api/apierrors/errorcode_test.go @@ -0,0 +1,151 @@ +package apierrors + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "maps" + "os" + "slices" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +var generateFlag = flag.Bool("generate", false, "Run tests that generate code") + +func TestErrorCodesMap(t *testing.T) { + cur := helpParseErrorCodes(t) + gen := errorCodesMap + + for curCode, curName := range cur { + genName, ok := gen[curCode] + if !ok { + t.Fatalf("error code %q: (%v) missing in errorCodesMap", + curCode, curName) + } + if genName != curName { + t.Fatalf("error code %q: (%v) has different name (%q) in errorCodesMap", + curCode, curName, genName) + } + } + if a, b := len(cur), len(gen); a != b { + const msg = "generated code out of sync:" + + " errorCodeSlice len(%v) != constant declaration len (%v)" + t.Fatalf(msg, a, b) + } +} + +func TestGenerate(t *testing.T) { + if !*generateFlag { + t.SkipNow() + } + + ecm := helpParseErrorCodes(t) + ecs := slices.Sorted(maps.Keys(ecm)) + + var sb strings.Builder + sb.WriteString("package apierrors\n\n") + sb.WriteString("//go:generate go test -run TestGenerate -args -generate\n") + sb.WriteString("//go:generate go fmt\n\n") + + { + sb.WriteString("var errorCodesMap = map[string]string{\n") + for _, ec := range ecs { + fmt.Fprintf(&sb, "\t%q: %q,\n", ec, ecm[ec]) + } + sb.WriteString("}\n\n") + } + + os.WriteFile("errorcode_gen.go", []byte(sb.String()), 0644) +} + +func helpParseErrorCodes(t *testing.T) map[string]string { + ecm, err := parseErrorCodesOnce() + require.NoError(t, err) + require.NotEmpty(t, ecm) + return maps.Clone(ecm) +} + +var parseErrorCodesOnce = sync.OnceValues(func() (map[string]string, error) { + return parseErrorCodes() +}) + +func parseErrorCodes() (map[string]string, error) { + data, err := os.ReadFile(`errorcode.go`) + if err != nil { + const msg = "parseErrorCodes: os.ReadFile(`errorcode.go`): %w" + return nil, fmt.Errorf(msg, err) + } + src := string(data) + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "", src, parser.SkipObjectResolution) + if err != nil { + const msg = "parseErrorCodes: parser.ParseFile: %w" + return nil, fmt.Errorf(msg, err) + } + + ecm := make(map[string]string) + for declIdx, decl := range f.Decls { + if err := parseErrorCodesDecl(ecm, declIdx, decl); err != nil { + return nil, fmt.Errorf("parseErrorCodes %w", err) + } + } + return ecm, nil +} + +func parseErrorCodesDecl(ecm map[string]string, decIdx int, decl ast.Decl) error { + dec, ok := decl.(*ast.GenDecl) + if !ok || dec.Tok != token.CONST { + return nil + } + if n := len(dec.Specs); n == 0 { + return fmt.Errorf("decl[%d]: specs are empty", decIdx) + } + for idx, spec := range dec.Specs { + valSpec, ok := spec.(*ast.ValueSpec) + if !ok { + return fmt.Errorf("const[%d]: unexpected type: %T", idx, spec) + } + if n := len(valSpec.Names); n != 1 { + return fmt.Errorf("const[%d]: unexpected const len: %T", idx, n) + } + + constName := valSpec.Names[0].Name + if !strings.HasPrefix(constName, "ErrorCode") { + return fmt.Errorf("const[%d]: missing ErrorCode prefix: %v", idx, constName) + } + if n := len(valSpec.Values); n != 1 { + return fmt.Errorf("const[%d]: unexpected const value len: %v", idx, n) + } + + constExpr := valSpec.Values[0] + basicLit, ok := constExpr.(*ast.BasicLit) + if !ok { + return fmt.Errorf("const[%d]: unexpected const value expr type: %T", idx, constExpr) + } + + constValue := basicLit.Value + if n := len(constValue); n <= 3 { + return fmt.Errorf("const[%d]: unexpected const value string len: %v (%q)", + idx, n, constValue) + } + if constValue[0] != '"' || constValue[len(constValue)-1] != '"' { + return fmt.Errorf("const[%d]: unexpected const value string quoting (%q)", + idx, constValue) + } + constValue = constValue[1 : len(constValue)-1] + + if prev, found := ecm[constValue]; found { + msg := "const[%d]: duplicate error code: %q: already defined by %q" + return fmt.Errorf(msg, idx, constValue, prev) + } + ecm[constValue] = constName + } + return nil +} diff --git a/internal/api/apierrors/metrics.go b/internal/api/apierrors/metrics.go new file mode 100644 index 0000000000..c3a2007ad4 --- /dev/null +++ b/internal/api/apierrors/metrics.go @@ -0,0 +1,81 @@ +package apierrors + +import ( + "context" + "errors" + "fmt" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// TODO(cstockton): Don't like how these are global variables here. I think we +// should probably have a metrics package which is initialized before the api +// server is created and then passed in as an option to the *API. +var ( + errorCodeCounter metric.Int64Counter + errorCodeAttrsByCode = make(map[string]metric.MeasurementOption) +) + +func RecordErrorCode(ctx context.Context, errorCode ErrorCode) { + attrs, ok := errorCodeAttrsByCode[errorCode] + if !ok { + attrs = errorCodeAttrsByCode[ErrorCodeUnknown] + } + errorCodeCounter.Add(ctx, 1, attrs) +} + +func RecordPostgresCode(ctx context.Context, code string) { + attrs := metric.WithAttributeSet( + attribute.NewSet( + attribute.String("type", "postgres"), + attribute.String("error", code), + ), + ) + errorCodeCounter.Add(ctx, 1, attrs) +} + +func InitMetrics() error { + return initMetrics(errorCodesMap) +} + +func initMetrics(ecm map[string]string) error { + if len(errorCodesMap) == 0 { + const msg = "InitMetrics: errorCodesMap is empty" + return errors.New(msg) + } + + counter, err := otel.Meter("gotrue").Int64Counter( + "global_auth_errors_total", + metric.WithDescription("Number of error codes returned by type and error."), + metric.WithUnit("{type}"), + metric.WithUnit("{error}"), + ) + if err != nil { + return fmt.Errorf("InitMetrics: %w", err) + } + + // TODO(cstockton): I'm not sure about having a single dimension of + // "error_code", as I begin trying to dig into the types of errors we + // raise I might want to add a type specifier. For example OAuthError does + // not have an auth error code, but may wrap one internally. + // + // This is really about deciding how to strike the balance between caller + // burden and best effort inferrence like we are doing here. + errorCodeAttrsByCode[ErrorCodeUnknown] = metric.WithAttributes( + attribute.String("error_code", ErrorCodeUnknown), + ) + for code := range ecm { + attrs := metric.WithAttributeSet( + attribute.NewSet( + attribute.String("type", "api"), + attribute.String("error", code), + ), + ) + errorCodeAttrsByCode[code] = attrs + } + + errorCodeCounter = counter + return nil +} diff --git a/internal/api/apierrors/metrics_test.go b/internal/api/apierrors/metrics_test.go new file mode 100644 index 0000000000..f8fb737840 --- /dev/null +++ b/internal/api/apierrors/metrics_test.go @@ -0,0 +1,12 @@ +package apierrors + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMetrics(t *testing.T) { + err := initMetrics(errorCodesMap) + require.NoError(t, err) +} diff --git a/internal/api/errors.go b/internal/api/errors.go index a9b467f36a..7aa3a88ee4 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -46,6 +46,8 @@ type ( func recoverer(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { defer func() { + apierrors.RecordPostgresCode(r.Context(), "25005") + apierrors.RecordErrorCode(r.Context(), apierrors.ErrorCodeUserSSOManaged) if rvr := recover(); rvr != nil { logEntry := observability.GetLogEntry(r) if logEntry != nil { @@ -103,6 +105,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { output.Message = e.Message output.Payload.Reasons = e.Reasons + apierrors.RecordErrorCode(r.Context(), output.Code) if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -122,6 +125,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { w.Header().Set("x-sb-error-code", output.ErrorCode) + apierrors.RecordErrorCode(r.Context(), output.ErrorCode) if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -157,6 +161,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { } } + apierrors.RecordErrorCode(r.Context(), resp.Code) if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -171,12 +176,14 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { // Provide better error messages for certain user-triggered Postgres errors. if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + apierrors.RecordPostgresCode(r.Context(), pgErr.Code) if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } return } + apierrors.RecordErrorCode(r.Context(), e.ErrorCode) if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -184,6 +191,18 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) + + // TODO(cstockton): We could either log oauth errors under a new type + // using the e.Code or try to unwrap an internal error if it exists. + // We could also add ErrorCodeBadRequest for these few edge cases as + // the visibility into oauth specific errors (user triggered errors) + // may not be useful for discovering auth server issues. Though the + // same argument could be made for most error codes. You could also + // argue some class of auth server issues might only surface through + // codes such as these. + // + // The reality here is it's about capturing as many failure modes as + // you can afford to so you have baselines to detect anomolies. if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -200,6 +219,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } + apierrors.RecordErrorCode(r.Context(), resp.Code) if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } @@ -210,6 +230,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } + apierrors.RecordErrorCode(r.Context(), httpError.ErrorCode) if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 30986b875f..c6d6749b54 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -14,6 +14,7 @@ import ( chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/oauthserver" @@ -487,7 +488,7 @@ func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { case <-ctx.Done(): err := ctx.Err() - if err == context.DeadlineExceeded { + if errors.Is(err, context.DeadlineExceeded) { httpError := &HTTPError{ HTTPStatus: http.StatusGatewayTimeout, ErrorCode: apierrors.ErrorCodeRequestTimeout, diff --git a/internal/models/errors.go b/internal/models/errors.go index c78c145229..313f4f1671 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -1,42 +1,16 @@ package models +import "errors" + +// sentinel error for all not found errors. +var errNotFound = errors.New("not found") + +// sentinel error for unique constraint violations. +var errUniqueConstraintViolated = errors.New("unique constraint violated") + // IsNotFoundError returns whether an error represents a "not found" error. func IsNotFoundError(err error) bool { - switch err.(type) { - case UserNotFoundError, *UserNotFoundError: - return true - case SessionNotFoundError, *SessionNotFoundError: - return true - case ConfirmationTokenNotFoundError, *ConfirmationTokenNotFoundError: - return true - case ConfirmationOrRecoveryTokenNotFoundError, *ConfirmationOrRecoveryTokenNotFoundError: - return true - case RefreshTokenNotFoundError, *RefreshTokenNotFoundError: - return true - case IdentityNotFoundError, *IdentityNotFoundError: - return true - case ChallengeNotFoundError, *ChallengeNotFoundError: - return true - case FactorNotFoundError, *FactorNotFoundError: - return true - case SSOProviderNotFoundError, *SSOProviderNotFoundError: - return true - case SAMLRelayStateNotFoundError, *SAMLRelayStateNotFoundError: - return true - case FlowStateNotFoundError, *FlowStateNotFoundError: - return true - case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: - return true - case OAuthServerClientNotFoundError, *OAuthServerClientNotFoundError: - return true - case OAuthServerAuthorizationNotFoundError, *OAuthServerAuthorizationNotFoundError: - return true - case OAuthClientStateNotFoundError, *OAuthClientStateNotFoundError: - return true - case CustomOAuthProviderNotFoundError, *CustomOAuthProviderNotFoundError: - return true - } - return false + return errors.Is(err, errNotFound) } type SessionNotFoundError struct{} @@ -45,6 +19,10 @@ func (e SessionNotFoundError) Error() string { return "Session not found" } +func (e SessionNotFoundError) Is(target error) bool { + return target == errNotFound +} + // UserNotFoundError represents when a user is not found. type UserNotFoundError struct{} @@ -52,6 +30,10 @@ func (e UserNotFoundError) Error() string { return "User not found" } +func (e UserNotFoundError) Is(target error) bool { + return target == errNotFound +} + // IdentityNotFoundError represents when an identity is not found. type IdentityNotFoundError struct{} @@ -59,6 +41,10 @@ func (e IdentityNotFoundError) Error() string { return "Identity not found" } +func (e IdentityNotFoundError) Is(target error) bool { + return target == errNotFound +} + // ConfirmationOrRecoveryTokenNotFoundError represents when a confirmation or recovery token is not found. type ConfirmationOrRecoveryTokenNotFoundError struct{} @@ -66,6 +52,10 @@ func (e ConfirmationOrRecoveryTokenNotFoundError) Error() string { return "Confirmation or Recovery Token not found" } +func (e ConfirmationOrRecoveryTokenNotFoundError) Is(target error) bool { + return target == errNotFound +} + // ConfirmationTokenNotFoundError represents when a confirmation token is not found. type ConfirmationTokenNotFoundError struct{} @@ -73,6 +63,10 @@ func (e ConfirmationTokenNotFoundError) Error() string { return "Confirmation Token not found" } +func (e ConfirmationTokenNotFoundError) Is(target error) bool { + return target == errNotFound +} + // RefreshTokenNotFoundError represents when a refresh token is not found. type RefreshTokenNotFoundError struct{} @@ -80,6 +74,10 @@ func (e RefreshTokenNotFoundError) Error() string { return "Refresh Token not found" } +func (e RefreshTokenNotFoundError) Is(target error) bool { + return target == errNotFound +} + // FactorNotFoundError represents when a user is not found. type FactorNotFoundError struct{} @@ -87,6 +85,10 @@ func (e FactorNotFoundError) Error() string { return "Factor not found" } +func (e FactorNotFoundError) Is(target error) bool { + return target == errNotFound +} + // ChallengeNotFoundError represents when a user is not found. type ChallengeNotFoundError struct{} @@ -94,6 +96,10 @@ func (e ChallengeNotFoundError) Error() string { return "Challenge not found" } +func (e ChallengeNotFoundError) Is(target error) bool { + return target == errNotFound +} + // SSOProviderNotFoundError represents an error when a SSO Provider can't be // found. type SSOProviderNotFoundError struct{} @@ -102,6 +108,10 @@ func (e SSOProviderNotFoundError) Error() string { return "SSO Identity Provider not found" } +func (e SSOProviderNotFoundError) Is(target error) bool { + return target == errNotFound +} + // SAMLRelayStateNotFoundError represents an error when a SAML relay state // can't be found. type SAMLRelayStateNotFoundError struct{} @@ -110,6 +120,10 @@ func (e SAMLRelayStateNotFoundError) Error() string { return "SAML RelayState not found" } +func (e SAMLRelayStateNotFoundError) Is(target error) bool { + return target == errNotFound +} + // FlowStateNotFoundError represents an error when an FlowState can't be // found. type FlowStateNotFoundError struct{} @@ -118,12 +132,12 @@ func (e FlowStateNotFoundError) Error() string { return "Flow State not found" } +func (e FlowStateNotFoundError) Is(target error) bool { + return target == errNotFound +} + func IsUniqueConstraintViolatedError(err error) bool { - switch err.(type) { - case UserEmailUniqueConflictError, *UserEmailUniqueConflictError: - return true - } - return false + return errors.Is(err, errUniqueConstraintViolated) } type UserEmailUniqueConflictError struct{} @@ -132,15 +146,27 @@ func (e UserEmailUniqueConflictError) Error() string { return "User email unique constraint violated" } +func (e UserEmailUniqueConflictError) Is(target error) bool { + return target == errUniqueConstraintViolated +} + type OAuthClientStateNotFoundError struct{} func (e OAuthClientStateNotFoundError) Error() string { return "OAuth state not found" } +func (e OAuthClientStateNotFoundError) Is(target error) bool { + return target == errNotFound +} + // CustomOAuthProviderNotFoundError represents an error when a custom OAuth/OIDC provider can't be found type CustomOAuthProviderNotFoundError struct{} func (e CustomOAuthProviderNotFoundError) Error() string { return "Custom OAuth provider not found" } + +func (e CustomOAuthProviderNotFoundError) Is(target error) bool { + return target == errNotFound +} diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 4605c5601e..d731719120 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -303,3 +303,7 @@ type OAuthServerAuthorizationNotFoundError struct{} func (e OAuthServerAuthorizationNotFoundError) Error() string { return "OAuth authorization not found" } + +func (e OAuthServerAuthorizationNotFoundError) Is(target error) bool { + return target == errNotFound +} diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index 3dc5dc8932..1e34c18795 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -205,6 +205,10 @@ func (e OAuthServerClientNotFoundError) Error() string { return "OAuth client not found" } +func (e OAuthServerClientNotFoundError) Is(target error) bool { + return target == errNotFound +} + type InvalidRedirectURIError struct { URI string } diff --git a/internal/models/one_time_token.go b/internal/models/one_time_token.go index 23b2c17a12..34e4309f3f 100644 --- a/internal/models/one_time_token.go +++ b/internal/models/one_time_token.go @@ -99,6 +99,10 @@ func (e OneTimeTokenNotFoundError) Error() string { return "One-time token not found" } +func (e OneTimeTokenNotFoundError) Is(target error) bool { + return target == errNotFound +} + type OneTimeToken struct { ID uuid.UUID `json:"id" db:"id"` diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go index df6defb7e8..d616bf3070 100644 --- a/internal/observability/metrics.go +++ b/internal/observability/metrics.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/utilities" "github.com/supabase/auth/internal/utilities/version" @@ -201,6 +202,9 @@ func ConfigureMetrics(ctx context.Context, mc *conf.MetricsConfig) error { if err = version.InitVersionMetrics(ctx, utilities.Version); err != nil { logrus.WithError(err).Error("unable to configure version metrics") } + if err = apierrors.InitMetrics(); err != nil { + logrus.WithError(err).Error("unable to configure version metrics") + } }) return err