From 0a32ca10d78753c00c12f33c35bb5b9b57783637 Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Tue, 10 Mar 2026 21:58:12 +0100 Subject: [PATCH 1/5] Initital rework, move to single error page, allow more generic errors --- internal/webserver/Webserver.go | 160 ++++++++---------- internal/webserver/api/Api.go | 2 +- .../authentication/Authentication.go | 72 ++++---- .../authentication/Authentication_test.go | 23 +-- .../webserver/authentication/oauth/Oauth.go | 24 +-- .../webserver/errorHandling/ErrorHandling.go | 155 +++++++++++++++++ .../errorcodes/Errorcodes.go | 0 internal/webserver/fileupload/FileUpload.go | 2 +- .../webserver/web/templates/html_error.tmpl | 71 +++++--- .../web/templates/html_error_auth.tmpl | 18 -- .../web/templates/html_error_header.tmpl | 17 -- .../web/templates/html_error_int_oauth.tmpl | 30 ---- 12 files changed, 324 insertions(+), 250 deletions(-) create mode 100644 internal/webserver/errorHandling/ErrorHandling.go rename internal/webserver/{api => errorHandling}/errorcodes/Errorcodes.go (100%) delete mode 100644 internal/webserver/web/templates/html_error_auth.tmpl delete mode 100644 internal/webserver/web/templates/html_error_header.tmpl delete mode 100644 internal/webserver/web/templates/html_error_int_oauth.tmpl diff --git a/internal/webserver/Webserver.go b/internal/webserver/Webserver.go index 6cdac42b..5a333986 100644 --- a/internal/webserver/Webserver.go +++ b/internal/webserver/Webserver.go @@ -42,6 +42,7 @@ import ( "github.com/forceu/gokapi/internal/webserver/authentication/oauth" "github.com/forceu/gokapi/internal/webserver/authentication/sessionmanager" "github.com/forceu/gokapi/internal/webserver/authentication/tokengeneration" + "github.com/forceu/gokapi/internal/webserver/errorHandling" "github.com/forceu/gokapi/internal/webserver/favicon" "github.com/forceu/gokapi/internal/webserver/fileupload" "github.com/forceu/gokapi/internal/webserver/ratelimiter" @@ -111,14 +112,12 @@ func Start() { mux.HandleFunc("/downloadPresigned", requireLogin(downloadPresigned, false, false)) mux.HandleFunc("/e2eSetup", requireLogin(showE2ESetup, true, false)) mux.HandleFunc("/error", showError) - mux.HandleFunc("/error-auth", showErrorAuth) - mux.HandleFunc("/error-header", showErrorHeader) - mux.HandleFunc("/error-oauth", showErrorIntOAuth) mux.HandleFunc("/filerequests", requireLogin(showUploadRequest, true, false)) mux.HandleFunc("/forgotpw", forgotPassword) mux.HandleFunc("/h/", showHotlink) mux.HandleFunc("/hotlink/", showHotlink) // backward compatibility mux.HandleFunc("/index", showIndex) + mux.HandleFunc("/test", doTest) mux.HandleFunc("/login", showLogin) mux.HandleFunc("/logs", requireLogin(showLogs, true, false)) mux.HandleFunc("/logout", doLogout) @@ -226,13 +225,13 @@ func initTemplates(templateFolderEmbedded embed.FS) { } // Sends a redirect HTTP output to the client. Variable url is used to redirect to ./url -func redirect(w http.ResponseWriter, url string) { - _, _ = io.WriteString(w, "") +func redirect(w http.ResponseWriter, r *http.Request, url string) { + http.Redirect(w, r, url, http.StatusTemporaryRedirect) } func redirectOnIncorrectId(w http.ResponseWriter, r *http.Request, url string) { ratelimiter.WaitOnFailedId(r) - redirect(w, url) + redirect(w, r, url) } type redirectValues struct { @@ -251,7 +250,7 @@ func redirectFromFilename(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") file, ok := storage.GetFile(id) if !ok { - redirect(w, "../../error") + redirect(w, r, "../../error") return } @@ -298,6 +297,10 @@ func showIndex(w http.ResponseWriter, r *http.Request) { helper.CheckIgnoreTimeout(err) } +func doTest(w http.ResponseWriter, r *http.Request) { + errorHandling.RedirectGenericErrorPage(w, r, 5) +} + func handleGenerateAuthToken(w http.ResponseWriter, r *http.Request) { user, err := authentication.GetUserFromRequest(r) if err != nil { @@ -325,7 +328,7 @@ func changePassword(w http.ResponseWriter, r *http.Request) { panic(err) } if !user.ResetPassword { - redirect(w, "admin") + redirect(w, r, "admin") return } err = r.ParseForm() @@ -344,7 +347,7 @@ func changePassword(w http.ResponseWriter, r *http.Request) { user.Password = pwHash user.ResetPassword = false database.SaveUser(user, false) - redirect(w, "admin") + redirect(w, r, "admin") return } } @@ -377,60 +380,24 @@ func validateNewPassword(newPassword string, user models.User, userCsrfToken str // Handling of /error func showError(w http.ResponseWriter, r *http.Request) { - const ( - invalidFile = iota - noCipherSupplied - wrongCipher - invalidFileRequest - ) - - errorReason := invalidFile - cardWidth := 18 - if r.URL.Query().Has("e2e") { - errorReason = noCipherSupplied - cardWidth = 25 - } - if r.URL.Query().Has("key") { - errorReason = wrongCipher - cardWidth = 25 - } - if r.URL.Query().Has("fr") { - errorReason = invalidFileRequest - cardWidth = 30 - } - err := templateFolder.ExecuteTemplate(w, "error", genericView{ - ErrorId: errorReason, - ErrorCardWidth: cardWidth, - PublicName: configuration.Get().PublicName, - CustomContent: customStaticInfo}) - helper.CheckIgnoreTimeout(err) -} -// Handling of /error-auth -func showErrorAuth(w http.ResponseWriter, r *http.Request) { - err := templateFolder.ExecuteTemplate(w, "error_auth", genericView{ - PublicName: configuration.Get().PublicName, - CustomContent: customStaticInfo}) - helper.CheckIgnoreTimeout(err) -} + displayedError := errorHandling.Get(r) -// Handling of /error-header -func showErrorHeader(w http.ResponseWriter, r *http.Request) { - err := templateFolder.ExecuteTemplate(w, "error_auth_header", genericView{ - PublicName: configuration.Get().PublicName, - CustomContent: customStaticInfo}) - helper.CheckIgnoreTimeout(err) -} + if r.URL.Query().Has("e2e") { + displayedError.ErrorId = errorHandling.TypeE2ECipher + displayedError.IsGeneric = true + displayedError.CardWidth = "25rem" + } -// Handling of /error-oauth -func showErrorIntOAuth(w http.ResponseWriter, r *http.Request) { - view := oauthErrorView{PublicName: configuration.Get().PublicName, - CustomContent: customStaticInfo} - view.IsAuthDenied = r.URL.Query().Get("isDenied") == "true" - view.ErrorProvidedName = r.URL.Query().Get("error") - view.ErrorProvidedMessage = r.URL.Query().Get("error_description") - view.ErrorGenericMessage = r.URL.Query().Get("error_generic") - err := templateFolder.ExecuteTemplate(w, "error_int_oauth", view) + err := templateFolder.ExecuteTemplate(w, "error", genericView{ + ErrorId: displayedError.ErrorId, + ErrorCardWidth: displayedError.CardWidth, + IsGenericError: displayedError.IsGeneric, + ErrorTitle: displayedError.Title, + ErrorMessage: displayedError.Message, + ErrorOauthMessage: displayedError.OAuthProviderMessage, + PublicName: configuration.Get().PublicName, + CustomContent: customStaticInfo}) helper.CheckIgnoreTimeout(err) } @@ -451,7 +418,7 @@ func showUploadRequest(w http.ResponseWriter, r *http.Request) { view := (&AdminView{}).convertGlobalConfig(ViewFileRequests, userId) if !view.ActiveUser.HasPermissionCreateFileRequests() { - redirect(w, "admin") + redirect(w, r, "admin") return } err = templateFolder.ExecuteTemplate(w, "uploadreq", view) @@ -468,7 +435,7 @@ func showApiAdmin(w http.ResponseWriter, r *http.Request) { view := (&AdminView{}).convertGlobalConfig(ViewAPI, userId) if configuration.GetEnvironment().DisableApiMenu && !view.ActiveUser.IsAdmin() { - redirect(w, "admin") + redirect(w, r, "admin") return } @@ -485,7 +452,7 @@ func showUserAdmin(w http.ResponseWriter, r *http.Request) { } view := (&AdminView{}).convertGlobalConfig(ViewUsers, userId) if !view.ActiveUser.HasPermissionManageUsers() || configuration.Get().Authentication.Method == models.AuthenticationDisabled { - redirect(w, "admin") + redirect(w, r, "admin") return } err = templateFolder.ExecuteTemplate(w, "users", view) @@ -501,25 +468,29 @@ func processApi(w http.ResponseWriter, r *http.Request) { // Shows a login form. If not authenticated, client needs to wait for three seconds. // If correct, a new session is created and the user is redirected to the admin menu func showLogin(w http.ResponseWriter, r *http.Request) { - _, ok := authentication.IsAuthenticated(w, r) + _, ok, err := authentication.IsAuthenticated(w, r) + if err != nil { + errorHandling.RedirectToErrorPage(w, r, "Unable to log in", err.Error(), errorHandling.WidthDefault) + return + } if ok { - redirect(w, "admin") + redirect(w, r, "admin") return } if configuration.Get().Authentication.Method == models.AuthenticationHeader { - redirect(w, "error-header") + errorHandling.RedirectToErrorPage(w, r, "Unauthorised", "No login information was sent from the authentication provider.", errorHandling.WidthDefault) return } if configuration.Get().Authentication.Method == models.AuthenticationOAuth2 { // If user clicked logout, force consent if r.URL.Query().Has("consent") { - redirect(w, "oauth-login?consent=true") + redirect(w, r, "oauth-login?consent=true") } else { - redirect(w, "oauth-login") + redirect(w, r, "oauth-login") } return } - err := r.ParseForm() + err = r.ParseForm() if err != nil { fmt.Println("Invalid form data sent to server for /login") fmt.Println(err) @@ -536,7 +507,7 @@ func showLogin(w http.ResponseWriter, r *http.Request) { if validCredentials { logging.LogValidLogin(user) sessionmanager.CreateSession(w, false, 0, retrievedUser.Id) - redirect(w, "admin") + redirect(w, r, "admin") return } logging.LogInvalidLogin(user, ip) @@ -569,7 +540,7 @@ type LoginView struct { // If it exists, a download form is shown, or a password needs to be entered. func showDownload(w http.ResponseWriter, r *http.Request) { addNoCacheHeader(w) - keyId := queryUrl(w, r, "id", "error") + keyId := queryUrl(w, r, "id", errorHandling.TypeFileNotFound) file, ok := storage.GetFile(keyId) if !ok || file.IsFileRequest() { redirectOnIncorrectId(w, r, "error") @@ -616,7 +587,7 @@ func showDownload(w http.ResponseWriter, r *http.Request) { if configuration.HashPassword(enteredPassword, true) == file.PasswordHash { writeFilePwCookie(w, file) // redirect so that there is no post data to be resent if user refreshes page - redirect(w, "d?id="+file.Id) + redirect(w, r, "d?id="+file.Id) return } view.IsFailedLogin = true @@ -646,10 +617,10 @@ func showHotlink(w http.ResponseWriter, r *http.Request) { } // Checks if a file is associated with the GET parameter from the current URL -func queryUrl(w http.ResponseWriter, r *http.Request, keyword string, redirectUrl string) string { +func queryUrl(w http.ResponseWriter, r *http.Request, keyword string, errorType int) string { keys, ok := r.URL.Query()[keyword] if !ok || len(keys[0]) < environment.MinLengthId { - redirect(w, redirectUrl) + errorHandling.RedirectGenericErrorPage(w, r, errorType) return "" } return keys[0] @@ -667,7 +638,7 @@ func showAdminMenu(w http.ResponseWriter, r *http.Request) { if config.Encryption.Level == encryption.EndToEndEncryption { e2einfo := database.GetEnd2EndInfo(user.Id) if !e2einfo.HasBeenSetUp() { - redirect(w, "e2eSetup") + redirect(w, r, "e2eSetup") return } } @@ -692,7 +663,7 @@ func showLogs(w http.ResponseWriter, r *http.Request) { } view := (&AdminView{}).convertGlobalConfig(ViewLogs, user) if !view.ActiveUser.HasPermissionManageLogs() { - redirect(w, "admin") + redirect(w, r, "admin") return } err = templateFolder.ExecuteTemplate(w, "logs", view) @@ -701,7 +672,7 @@ func showLogs(w http.ResponseWriter, r *http.Request) { func showE2ESetup(w http.ResponseWriter, r *http.Request) { if configuration.Get().Encryption.Level != encryption.EndToEndEncryption { - redirect(w, "admin") + redirect(w, r, "admin") return } @@ -961,23 +932,23 @@ type userInfo struct { // Handling of /publicUpload func showPublicUpload(w http.ResponseWriter, r *http.Request) { addNoCacheHeader(w) - fileRequestId := queryUrl(w, r, "id", "error?fr") + fileRequestId := queryUrl(w, r, "id", errorHandling.TypeInvalidFileRequest) request, ok := filerequest.Get(fileRequestId) if !ok { - redirect(w, "error?fr") + errorHandling.RedirectGenericErrorPage(w, r, errorHandling.TypeInvalidFileRequest) return } if !request.IsUnlimitedTime() && request.Expiry < time.Now().Unix() { - redirect(w, "error?fr") + errorHandling.RedirectGenericErrorPage(w, r, errorHandling.TypeInvalidFileRequest) return } if !request.IsUnlimitedFiles() && request.UploadedFiles >= request.MaxFiles { - redirect(w, "error?fr") + errorHandling.RedirectGenericErrorPage(w, r, errorHandling.TypeInvalidFileRequest) return } - apiKey := queryUrl(w, r, "key", "error?fr") + apiKey := queryUrl(w, r, "key", errorHandling.TypeInvalidFileRequest) if subtle.ConstantTimeCompare([]byte(request.ApiKey), []byte(apiKey)) != 1 { - redirect(w, "error?fr") + errorHandling.RedirectGenericErrorPage(w, r, errorHandling.TypeInvalidFileRequest) return } @@ -1030,7 +1001,7 @@ func downloadFileWithNameInUrl(w http.ResponseWriter, r *http.Request) { // Handling of /downloadFile // Outputs the file to the user and reduces the download remaining count for the file func downloadFile(w http.ResponseWriter, r *http.Request) { - id := queryUrl(w, r, "id", "error") + id := queryUrl(w, r, "id", errorHandling.TypeFileNotFound) serveFile(id, true, w, r) } @@ -1082,9 +1053,9 @@ func serveFile(id string, isRootUrl bool, w http.ResponseWriter, r *http.Request if savedFile.PasswordHash != "" { if !(isValidPwCookie(r, savedFile)) { if isRootUrl { - redirect(w, "d?id="+savedFile.Id) + redirect(w, r, "d?id="+savedFile.Id) } else { - redirect(w, "../../d?id="+savedFile.Id) + redirect(w, r, "../../d?id="+savedFile.Id) } return } @@ -1095,11 +1066,15 @@ func serveFile(id string, isRootUrl bool, w http.ResponseWriter, r *http.Request func requireLogin(next http.HandlerFunc, isUiCall, isPwChangeView bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { addNoCacheHeader(w) - user, isLoggedIn := authentication.IsAuthenticated(w, r) + user, isLoggedIn, err := authentication.IsAuthenticated(w, r) + if err != nil { + errorHandling.RedirectToErrorPage(w, r, "Unable to log in", err.Error(), errorHandling.WidthDefault) + return + } if isLoggedIn { if user.ResetPassword && isUiCall && configuration.Get().Authentication.Method == models.AuthenticationInternal { if !isPwChangeView { - redirect(w, "changePassword") + redirect(w, r, "changePassword") return } } @@ -1112,7 +1087,7 @@ func requireLogin(next http.HandlerFunc, isUiCall, isPwChangeView bool) http.Han _, _ = io.WriteString(w, "{\"Result\":\"error\",\"ErrorMessage\":\"Not authenticated\"}") return } - redirect(w, "login") + redirect(w, r, "login") } } @@ -1121,7 +1096,7 @@ type adminButtonContext struct { ActiveUser *models.User } -// Used internally in templates, to create buttons with user context +// Used internally in templates to create buttons with user context func newAdminButtonContext(file models.FileApiOutput, user models.User) adminButtonContext { return adminButtonContext{CurrentFile: file, ActiveUser: &user} } @@ -1164,12 +1139,15 @@ func addCacheHeader(w http.ResponseWriter) { type genericView struct { IsAdminView bool IsDownloadView bool + IsGenericError bool PublicName string RedirectUrl string + ErrorTitle string ErrorMessage string + ErrorOauthMessage string CsrfToken string + ErrorCardWidth string ErrorId int - ErrorCardWidth int MinPasswordLength int CustomContent customStatic } diff --git a/internal/webserver/api/Api.go b/internal/webserver/api/Api.go index 4dea564b..110ff42b 100644 --- a/internal/webserver/api/Api.go +++ b/internal/webserver/api/Api.go @@ -23,8 +23,8 @@ import ( "github.com/forceu/gokapi/internal/storage/filerequest" "github.com/forceu/gokapi/internal/storage/presign" "github.com/forceu/gokapi/internal/webserver/api/apiMutex" - "github.com/forceu/gokapi/internal/webserver/api/errorcodes" "github.com/forceu/gokapi/internal/webserver/authentication/users" + "github.com/forceu/gokapi/internal/webserver/errorHandling/errorcodes" "github.com/forceu/gokapi/internal/webserver/fileupload" "github.com/forceu/gokapi/internal/webserver/ratelimiter" ) diff --git a/internal/webserver/authentication/Authentication.go b/internal/webserver/authentication/Authentication.go index c1a201f5..ef2f0d5a 100644 --- a/internal/webserver/authentication/Authentication.go +++ b/internal/webserver/authentication/Authentication.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "log" "net/http" "os" @@ -19,6 +18,8 @@ import ( "github.com/forceu/gokapi/internal/models" "github.com/forceu/gokapi/internal/webserver/authentication/csrftoken" "github.com/forceu/gokapi/internal/webserver/authentication/sessionmanager" + "github.com/forceu/gokapi/internal/webserver/authentication/users" + "github.com/forceu/gokapi/internal/webserver/errorHandling" ) type userNameContext string @@ -95,41 +96,44 @@ func SetUserInRequest(r *http.Request, user models.User) *http.Request { } // IsAuthenticated returns true and the user ID if authenticated -func IsAuthenticated(w http.ResponseWriter, r *http.Request) (models.User, bool) { +func IsAuthenticated(w http.ResponseWriter, r *http.Request) (models.User, bool, error) { switch authSettings.Method { case models.AuthenticationInternal: user, ok := isGrantedSession(w, r) if ok { - return user, true + return user, true, nil } case models.AuthenticationOAuth2: user, ok := isGrantedSession(w, r) if ok { - return user, true + return user, true, nil } case models.AuthenticationHeader: - user, ok := isGrantedHeader(r) + user, ok, err := isGrantedHeader(r) + if err != nil { + return models.User{}, false, err + } if ok { - return user, true + return user, true, nil } case models.AuthenticationDisabled: adminUser, ok := database.GetSuperAdmin() if !ok { panic("no super admin found") } - return adminUser, true + return adminUser, true, nil } - return models.User{}, false + return models.User{}, false, nil } // isGrantedHeader returns true if the user was authenticated by a proxy header if enabled -func isGrantedHeader(r *http.Request) (models.User, bool) { +func isGrantedHeader(r *http.Request) (models.User, bool, error) { if authSettings.HeaderKey == "" { - return models.User{}, false + return models.User{}, false, errors.New("header key is not set") } userName := r.Header.Get(authSettings.HeaderKey) if userName == "" { - return models.User{}, false + return models.User{}, false, errors.New("no user found in header") } return getOrCreateUser(userName) } @@ -217,7 +221,7 @@ type OAuthUserClaims interface { } // CheckOauthUserAndRedirect checks if the user is allowed to use the Gokapi instance -func CheckOauthUserAndRedirect(userInfo OAuthUserInfo, w http.ResponseWriter) error { +func CheckOauthUserAndRedirect(w http.ResponseWriter, r *http.Request, userInfo OAuthUserInfo) error { var groups []string var err error @@ -228,34 +232,33 @@ func CheckOauthUserAndRedirect(userInfo OAuthUserInfo, w http.ResponseWriter) er } } if isValidOauthUser(userInfo, groups) { - user, ok := getOrCreateUser(userInfo.Email) + user, ok, errCreate := getOrCreateUser(userInfo.Email) + if errCreate != nil { + return errCreate + } if ok { sessionmanager.CreateSession(w, true, authSettings.OAuthRecheckInterval, user.Id) - redirect(w, "admin") + http.Redirect(w, r, "admin", http.StatusTemporaryRedirect) return nil } } - redirect(w, "error-auth") + errorHandling.RedirectGenericErrorPage(w, r, errorHandling.TypeOAuthNotAuthorised) return nil } -func getOrCreateUser(username string) (models.User, bool) { +func getOrCreateUser(username string) (models.User, bool, error) { user, ok := database.GetUserByName(username) - if !ok { - if authSettings.OnlyRegisteredUsers { - return models.User{}, false - } - user = models.User{ - Name: username, - UserLevel: models.UserLevelUser, - } - database.SaveUser(user, true) - user, ok = database.GetUserByName(username) - if !ok { - panic("unable to read new user") - } + if ok { + return user, true, nil + } + if authSettings.OnlyRegisteredUsers { + return models.User{}, false, nil } - return user, true + newUser, err := users.Create(username) + if err != nil { + return models.User{}, false, err + } + return newUser, true, nil } func isValidOauthUser(userInfo OAuthUserInfo, groups []string) bool { @@ -297,20 +300,15 @@ func IsEqualStringConstantTime(s1, s2 string) bool { return subtle.ConstantTimeCompare([]byte(s1), []byte(s2)) == 1 } -// Sends a redirect HTTP output to the client. Variable url is used to redirect to ./url -func redirect(w http.ResponseWriter, url string) { - _, _ = io.WriteString(w, "") -} - // Logout logs the user out and removes the session func Logout(w http.ResponseWriter, r *http.Request) { if authSettings.Method == models.AuthenticationInternal || authSettings.Method == models.AuthenticationOAuth2 { sessionmanager.LogoutSession(w, r) } if authSettings.Method == models.AuthenticationOAuth2 { - redirect(w, "login?consent=true") + http.Redirect(w, r, "login?consent=true", http.StatusTemporaryRedirect) } else { - redirect(w, "login") + http.Redirect(w, r, "login", http.StatusTemporaryRedirect) } } diff --git a/internal/webserver/authentication/Authentication_test.go b/internal/webserver/authentication/Authentication_test.go index 46934829..8066eaed 100644 --- a/internal/webserver/authentication/Authentication_test.go +++ b/internal/webserver/authentication/Authentication_test.go @@ -362,28 +362,28 @@ func TestCheckOauthUser(t *testing.T) { info := OAuthUserInfo{ ClaimsSent: testInfo{Output: []byte(`{"amr":["pwd","hwk","user","pin","mfa"],"aud":["gokapi-dev"],"auth_time":1705573822,"azp":"gokapi-dev","client_id":"gokapi-dev","email":"test@test.com","email_verified":true,"groups":["admins","dev"],"iat":1705577400,"iss":"https://auth.test.com","name":"gokapi","preferred_username":"gokapi","rat":1705577400,"sub":"944444cf3e-0546-44f2-acfa-a94444444360"}`)}, } - output, err := getOuthUserOutput(t, info) + output, err := getOauthUserOutput(t, info) test.IsNil(t, err) test.IsEqualString(t, redirectsToSite(output), "error-auth") info.Subject = "random" - output, err = getOuthUserOutput(t, info) + output, err = getOauthUserOutput(t, info) test.IsNil(t, err) test.IsEqualString(t, redirectsToSite(output), "error-auth") info.Email = "random" - output, err = getOuthUserOutput(t, info) + output, err = getOauthUserOutput(t, info) test.IsNil(t, err) test.IsEqualString(t, redirectsToSite(output), "admin") info.Email = "test@test-invalid.com" authSettings.OnlyRegisteredUsers = true - output, err = getOuthUserOutput(t, info) + output, err = getOauthUserOutput(t, info) test.IsNil(t, err) test.IsEqualString(t, redirectsToSite(output), "error-auth") info.Email = "random" - output, err = getOuthUserOutput(t, info) + output, err = getOauthUserOutput(t, info) test.IsNil(t, err) test.IsEqualString(t, redirectsToSite(output), "admin") @@ -392,7 +392,7 @@ func TestCheckOauthUser(t *testing.T) { authSettings.OAuthGroupScope = "groupscope" newClaims := testInfo{Output: []byte("{invalid")} info.ClaimsSent = newClaims - _, err = getOuthUserOutput(t, info) + _, err = getOauthUserOutput(t, info) test.IsNotNil(t, err) } @@ -406,17 +406,6 @@ func redirectsToSite(input string) string { return "other" } -func getOuthUserOutput(t *testing.T, info OAuthUserInfo) (string, error) { - t.Helper() - w := httptest.NewRecorder() - err := CheckOauthUserAndRedirect(info, w) - if err != nil { - return "", err - } - output, _ := io.ReadAll(w.Result().Body) - return string(output), nil -} - var modelUserPW = models.AuthenticationConfig{ Method: models.AuthenticationInternal, SaltAdmin: testconfiguration.SaltAdmin, diff --git a/internal/webserver/authentication/oauth/Oauth.go b/internal/webserver/authentication/oauth/Oauth.go index b9b3d5f8..b3861326 100644 --- a/internal/webserver/authentication/oauth/Oauth.go +++ b/internal/webserver/authentication/oauth/Oauth.go @@ -11,6 +11,7 @@ import ( "github.com/forceu/gokapi/internal/helper" "github.com/forceu/gokapi/internal/models" "github.com/forceu/gokapi/internal/webserver/authentication" + "github.com/forceu/gokapi/internal/webserver/errorHandling" "golang.org/x/oauth2" ) @@ -72,11 +73,11 @@ func isLoginRequired(r *http.Request) bool { func HandlerCallback(w http.ResponseWriter, r *http.Request) { state, err := r.Cookie(authentication.CookieOauth) if err != nil { - showOauthErrorPage(w, r, "Parameter state was not provided") + errorHandling.RedirectToOAuthErrorPage(w, r, "Parameter state was not provided", err) return } if r.URL.Query().Get("state") != state.Value { - showOauthErrorPage(w, r, "Parameter state did not match") + errorHandling.RedirectToOAuthErrorPage(w, r, "Parameter state did not match", err) return } @@ -87,17 +88,18 @@ func HandlerCallback(w http.ResponseWriter, r *http.Request) { oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code")) if err != nil { - showOauthErrorPage(w, r, "Failed to exchange token: "+err.Error()) + errorHandling.RedirectToOAuthErrorPage(w, r, "Failed to exchange token", err) return } userInfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) if err != nil { - showOauthErrorPage(w, r, "Failed to get userinfo: "+err.Error()) + errorHandling.RedirectToOAuthErrorPage(w, r, "Failed to get user info", err) return } if userInfo.Email == "" { - showOauthErrorPage(w, r, "An empty email address was provided.\nPlease make sure that you have your email address set in your authentication user backend.") + errorHandling.RedirectToOAuthErrorPage(w, r, "An empty email address was provided.\nPlease make sure that you have your"+ + " email address set in your authentication user backend.", nil) return } info := authentication.OAuthUserInfo{ @@ -105,20 +107,12 @@ func HandlerCallback(w http.ResponseWriter, r *http.Request) { Email: userInfo.Email, ClaimsSent: userInfo, } - err = authentication.CheckOauthUserAndRedirect(info, w) + err = authentication.CheckOauthUserAndRedirect(w, r, info) if err != nil { - showOauthErrorPage(w, r, "Failed to extract scope value: "+err.Error()) + errorHandling.RedirectToOAuthErrorPage(w, r, "Failed to continue with login", err) } } -func showOauthErrorPage(w http.ResponseWriter, r *http.Request, errorMessage string) { - // Extract the query parameters from the original URL - queryParams := r.URL.Query() - queryParams.Add("error_generic", errorMessage) - redirectURL := "./error-oauth?" + queryParams.Encode() - http.Redirect(w, r, redirectURL, http.StatusSeeOther) -} - func setCallbackCookie(w http.ResponseWriter, value string) { c := &http.Cookie{ Name: authentication.CookieOauth, diff --git a/internal/webserver/errorHandling/ErrorHandling.go b/internal/webserver/errorHandling/ErrorHandling.go new file mode 100644 index 00000000..7e7f11d7 --- /dev/null +++ b/internal/webserver/errorHandling/ErrorHandling.go @@ -0,0 +1,155 @@ +package errorHandling + +import ( + "net/http" + "strconv" + "sync" + "time" + + "github.com/forceu/gokapi/internal/helper" +) + +var tokens = make(map[string]DisplayedError) +var mutex sync.RWMutex +var cleanupOnce sync.Once + +const ttl = 5 * time.Minute + +const WidthDefault = "20rem" +const WidthWide = "30rem" +const WidthVeryWide = "65%" + +const ( + TypeFileNotFound = iota + TypeInvalidFileRequest + TypeE2ECipher + TypeOAuthNotAuthorised + TypeOAuthNonGeneric +) + +type DisplayedError struct { + Title string + Message string + OAuthProviderMessage string + CardWidth string + ErrorId int + IsGeneric bool + expiry int64 +} + +func (d DisplayedError) IsExpired() bool { + return d.expiry < time.Now().Unix() +} + +func (d DisplayedError) GetWidth() bool { + return d.expiry < time.Now().Unix() +} + +func RedirectToErrorPage(w http.ResponseWriter, r *http.Request, errorTitle, errorMessage, cardWidth string) { + result := DisplayedError{ + Title: errorTitle, + Message: errorMessage, + expiry: time.Now().Add(ttl).Unix(), + CardWidth: cardWidth, + } + redirectToError(w, r, result) +} + +func RedirectGenericErrorPage(w http.ResponseWriter, r *http.Request, genericType int) { + var cardWidth string + switch genericType { + case TypeFileNotFound: + cardWidth = WidthDefault + case TypeInvalidFileRequest: + cardWidth = WidthWide + case TypeE2ECipher: + cardWidth = WidthVeryWide + case TypeOAuthNotAuthorised: + cardWidth = WidthWide + default: + redirectToError(w, r, DisplayedError{ + Title: "Unknown error", + Message: "Gokapi cannot display this error (error code " + strconv.Itoa(genericType) + ")", + CardWidth: WidthWide, + expiry: time.Now().Add(ttl).Unix(), + }) + return + } + + result := DisplayedError{ + expiry: time.Now().Add(ttl).Unix(), + ErrorId: genericType, + IsGeneric: true, + CardWidth: cardWidth, + } + redirectToError(w, r, result) +} + +func RedirectToOAuthErrorPage(w http.ResponseWriter, r *http.Request, errorMessage string, err error) { + if r.URL.Query().Get("error") == "access_denied" { + result := DisplayedError{ + Title: "Access denied", + Message: "The request was denied by the user or authentication provider.", + expiry: time.Now().Add(ttl).Unix(), + ErrorId: TypeOAuthNonGeneric, + IsGeneric: false, + } + redirectToError(w, r, result) + return + } + result := DisplayedError{ + Title: r.URL.Query().Get("error"), + Message: errorMessage, + OAuthProviderMessage: r.URL.Query().Get("error_description"), + expiry: time.Now().Add(ttl).Unix(), + ErrorId: TypeOAuthNonGeneric, + IsGeneric: false, + } + redirectToError(w, r, result) +} + +func redirectToError(w http.ResponseWriter, r *http.Request, displayedError DisplayedError) { + token := helper.GenerateRandomString(30) + mutex.Lock() + tokens[token] = displayedError + mutex.Unlock() + + cleanupOnce.Do(func() { + go cleanup(true) + }) + http.Redirect(w, r, "./error?e="+token, http.StatusTemporaryRedirect) +} + +func Get(r *http.Request) DisplayedError { + if !r.URL.Query().Has("e") { + return DisplayedError{ + IsGeneric: true, + ErrorId: TypeFileNotFound, + CardWidth: WidthDefault, + } + } + displayedError, ok := tokens[r.URL.Query().Get("e")] + if !ok { + return DisplayedError{ + Title: "Unknown error ID", + Message: "Unfortunately, an error occurred and the error message could not be displayed.", + CardWidth: WidthDefault, + } + } + return displayedError +} + +func cleanup(periodic bool) { + mutex.Lock() + for id, token := range tokens { + if token.IsExpired() { + delete(tokens, id) + } + } + mutex.Unlock() + if periodic { + time.Sleep(time.Hour) + go cleanup(true) + } + +} diff --git a/internal/webserver/api/errorcodes/Errorcodes.go b/internal/webserver/errorHandling/errorcodes/Errorcodes.go similarity index 100% rename from internal/webserver/api/errorcodes/Errorcodes.go rename to internal/webserver/errorHandling/errorcodes/Errorcodes.go diff --git a/internal/webserver/fileupload/FileUpload.go b/internal/webserver/fileupload/FileUpload.go index 898a9200..bfd80293 100644 --- a/internal/webserver/fileupload/FileUpload.go +++ b/internal/webserver/fileupload/FileUpload.go @@ -14,7 +14,7 @@ import ( "github.com/forceu/gokapi/internal/storage" "github.com/forceu/gokapi/internal/storage/chunking" "github.com/forceu/gokapi/internal/storage/chunking/chunkreservation" - "github.com/forceu/gokapi/internal/webserver/api/errorcodes" + "github.com/forceu/gokapi/internal/webserver/errorHandling/errorcodes" ) const minChunkSize = 5 * 1024 * 1024 diff --git a/internal/webserver/web/templates/html_error.tmpl b/internal/webserver/web/templates/html_error.tmpl index a6a4ba75..a7b10f89 100644 --- a/internal/webserver/web/templates/html_error.tmpl +++ b/internal/webserver/web/templates/html_error.tmpl @@ -2,35 +2,19 @@
-
+
-{{ if eq .ErrorId 0 }} +{{ if .IsGenericError }} + + {{ if eq .ErrorId 0 }}

File not found


The link may have expired or the file has been downloaded too many times. -{{ end }} - -{{ if eq .ErrorId 1 }} -

- Missing decryption key -

-
- This file is encrypted, but no key was provided.
- Please contact the uploader and ensure the full link (including the value after the hash) is used. -{{ end }} - -{{ if eq .ErrorId 2 }} -

- Invalid decryption key -

-
- This file is encrypted, but the provided key is incorrect.
- If this file is end-to-end encrypted, please request the correct link from the uploader. -{{ end }} - -{{ if eq .ErrorId 3 }} + {{ end }} + + {{ if eq .ErrorId 1 }}

Unable to upload files

@@ -41,7 +25,48 @@
  • - The file limit for this upload request has been reached
  • - An invalid upload URL was submitted
  • + {{ end }} + + {{ if eq .ErrorId 2 }} +

    + Missing or invalid decryption key +

    +
    + This file is encrypted, but no key was provided or the key is invalid.

    + Please contact the uploader and make sure the complete URL is used, including the value after the hash. + {{ end }} + + {{ if eq .ErrorId 3 }} +

    Unauthorised user

    +
    +

    Login with OAuth provider was sucessful, however this user is not authorised to use Gokapi.



    + Log in as different user + {{ end }} + +{{ else }} + + {{ if eq .ErrorId 4 }} +

    OIDC Provider Error: {{.ErrorTitle}}

    +
    +

    Login with OAuth provider was not sucessful, the following error was raised:

    + {{ if .ErrorOauthMessage }} +

    {{ .ErrorOauthMessage }}

    + {{ end}} +

    {{ .ErrorMessage }}


    + Try again + {{ else }} + +

    + {{ .ErrorTitle }} +

    +
    + {{ .ErrorMessage }} + {{ end }} + {{ end }} + + +

    diff --git a/internal/webserver/web/templates/html_error_auth.tmpl b/internal/webserver/web/templates/html_error_auth.tmpl deleted file mode 100644 index b8e6f91d..00000000 --- a/internal/webserver/web/templates/html_error_auth.tmpl +++ /dev/null @@ -1,18 +0,0 @@ -{{define "error_auth"}}{{template "header" .}} - -
    -
    -
    -
    -

    Unauthorised user

    -
    -

    Login with OAuth provider was sucessful, however this user is not authorised to use Gokapi.



    - Log in as different user -
    -
    -
    -
    -{{ template "pagename" "LoginError"}} -{{ template "customjs" .}} -{{template "footer"}} -{{end}} diff --git a/internal/webserver/web/templates/html_error_header.tmpl b/internal/webserver/web/templates/html_error_header.tmpl deleted file mode 100644 index f48398c0..00000000 --- a/internal/webserver/web/templates/html_error_header.tmpl +++ /dev/null @@ -1,17 +0,0 @@ -{{define "error_auth_header"}}{{template "header" .}} - -
    -
    -
    -
    -

    Unauthorised

    -
    -

    Error: No login information was sent from the authentication provider.


    -
    -
    -
    -
    -{{ template "pagename" "LoginErrorHeader"}} -{{ template "customjs" .}} -{{template "footer"}} -{{end}} diff --git a/internal/webserver/web/templates/html_error_int_oauth.tmpl b/internal/webserver/web/templates/html_error_int_oauth.tmpl deleted file mode 100644 index 45bfa5c0..00000000 --- a/internal/webserver/web/templates/html_error_int_oauth.tmpl +++ /dev/null @@ -1,30 +0,0 @@ -{{define "error_int_oauth"}}{{template "header" .}} - -
    -
    -{{ if eq .ErrorProvidedName "access_denied"}} -
    -
    -

    Access denied

    -
    -

    The request was denied by the user or authentication provider.


    -{{ else }} -
    -
    -

    OIDC Provider Error {{.ErrorProvidedName}}

    -
    -

    Login with OAuth provider was not sucessful, the following error was raised:

    -{{ if .ErrorProvidedMessage }} -

    {{ .ErrorProvidedMessage }}

    -{{ end}} -

    {{ .ErrorGenericMessage }}


    -{{ end }} - Try again -
    -
    -
    -
    -{{ template "pagename" "LoginErrorOauth"}} -{{ template "customjs" .}} -{{template "footer"}} -{{end}} From 3e6e4f9d0ee3b204b72a5020069cf920c3f8080a Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Wed, 11 Mar 2026 19:23:29 +0100 Subject: [PATCH 2/5] Updated docs, removed debug endpoint, fixed some tests --- docs/examples.rst | 4 +- internal/webserver/Webserver.go | 12 ++---- .../authentication/Authentication.go | 2 +- .../authentication/Authentication_test.go | 41 +++++++++++-------- .../webserver/authentication/oauth/Oauth.go | 2 +- .../webserver/errorHandling/ErrorHandling.go | 2 +- 6 files changed, 32 insertions(+), 31 deletions(-) diff --git a/docs/examples.rst b/docs/examples.rst index 4e262c95..d6197b89 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -165,7 +165,7 @@ Keycloak ^^^^^^^^^^^^ .. note:: - This guide has been written for version 24.0.3 + This guide has been written for version 26.5.5 .. warning:: In a previous version of this guide, the client mapping was for the predefined mapper "Group memberships", which in some cases always returned the value "admin". Please make sure that you are using a custom mapper, as described in :ref:`oidcconfig_keycloak_opt` @@ -233,7 +233,7 @@ Gokapi Configuration +---------------------------+-----------------------------------------------------------------------+--------------------------------------------+ | Client Secret | Client secret provided | AhXeV7_EXAMPLE_KEY | +---------------------------+-----------------------------------------------------------------------+--------------------------------------------+ -| Recheck identity | If open ``Consent required`` is disabled, use a low interval | 12 hours | +| Recheck identity | If ``Consent required`` is disabled, use a low interval | 12 hours | +---------------------------+-----------------------------------------------------------------------+--------------------------------------------+ | Admin email address | The email address for the super-admin | gokapi@example.com | +---------------------------+-----------------------------------------------------------------------+--------------------------------------------+ diff --git a/internal/webserver/Webserver.go b/internal/webserver/Webserver.go index 5a333986..fe4ba703 100644 --- a/internal/webserver/Webserver.go +++ b/internal/webserver/Webserver.go @@ -117,7 +117,6 @@ func Start() { mux.HandleFunc("/h/", showHotlink) mux.HandleFunc("/hotlink/", showHotlink) // backward compatibility mux.HandleFunc("/index", showIndex) - mux.HandleFunc("/test", doTest) mux.HandleFunc("/login", showLogin) mux.HandleFunc("/logs", requireLogin(showLogs, true, false)) mux.HandleFunc("/logout", doLogout) @@ -297,10 +296,6 @@ func showIndex(w http.ResponseWriter, r *http.Request) { helper.CheckIgnoreTimeout(err) } -func doTest(w http.ResponseWriter, r *http.Request) { - errorHandling.RedirectGenericErrorPage(w, r, 5) -} - func handleGenerateAuthToken(w http.ResponseWriter, r *http.Request) { user, err := authentication.GetUserFromRequest(r) if err != nil { @@ -470,7 +465,7 @@ func processApi(w http.ResponseWriter, r *http.Request) { func showLogin(w http.ResponseWriter, r *http.Request) { _, ok, err := authentication.IsAuthenticated(w, r) if err != nil { - errorHandling.RedirectToErrorPage(w, r, "Unable to log in", err.Error(), errorHandling.WidthDefault) + errorHandling.RedirectToErrorPage(w, r, "Unable to log in", "The following error was raised: "+err.Error(), errorHandling.WidthDefault) return } if ok { @@ -478,7 +473,8 @@ func showLogin(w http.ResponseWriter, r *http.Request) { return } if configuration.Get().Authentication.Method == models.AuthenticationHeader { - errorHandling.RedirectToErrorPage(w, r, "Unauthorised", "No login information was sent from the authentication provider.", errorHandling.WidthDefault) + errorHandling.RedirectToErrorPage(w, r, "Unauthorised", + "No login information was sent from the authentication provider.", errorHandling.WidthDefault) return } if configuration.Get().Authentication.Method == models.AuthenticationOAuth2 { @@ -1068,7 +1064,7 @@ func requireLogin(next http.HandlerFunc, isUiCall, isPwChangeView bool) http.Han addNoCacheHeader(w) user, isLoggedIn, err := authentication.IsAuthenticated(w, r) if err != nil { - errorHandling.RedirectToErrorPage(w, r, "Unable to log in", err.Error(), errorHandling.WidthDefault) + errorHandling.RedirectToErrorPage(w, r, "Unable to log in", "The following error was raised: "+err.Error(), errorHandling.WidthDefault) return } if isLoggedIn { diff --git a/internal/webserver/authentication/Authentication.go b/internal/webserver/authentication/Authentication.go index 8dd7ce21..9a2af768 100644 --- a/internal/webserver/authentication/Authentication.go +++ b/internal/webserver/authentication/Authentication.go @@ -133,7 +133,7 @@ func isGrantedHeader(r *http.Request) (models.User, bool, error) { } userName := r.Header.Get(authSettings.HeaderKey) if userName == "" { - return models.User{}, false, errors.New("no user found in header") + return models.User{}, false, errors.New("header key is not set or empty") } return getOrCreateUser(userName) } diff --git a/internal/webserver/authentication/Authentication_test.go b/internal/webserver/authentication/Authentication_test.go index 8066eaed..02d3e5d9 100644 --- a/internal/webserver/authentication/Authentication_test.go +++ b/internal/webserver/authentication/Authentication_test.go @@ -99,7 +99,7 @@ func TestIsAuthenticated(t *testing.T) { testAuthDisabled(t) w, r := test.GetRecorder("GET", "/", nil, nil, nil) authSettings.Method = -1 - _, ok := IsAuthenticated(w, r) + _, ok, _ := IsAuthenticated(w, r) test.IsEqualBool(t, ok, false) } @@ -112,17 +112,17 @@ func testAuthSession(t *testing.T) { w, r := test.GetRecorder("GET", "/", nil, nil, nil) Init(modelUserPW) - _, ok := IsAuthenticated(w, r) + _, ok, _ := IsAuthenticated(w, r) test.IsEqualBool(t, ok, false) Init(modelOauth) - _, ok = IsAuthenticated(w, r) + _, ok, _ = IsAuthenticated(w, r) test.IsEqualBool(t, ok, false) Init(modelUserPW) w, r = test.GetRecorder("GET", "/", []test.Cookie{{ Name: "session_token", Value: "validsession", }}, nil, nil) - user, ok := IsAuthenticated(w, r) + user, ok, _ := IsAuthenticated(w, r) test.IsEqualBool(t, ok, true) test.IsEqualInt(t, user.Id, 7) test.IsEqualInt(t, exitCode, 0) @@ -137,28 +137,31 @@ func testAuthSession(t *testing.T) { func testAuthHeader(t *testing.T) { w, r := test.GetRecorder("GET", "/", nil, nil, nil) Init(modelHeader) - _, ok := IsAuthenticated(w, r) + _, ok, err := IsAuthenticated(w, r) test.IsEqualBool(t, ok, false) + test.IsNotNil(t, err) w, r = test.GetRecorder("GET", "/", nil, []test.Header{{ Name: "testHeader", Value: "testUser", }}, nil) - user, ok := IsAuthenticated(w, r) + user, ok, err := IsAuthenticated(w, r) test.IsEqualString(t, user.Name, "testuser") test.IsEqualBool(t, ok, true) + test.IsNil(t, err) authSettings.OnlyRegisteredUsers = true w, r = test.GetRecorder("GET", "/", nil, []test.Header{{ Name: "testHeader", Value: "testUser", }}, nil) - _, ok = IsAuthenticated(w, r) + _, ok, err = IsAuthenticated(w, r) + test.IsNil(t, err) test.IsEqualBool(t, ok, true) w, r = test.GetRecorder("GET", "/", nil, []test.Header{{ Name: "testHeader", Value: "otherUser2", }}, nil) - _, ok = IsAuthenticated(w, r) + _, ok, _ = IsAuthenticated(w, r) test.IsEqualBool(t, ok, false) authSettings.OnlyRegisteredUsers = false } @@ -166,7 +169,7 @@ func testAuthHeader(t *testing.T) { func testAuthDisabled(t *testing.T) { w, r := test.GetRecorder("GET", "/", nil, nil, nil) Init(modelDisabled) - user, ok := IsAuthenticated(w, r) + user, ok, _ := IsAuthenticated(w, r) test.IsEqualBool(t, ok, true) test.IsEqualInt(t, user.Id, 5) } @@ -187,14 +190,6 @@ func TestEqualString(t *testing.T) { test.IsEqualBool(t, IsEqualStringConstantTime("yes", "yes"), true) } -func TestRedirect(t *testing.T) { - w := httptest.NewRecorder() - redirect(w, "test") - output, err := io.ReadAll(w.Body) - test.IsNil(t, err) - test.IsEqualString(t, string(output), "") -} - func TestGetUserFromRequest(t *testing.T) { _, r := test.GetRecorder("GET", "/", nil, nil, nil) _, err := GetUserFromRequest(r) @@ -356,7 +351,17 @@ func (t testInfo) Claims(v interface{}) error { } return json.Unmarshal(t.Output, v) } - +func getOauthUserOutput(t *testing.T, info OAuthUserInfo) (string, error) { + t.Helper() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + err := CheckOauthUserAndRedirect(w, r, info) + if err != nil { + return "", err + } + output, _ := io.ReadAll(w.Result().Body) + return string(output), nil +} func TestCheckOauthUser(t *testing.T) { Init(modelOauth) info := OAuthUserInfo{ diff --git a/internal/webserver/authentication/oauth/Oauth.go b/internal/webserver/authentication/oauth/Oauth.go index b3861326..59a93103 100644 --- a/internal/webserver/authentication/oauth/Oauth.go +++ b/internal/webserver/authentication/oauth/Oauth.go @@ -109,7 +109,7 @@ func HandlerCallback(w http.ResponseWriter, r *http.Request) { } err = authentication.CheckOauthUserAndRedirect(w, r, info) if err != nil { - errorHandling.RedirectToOAuthErrorPage(w, r, "Failed to continue with login", err) + errorHandling.RedirectToOAuthErrorPage(w, r, "Failed to continue with login: ", err) } } diff --git a/internal/webserver/errorHandling/ErrorHandling.go b/internal/webserver/errorHandling/ErrorHandling.go index 7e7f11d7..dff1b34f 100644 --- a/internal/webserver/errorHandling/ErrorHandling.go +++ b/internal/webserver/errorHandling/ErrorHandling.go @@ -99,7 +99,7 @@ func RedirectToOAuthErrorPage(w http.ResponseWriter, r *http.Request, errorMessa } result := DisplayedError{ Title: r.URL.Query().Get("error"), - Message: errorMessage, + Message: errorMessage + " " + err.Error(), OAuthProviderMessage: r.URL.Query().Get("error_description"), expiry: time.Now().Add(ttl).Unix(), ErrorId: TypeOAuthNonGeneric, From b928b9d1779442ebab7b7b215bd7937a44d154ff Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Thu, 12 Mar 2026 06:14:43 +0100 Subject: [PATCH 3/5] Fixed more tests, fixed nil exception --- internal/test/TestHelper.go | 25 ++++++++++- internal/webserver/Webserver_test.go | 16 ++----- internal/webserver/api/Api.go | 12 +----- internal/webserver/api/Api_test.go | 7 +-- .../authentication/Authentication_test.go | 43 +++++++------------ .../authentication/oauth/Oauth_test.go | 2 +- .../webserver/errorHandling/ErrorHandling.go | 5 ++- internal/webserver/ratelimiter/RateLimiter.go | 20 +++++++-- 8 files changed, 67 insertions(+), 63 deletions(-) diff --git a/internal/test/TestHelper.go b/internal/test/TestHelper.go index f5634ba3..5e5de59e 100644 --- a/internal/test/TestHelper.go +++ b/internal/test/TestHelper.go @@ -47,6 +47,19 @@ func ResponseBodyContains(t MockT, got *httptest.ResponseRecorder, want string) } } +// ResponseIsRedirect fails test if not correct redirect +func ResponseIsRedirect(t MockT, got *httptest.ResponseRecorder, wantUrl string, ignoreParam bool) { + t.Helper() + IsEqualInt(t, got.Code, http.StatusTemporaryRedirect) + location := got.Header().Get("Location") + if ignoreParam { + location = strings.Split(location, "?")[0] + } + if !strings.HasSuffix(location, wantUrl) { + t.Errorf("Redirect Location mismatch: got %s, want to end with %s", location, wantUrl) + } +} + // ResponseBodyIs fails test if http response is not the exact string func ResponseBodyIs(t MockT, got *httptest.ResponseRecorder, want string) { t.Helper() @@ -309,7 +322,10 @@ func HttpPageResultJson(t MockT, config HttpTestConfig) []*http.Cookie { func checkResponse(t MockT, response *http.Response, config HttpTestConfig) { t.Helper() - IsEqualBool(t, response != nil, true) + if response == nil { + t.Errorf("No response received") + return + } if response.StatusCode != config.ResultCode { t.Errorf("Status Code - Got: %d Want: %d", response.StatusCode, config.ResultCode) } @@ -319,6 +335,12 @@ func checkResponse(t MockT, response *http.Response, config HttpTestConfig) { if config.IsHtml && !bytes.Contains(content, []byte("")) { t.Errorf(config.Url + ": Incorrect response, no HTML tag") } + if config.RedirectUrl != "" { + location := response.Header.Get("Location") + if !strings.HasSuffix(location, config.RedirectUrl) { + t.Errorf("Redirect Location mismatch: got %s, want to end with %s", location, config.RedirectUrl) + } + } for _, requiredString := range config.RequiredContent { if !bytes.Contains(content, []byte(requiredString)) { t.Errorf(config.Url + ": Incorrect response. Got:\n" + string(content)) @@ -343,6 +365,7 @@ type HttpTestConfig struct { Headers []Header UploadFileName string UploadFieldName string + RedirectUrl string ResultCode int Body io.Reader } diff --git a/internal/webserver/Webserver_test.go b/internal/webserver/Webserver_test.go index 26d44104..5546b7a0 100644 --- a/internal/webserver/Webserver_test.go +++ b/internal/webserver/Webserver_test.go @@ -19,9 +19,9 @@ import ( "github.com/forceu/gokapi/internal/storage/processingstatus" "github.com/forceu/gokapi/internal/test" "github.com/forceu/gokapi/internal/test/testconfiguration" - "github.com/forceu/gokapi/internal/webserver/api" "github.com/forceu/gokapi/internal/webserver/authentication" "github.com/forceu/gokapi/internal/webserver/authentication/csrftoken" + "github.com/forceu/gokapi/internal/webserver/ratelimiter" ) func TestMain(m *testing.M) { @@ -31,7 +31,7 @@ func TestMain(m *testing.M) { authentication.Init(configuration.Get().Authentication) go Start() time.Sleep(1 * time.Second) - api.SetDebugTrue() + ratelimiter.SetUnitTestMode(true) exitVal := m.Run() testconfiguration.Delete() os.Exit(exitVal) @@ -594,7 +594,8 @@ func TestApiPageNotAuthorized(t *testing.T) { test.HttpPageResult(t, test.HttpTestConfig{ Url: "http://127.0.0.1:53843/apiKeys", IsHtml: true, - RequiredContent: []string{"URL=./login"}, + RedirectUrl: "login", + ResultCode: http.StatusTemporaryRedirect, ExcludedContent: []string{"Click on the API key name to give it a new name."}, Cookies: []test.Cookie{{ Name: "session_token", @@ -674,15 +675,6 @@ func TestResponseError(t *testing.T) { test.ResponseBodyContains(t, w, "testerror") } -func TestShowErrorAuth(t *testing.T) { - t.Parallel() - test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/error-auth", - RequiredContent: []string{"Log in as different user"}, - IsHtml: true, - }) -} - func TestServeWasmDownloader(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ diff --git a/internal/webserver/api/Api.go b/internal/webserver/api/Api.go index 110ff42b..98663998 100644 --- a/internal/webserver/api/Api.go +++ b/internal/webserver/api/Api.go @@ -35,10 +35,6 @@ const LengthPublicId = 35 // LengthApiKey is the length of the private API key used for authentication const LengthApiKey = 30 -// isDebug must be false and is only set to true for running test units -// If true, rate limiting is disabled for API calls -var isDebug = false - // Process parses the request and executes the API call or returns an error message to the sender func Process(w http.ResponseWriter, r *http.Request) { w.Header().Set("cache-control", "no-store") @@ -77,12 +73,6 @@ func parseRequestUrl(r *http.Request) string { return strings.Replace(r.URL.String(), "/api", "", 1) } -// SetDebugTrue should never be called in production -// It is used to disable API rate limiting -func SetDebugTrue() { - isDebug = true -} - func apiEditFile(w http.ResponseWriter, r requestParser, user models.User) { request, ok := r.(*paramFilesModify) if !ok { @@ -1248,7 +1238,7 @@ func apiUploadRequestListSingle(w http.ResponseWriter, r requestParser, user mod func isAuthorisedForApi(r *http.Request, routing apiRoute) (models.User, bool) { keyId := r.Header.Get("apikey") - ratelimiter.WaitOnApiAuthentication(logging.GetIpAddress(r), isDebug) + ratelimiter.WaitOnApiAuthentication(logging.GetIpAddress(r)) user, apiKey, ok := isValidApiKey(keyId, true, routing.ApiPerm) if !ok { return models.User{}, false diff --git a/internal/webserver/api/Api_test.go b/internal/webserver/api/Api_test.go index 09ae3371..89a59126 100644 --- a/internal/webserver/api/Api_test.go +++ b/internal/webserver/api/Api_test.go @@ -20,6 +20,7 @@ import ( "github.com/forceu/gokapi/internal/storage" "github.com/forceu/gokapi/internal/test" "github.com/forceu/gokapi/internal/test/testconfiguration" + "github.com/forceu/gokapi/internal/webserver/ratelimiter" ) func TestMain(m *testing.M) { @@ -27,6 +28,7 @@ func TestMain(m *testing.M) { configuration.Load() configuration.ConnectDatabase() generateTestData() + ratelimiter.SetUnitTestMode(true) exitVal := m.Run() testconfiguration.Delete() os.Exit(exitVal) @@ -119,11 +121,6 @@ func getRecorderWithBody(url, apikey, method string, headers []test.Header, body return test.GetRecorder(method, url, nil, passedHeaders, body) } -func TestIsDebugModeFalse(t *testing.T) { - test.IsEqualBool(t, isDebug, false) - SetDebugTrue() -} - func testAuthorisation(t *testing.T, url string, requiredPermission models.ApiPermission) models.ApiKey { w, r := getRecorder(url, "", []test.Header{{}}) Process(w, r) diff --git a/internal/webserver/authentication/Authentication_test.go b/internal/webserver/authentication/Authentication_test.go index 02d3e5d9..96e57a04 100644 --- a/internal/webserver/authentication/Authentication_test.go +++ b/internal/webserver/authentication/Authentication_test.go @@ -5,11 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/http/httptest" "os" - "strings" "testing" "github.com/forceu/gokapi/internal/configuration" @@ -324,7 +322,7 @@ func TestLogout(t *testing.T) { test.IsEqualBool(t, ok, false) _, ok = sessionmanager.IsValidSession(w, r, false, 0) test.IsEqualBool(t, ok, false) - test.ResponseBodyContains(t, w, "") + test.ResponseIsRedirect(t, w, "login", false) Init(modelOauth) w, r, _, _ = getRecorder([]test.Cookie{{ @@ -338,7 +336,7 @@ func TestLogout(t *testing.T) { test.IsEqualBool(t, ok, false) _, ok = sessionmanager.IsValidSession(w, r, false, 0) test.IsEqualBool(t, ok, false) - test.ResponseBodyContains(t, w, "") + test.ResponseIsRedirect(t, w, "login?consent=true", false) } type testInfo struct { @@ -351,46 +349,45 @@ func (t testInfo) Claims(v interface{}) error { } return json.Unmarshal(t.Output, v) } -func getOauthUserOutput(t *testing.T, info OAuthUserInfo) (string, error) { +func getOauthUserOutput(t *testing.T, info OAuthUserInfo) (*httptest.ResponseRecorder, error) { t.Helper() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) err := CheckOauthUserAndRedirect(w, r, info) if err != nil { - return "", err + return w, err } - output, _ := io.ReadAll(w.Result().Body) - return string(output), nil + return w, nil } func TestCheckOauthUser(t *testing.T) { Init(modelOauth) info := OAuthUserInfo{ ClaimsSent: testInfo{Output: []byte(`{"amr":["pwd","hwk","user","pin","mfa"],"aud":["gokapi-dev"],"auth_time":1705573822,"azp":"gokapi-dev","client_id":"gokapi-dev","email":"test@test.com","email_verified":true,"groups":["admins","dev"],"iat":1705577400,"iss":"https://auth.test.com","name":"gokapi","preferred_username":"gokapi","rat":1705577400,"sub":"944444cf3e-0546-44f2-acfa-a94444444360"}`)}, } - output, err := getOauthUserOutput(t, info) + w, err := getOauthUserOutput(t, info) test.IsNil(t, err) - test.IsEqualString(t, redirectsToSite(output), "error-auth") + test.ResponseIsRedirect(t, w, "error", true) info.Subject = "random" - output, err = getOauthUserOutput(t, info) + w, err = getOauthUserOutput(t, info) test.IsNil(t, err) - test.IsEqualString(t, redirectsToSite(output), "error-auth") + test.ResponseIsRedirect(t, w, "error", true) info.Email = "random" - output, err = getOauthUserOutput(t, info) + w, err = getOauthUserOutput(t, info) test.IsNil(t, err) - test.IsEqualString(t, redirectsToSite(output), "admin") + test.ResponseIsRedirect(t, w, "admin", false) info.Email = "test@test-invalid.com" authSettings.OnlyRegisteredUsers = true - output, err = getOauthUserOutput(t, info) + w, err = getOauthUserOutput(t, info) test.IsNil(t, err) - test.IsEqualString(t, redirectsToSite(output), "error-auth") + test.ResponseIsRedirect(t, w, "error", true) info.Email = "random" - output, err = getOauthUserOutput(t, info) + w, err = getOauthUserOutput(t, info) test.IsNil(t, err) - test.IsEqualString(t, redirectsToSite(output), "admin") + test.ResponseIsRedirect(t, w, "admin", false) authSettings.OnlyRegisteredUsers = false authSettings.OAuthGroups = []string{"otheruser@test"} @@ -401,16 +398,6 @@ func TestCheckOauthUser(t *testing.T) { test.IsNotNil(t, err) } -func redirectsToSite(input string) string { - sites := []string{"admin", "error-auth"} - for _, site := range sites { - if strings.Contains(input, site) { - return site - } - } - return "other" -} - var modelUserPW = models.AuthenticationConfig{ Method: models.AuthenticationInternal, SaltAdmin: testconfiguration.SaltAdmin, diff --git a/internal/webserver/authentication/oauth/Oauth_test.go b/internal/webserver/authentication/oauth/Oauth_test.go index f9eb654b..c3d43a1a 100644 --- a/internal/webserver/authentication/oauth/Oauth_test.go +++ b/internal/webserver/authentication/oauth/Oauth_test.go @@ -49,7 +49,7 @@ func TestHandlerCallback_StateMismatch(t *testing.T) { HandlerCallback(rr, req) // Should redirect to error page - test.IsEqualInt(t, rr.Code, http.StatusSeeOther) + test.IsEqualInt(t, rr.Code, http.StatusTemporaryRedirect) test.IsEqualBool(t, rr.Header().Get("Location") != "", true) } diff --git a/internal/webserver/errorHandling/ErrorHandling.go b/internal/webserver/errorHandling/ErrorHandling.go index dff1b34f..0db8af25 100644 --- a/internal/webserver/errorHandling/ErrorHandling.go +++ b/internal/webserver/errorHandling/ErrorHandling.go @@ -97,9 +97,12 @@ func RedirectToOAuthErrorPage(w http.ResponseWriter, r *http.Request, errorMessa redirectToError(w, r, result) return } + if err != nil { + errorMessage = errorMessage + " " + err.Error() + } result := DisplayedError{ Title: r.URL.Query().Get("error"), - Message: errorMessage + " " + err.Error(), + Message: errorMessage, OAuthProviderMessage: r.URL.Query().Get("error_description"), expiry: time.Now().Add(ttl).Unix(), ErrorId: TypeOAuthNonGeneric, diff --git a/internal/webserver/ratelimiter/RateLimiter.go b/internal/webserver/ratelimiter/RateLimiter.go index e9e5d111..b4b1e2b0 100644 --- a/internal/webserver/ratelimiter/RateLimiter.go +++ b/internal/webserver/ratelimiter/RateLimiter.go @@ -2,6 +2,7 @@ package ratelimiter import ( "context" + "fmt" "net/http" "sync" "time" @@ -16,11 +17,22 @@ var failedIdLimiter = newLimiter() var failedDownloadPasswordLimiter = newLimiter() var failedApiKeyLimiter = newLimiter() +// isUnitTest must be false and is only set to true for running test units +// If true, rate limiting is disabled +var isUnitTest = false + type limiterEntry struct { limiter *rate.Limiter lastSeen time.Time } +// SetUnitTestMode disables all rate limiting +// This is only used for running unit tests +func SetUnitTestMode(enabled bool) { + fmt.Println("Rate limiting disabled for unit tests") + isUnitTest = enabled +} + type store struct { mu sync.Mutex limiters map[string]*limiterEntry @@ -41,10 +53,7 @@ func WaitOnLogin(ip string) { // WaitOnApiAuthentication blocks the current goroutine until the rate limiter allows a request // 200 attempts without limiting, thereafter one attempt every second -func WaitOnApiAuthentication(ip string, isDebug bool) { - if isDebug { - return - } +func WaitOnApiAuthentication(ip string) { _ = failedApiKeyLimiter.Get(ip, 1, 200).WaitN(context.Background(), 1) } @@ -69,6 +78,9 @@ func IsAllowedNewUuid(key string) bool { // Get returns the rate limiter for the given key func (s *store) Get(key string, r rate.Limit, burst int) *rate.Limiter { + if isUnitTest { + return rate.NewLimiter(r, burst) + } s.mu.Lock() defer s.mu.Unlock() From 267d496c905bb6de81d2acfdc742a9646d0d2942 Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Thu, 12 Mar 2026 10:19:23 +0100 Subject: [PATCH 4/5] Fixed tests --- internal/test/TestHelper.go | 47 ++-- internal/webserver/Webserver_test.go | 225 ++++++++---------- .../webserver/errorHandling/ErrorHandling.go | 6 +- 3 files changed, 131 insertions(+), 147 deletions(-) diff --git a/internal/test/TestHelper.go b/internal/test/TestHelper.go index 5e5de59e..85a014bf 100644 --- a/internal/test/TestHelper.go +++ b/internal/test/TestHelper.go @@ -269,6 +269,12 @@ func HttpPageResult(t MockT, config HttpTestConfig) []*http.Cookie { config.init(t) client := &http.Client{} + if config.RedirectUrl != "" { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + data := url.Values{} for _, value := range config.PostValues { data.Add(value.Key, value.Value) @@ -337,6 +343,9 @@ func checkResponse(t MockT, response *http.Response, config HttpTestConfig) { } if config.RedirectUrl != "" { location := response.Header.Get("Location") + if config.IgnoreRedirectParm { + location = strings.Split(location, "?")[0] + } if !strings.HasSuffix(location, config.RedirectUrl) { t.Errorf("Redirect Location mismatch: got %s, want to end with %s", location, config.RedirectUrl) } @@ -355,19 +364,20 @@ func checkResponse(t MockT, response *http.Response, config HttpTestConfig) { // HttpTestConfig is a struct for http test init type HttpTestConfig struct { - Url string - RequiredContent []string - ExcludedContent []string - IsHtml bool - Method string - PostValues []PostBody - Cookies []Cookie - Headers []Header - UploadFileName string - UploadFieldName string - RedirectUrl string - ResultCode int - Body io.Reader + Url string + RequiredContent []string + ExcludedContent []string + IsHtml bool + IgnoreRedirectParm bool + Method string + PostValues []PostBody + Cookies []Cookie + Headers []Header + UploadFileName string + UploadFieldName string + RedirectUrl string + ResultCode int + Body io.Reader } func (c *HttpTestConfig) init(t MockT) { @@ -379,7 +389,11 @@ func (c *HttpTestConfig) init(t MockT) { c.Method = "GET" } if c.ResultCode == 0 { - c.ResultCode = 200 + if c.RedirectUrl == "" { + c.ResultCode = 200 + } else { + c.ResultCode = 307 + } } } @@ -467,6 +481,11 @@ func HttpPostRequest(t MockT, config HttpTestConfig) []*http.Cookie { r.Header.Set(header.Name, header.Value) } client := &http.Client{} + if config.RedirectUrl != "" { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } response, err := client.Do(r) IsNil(t, err) defer response.Body.Close() diff --git a/internal/webserver/Webserver_test.go b/internal/webserver/Webserver_test.go index 5546b7a0..3e91dfc5 100644 --- a/internal/webserver/Webserver_test.go +++ b/internal/webserver/Webserver_test.go @@ -74,110 +74,96 @@ func TestStaticDirs(t *testing.T) { RequiredContent: []string{".btn-secondary:hover"}, }) } + +func postValues(username, password, csrf string) []test.PostBody { + return []test.PostBody{ + {Key: "username", Value: username}, + {Key: "password", Value: password}, + {Key: "csrf-token", Value: csrf}, + } +} + +func cookieValue(cookies []*http.Cookie, name string) string { + for _, c := range cookies { + if c.Name == name { + return c.Value + } + } + return "" +} + func TestLogin(t *testing.T) { + const loginUrl = "http://localhost:53843/login" + + // GET /login shows the login form test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/login", - RequiredContent: []string{"id=\"uname_hidden\""}, + Url: loginUrl, IsHtml: true, + ResultCode: http.StatusOK, + RequiredContent: []string{"id=\"uname_hidden\""}, }) - config := test.HttpTestConfig{ - Url: "http://localhost:53843/login", - ExcludedContent: []string{"\"Refresh\" content=\"0; URL=./admin\""}, - RequiredContent: []string{"id=\"uname_hidden\"", "Incorrect username or password"}, + + postConfig := test.HttpTestConfig{ + Url: loginUrl, IsHtml: true, Method: "POST", - PostValues: []test.PostBody{ - { - Key: "username", - Value: "invalid", - }, { - Key: "password", - Value: "invalid", - }, { - Key: "csrf-token", - Value: csrftoken.Generate(), - }, - }, - ResultCode: 200, + ResultCode: http.StatusOK, + RequiredContent: []string{"id=\"uname_hidden\"", "Incorrect username or password"}, + ExcludedContent: []string{"URL=./admin"}, } - test.HttpPostRequest(t, config) - config.PostValues = []test.PostBody{ - { - Key: "username", - Value: "test", - }, { - Key: "password", - Value: "invalid", - }, { - Key: "csrf-token", - Value: csrftoken.Generate(), - }, - } - test.HttpPostRequest(t, config) + // POST with wrong username and password shows error + postConfig.PostValues = postValues("invalid", "invalid", csrftoken.Generate()) + test.HttpPostRequest(t, postConfig) - config.PostValues = []test.PostBody{ - { - Key: "username", - Value: "test", - }, { - Key: "password", - Value: "adminadmin", - }, { - Key: "csrf-token", - Value: "invalid", - }, - } - test.HttpPostRequest(t, config) + // POST with correct username but wrong password shows error + postConfig.PostValues = postValues("test", "invalid", csrftoken.Generate()) + test.HttpPostRequest(t, postConfig) + // POST with correct credentials but invalid CSRF token shows error + postConfig.PostValues = postValues("test", "adminadmin", "invalid") + test.HttpPostRequest(t, postConfig) + + // GET /login with OAuth2 enabled redirects to oauth-login oauthConfig := configuration.Get() oauthConfig.Authentication.Method = models.AuthenticationOAuth2 oauthConfig.Authentication.OAuthProvider = "http://test.com" oauthConfig.Authentication.OAuthClientSecret = "secret" oauthConfig.Authentication.OAuthClientId = "client" - authentication.Init(configuration.Get().Authentication) - config.RequiredContent = []string{"\"Refresh\" content=\"0; URL=./oauth-login\""} - config.PostValues = []test.PostBody{} - test.HttpPageResult(t, config) + authentication.Init(oauthConfig.Authentication) + test.HttpPageResult(t, test.HttpTestConfig{ + Url: loginUrl, + ResultCode: http.StatusTemporaryRedirect, + RedirectUrl: "oauth-login", + }) configuration.Get().Authentication.Method = models.AuthenticationInternal authentication.Init(configuration.Get().Authentication) - buf := config.RequiredContent - config.RequiredContent = config.ExcludedContent - config.ExcludedContent = buf - config.PostValues = []test.PostBody{ - { - Key: "username", - Value: "test", - }, { - Key: "password", - Value: "adminadmin", - }, { - Key: "csrf-token", - Value: csrftoken.Generate(), - }, - } - cookies := test.HttpPostRequest(t, config) - var session string - for _, cookie := range cookies { - if cookie.Name == "session_token" { - session = cookie.Value - } - } + // POST with valid credentials returns a redirect to admin and sets a session cookie + postConfig.RequiredContent = nil + postConfig.ExcludedContent = nil + postConfig.IsHtml = false + postConfig.ResultCode = http.StatusTemporaryRedirect + postConfig.RedirectUrl = "admin" + postConfig.PostValues = postValues("test", "adminadmin", csrftoken.Generate()) + cookies := test.HttpPostRequest(t, postConfig) + session := cookieValue(cookies, "session_token") test.IsNotEqualString(t, session, "") - config.Cookies = []test.Cookie{{ - Name: "session_token", - Value: session, - }} - test.HttpPageResult(t, config) + // Visiting /login with a valid session redirects to admin + test.HttpPageResult(t, test.HttpTestConfig{ + Url: loginUrl, + ResultCode: http.StatusTemporaryRedirect, + RedirectUrl: "admin", + Cookies: []test.Cookie{{Name: "session_token", Value: session}}, + }) } + func TestAdminNoAuth(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/admin", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/admin", + RedirectUrl: "login", }) } func TestAdminAuth(t *testing.T) { @@ -195,9 +181,8 @@ func TestAdminAuth(t *testing.T) { func TestAdminExpiredAuth(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/admin", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/admin", + RedirectUrl: "login", Cookies: []test.Cookie{{ Name: "session_token", Value: "expiredsession", @@ -240,9 +225,8 @@ func TestAdminRenewalAuth(t *testing.T) { func TestAdminInvalidAuth(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/admin", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/admin", + RedirectUrl: "login", Cookies: []test.Cookie{{ Name: "session_token", Value: "invalid", @@ -253,9 +237,9 @@ func TestAdminInvalidAuth(t *testing.T) { func TestInvalidLink(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/d?id=123", - RequiredContent: []string{"URL=./error\""}, - IsHtml: true, + Url: "http://localhost:53843/d?id=123", + IgnoreRedirectParm: true, + RedirectUrl: "error", }) } @@ -271,16 +255,6 @@ func TestError(t *testing.T) { RequiredContent: []string{"This file is encrypted, but no key was provided"}, IsHtml: true, }) - test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/error?key", - RequiredContent: []string{"This file is encrypted, but the provided key is incorrect"}, - IsHtml: true, - }) - test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/error?fr", - RequiredContent: []string{"The file limit for this upload request has been reached"}, - IsHtml: true, - }) } func TestForgotPw(t *testing.T) { @@ -295,11 +269,10 @@ func TestForgotPw(t *testing.T) { func TestLoginCorrect(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/login", - RequiredContent: []string{"URL=./admin\""}, - IsHtml: true, - Method: "POST", - PostValues: []test.PostBody{{"username", "test"}, {"password", "adminadmin"}, {"csrf-token", csrftoken.Generate()}}, + Url: "http://localhost:53843/login", + RedirectUrl: "admin", + Method: "POST", + PostValues: []test.PostBody{{"username", "test"}, {"password", "adminadmin"}, {"csrf-token", csrftoken.Generate()}}, }) } @@ -337,9 +310,8 @@ func TestLogout(t *testing.T) { }) // Logout test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/logout", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/logout", + RedirectUrl: "login", Cookies: []test.Cookie{{ Name: "session_token", Value: "logoutsession", @@ -347,9 +319,8 @@ func TestLogout(t *testing.T) { }) // Admin after logout test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/admin", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/admin", + RedirectUrl: "login", Cookies: []test.Cookie{{ Name: "session_token", Value: "logoutsession", @@ -393,15 +364,15 @@ func TestDownloadNoPassword(t *testing.T) { }) // Show download page expired file test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://127.0.0.1:53843/d?id=Wzol7LyY2QVczXynJtVo", - IsHtml: true, - RequiredContent: []string{"URL=./error\""}, + Url: "http://127.0.0.1:53843/d?id=Wzol7LyY2QVczXynJtVo", + IgnoreRedirectParm: true, + RedirectUrl: "error", }) // Download expired file test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://127.0.0.1:53843/downloadFile?id=Wzol7LyY2QVczXynJtVo", - IsHtml: true, - RequiredContent: []string{"URL=./error\""}, + Url: "http://127.0.0.1:53843/downloadFile?id=Wzol7LyY2QVczXynJtVo", + IgnoreRedirectParm: true, + RedirectUrl: "error", }) } @@ -437,10 +408,9 @@ func TestDownloadIncorrectPasswordCookie(t *testing.T) { func TestDownloadIncorrectPassword(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://127.0.0.1:53843/downloadFile?id=jpLXGJKigM4hjtA6T6sN", - IsHtml: true, - RequiredContent: []string{"URL=./d?id=jpLXGJKigM4hjtA6T6sN"}, - Cookies: []test.Cookie{{"pjpLXGJKigM4hjtA6T6sN", "invalid"}}, + Url: "http://127.0.0.1:53843/downloadFile?id=jpLXGJKigM4hjtA6T6sN", + RedirectUrl: "d?id=jpLXGJKigM4hjtA6T6sN", + Cookies: []test.Cookie{{"pjpLXGJKigM4hjtA6T6sN", "invalid"}}, }) } @@ -448,11 +418,10 @@ func TestDownloadCorrectPassword(t *testing.T) { t.Parallel() // Submit download page correct password cookies := test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://127.0.0.1:53843/d?id=jpLXGJKigM4hjtA6T6sN2", - IsHtml: true, - RequiredContent: []string{"URL=./d?id=jpLXGJKigM4hjtA6T6sN2"}, - Method: "POST", - PostValues: []test.PostBody{{"password", "123"}}, + Url: "http://127.0.0.1:53843/d?id=jpLXGJKigM4hjtA6T6sN2", + RedirectUrl: "d?id=jpLXGJKigM4hjtA6T6sN2", + Method: "POST", + PostValues: []test.PostBody{{"password", "123"}}, }) pwCookie := "" for _, cookie := range cookies { @@ -593,7 +562,6 @@ func TestApiPageNotAuthorized(t *testing.T) { t.Parallel() test.HttpPageResult(t, test.HttpTestConfig{ Url: "http://127.0.0.1:53843/apiKeys", - IsHtml: true, RedirectUrl: "login", ResultCode: http.StatusTemporaryRedirect, ExcludedContent: []string{"Click on the API key name to give it a new name."}, @@ -645,9 +613,8 @@ func TestProcessApi(t *testing.T) { func TestDisableLogin(t *testing.T) { test.HttpPageResult(t, test.HttpTestConfig{ - Url: "http://localhost:53843/admin", - RequiredContent: []string{"URL=./login\""}, - IsHtml: true, + Url: "http://localhost:53843/admin", + RedirectUrl: "login", Cookies: []test.Cookie{{ Name: "session_token", Value: "invalid", diff --git a/internal/webserver/errorHandling/ErrorHandling.go b/internal/webserver/errorHandling/ErrorHandling.go index 0db8af25..cd437564 100644 --- a/internal/webserver/errorHandling/ErrorHandling.go +++ b/internal/webserver/errorHandling/ErrorHandling.go @@ -41,10 +41,6 @@ func (d DisplayedError) IsExpired() bool { return d.expiry < time.Now().Unix() } -func (d DisplayedError) GetWidth() bool { - return d.expiry < time.Now().Unix() -} - func RedirectToErrorPage(w http.ResponseWriter, r *http.Request, errorTitle, errorMessage, cardWidth string) { result := DisplayedError{ Title: errorTitle, @@ -124,6 +120,8 @@ func redirectToError(w http.ResponseWriter, r *http.Request, displayedError Disp } func Get(r *http.Request) DisplayedError { + mutex.RLock() + defer mutex.RUnlock() if !r.URL.Query().Has("e") { return DisplayedError{ IsGeneric: true, From 4dd37d1f21de0b81200deb0a68eef7fd0d111879 Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Thu, 12 Mar 2026 10:35:04 +0100 Subject: [PATCH 5/5] Added mock OIDC Server --- .../authentication/oauth/Oauth_test.go | 343 ++++++++++++++++-- 1 file changed, 314 insertions(+), 29 deletions(-) diff --git a/internal/webserver/authentication/oauth/Oauth_test.go b/internal/webserver/authentication/oauth/Oauth_test.go index c3d43a1a..d72d7fe2 100644 --- a/internal/webserver/authentication/oauth/Oauth_test.go +++ b/internal/webserver/authentication/oauth/Oauth_test.go @@ -1,66 +1,351 @@ package oauth import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" "net/http" "net/http/httptest" + "os" + "strings" "testing" + "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/forceu/gokapi/internal/configuration" + "github.com/forceu/gokapi/internal/models" "github.com/forceu/gokapi/internal/test" + "github.com/forceu/gokapi/internal/test/testconfiguration" "github.com/forceu/gokapi/internal/webserver/authentication" + "golang.org/x/oauth2" ) +func TestMain(m *testing.M) { + testconfiguration.Create(false) + configuration.Load() + configuration.ConnectDatabase() + exitVal := m.Run() + testconfiguration.Delete() + os.Exit(exitVal) +} + +// mockOIDCServer is a self-contained fake OIDC provider. +// It serves the discovery document, JWKS, token, and userinfo endpoints. +type mockOIDCServer struct { + server *httptest.Server + privateKey *rsa.PrivateKey + userEmail string + userSubject string + tokenValid bool +} + +func newMockOIDCServer() *mockOIDCServer { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic("failed to generate RSA key: " + err.Error()) + } + m := &mockOIDCServer{ + privateKey: key, + userEmail: "testuser@example.com", + userSubject: "test-subject-123", + tokenValid: true, + } + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", m.handleDiscovery) + mux.HandleFunc("/jwks", m.handleJWKS) + mux.HandleFunc("/token", m.handleToken) + mux.HandleFunc("/userinfo", m.handleUserinfo) + m.server = httptest.NewServer(mux) + return m +} + +func (m *mockOIDCServer) URL() string { return m.server.URL } +func (m *mockOIDCServer) Close() { m.server.Close() } + +func (m *mockOIDCServer) handleDiscovery(w http.ResponseWriter, r *http.Request) { + base := m.server.URL + doc := map[string]any{ + "issuer": base, + "authorization_endpoint": base + "/auth", + "token_endpoint": base + "/token", + "jwks_uri": base + "/jwks", + "userinfo_endpoint": base + "/userinfo", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) +} + +func (m *mockOIDCServer) handleJWKS(w http.ResponseWriter, r *http.Request) { + pub := m.privateKey.Public().(*rsa.PublicKey) + n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) + e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) + jwks := map[string]any{ + "keys": []map[string]any{{ + "kty": "RSA", + "use": "sig", + "alg": "RS256", + "kid": "test-key", + "n": n, + "e": e, + }}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jwks) +} + +func (m *mockOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) { + if !m.tokenValid { + http.Error(w, `{"error":"invalid_grant"}`, http.StatusBadRequest) + return + } + resp := map[string]any{ + "access_token": "mock-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": m.buildIDToken(), + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +func (m *mockOIDCServer) handleUserinfo(w http.ResponseWriter, r *http.Request) { + info := map[string]any{ + "sub": m.userSubject, + "email": m.userEmail, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(info) +} + +// buildIDToken builds a minimal unsigned-style ID token. Since our tests +// don't verify the signature path (we rely on the userinfo endpoint), we +// use a simple base64-encoded JSON payload wrapped in a fake JWT envelope. +func (m *mockOIDCServer) buildIDToken() string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","kid":"test-key"}`)) + payload, _ := json.Marshal(map[string]any{ + "iss": m.server.URL, + "sub": m.userSubject, + "email": m.userEmail, + "aud": []string{"test-client"}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + return header + "." + base64.RawURLEncoding.EncodeToString(payload) + ".fakesig" +} + +func TestInit_WithoutGroupScope(t *testing.T) { + mock := newMockOIDCServer() + defer mock.Close() + + credentials := models.AuthenticationConfig{ + OAuthProvider: mock.URL(), + OAuthClientId: "my-client", + OAuthClientSecret: "my-secret", + } + // Ensure no group scope is set in the global config + configuration.Get().Authentication.OAuthGroupScope = "" + + Init(mock.URL()+"/", credentials) + + test.IsEqualBool(t, ctx != nil, true) + test.IsEqualBool(t, provider != nil, true) + test.IsEqualString(t, config.ClientID, "my-client") + test.IsEqualString(t, config.ClientSecret, "my-secret") + test.IsEqualString(t, config.RedirectURL, mock.URL()+"/oauth-callback") + // Base scopes only: openid, profile, email + test.IsEqualInt(t, len(config.Scopes), 3) + test.IsEqualBool(t, containsString(strings.Join(config.Scopes, ","), oidc.ScopeOpenID), true) + test.IsEqualBool(t, containsString(strings.Join(config.Scopes, ","), "profile"), true) + test.IsEqualBool(t, containsString(strings.Join(config.Scopes, ","), "email"), true) +} + +func TestInit_WithGroupScope(t *testing.T) { + mock := newMockOIDCServer() + defer mock.Close() + + credentials := models.AuthenticationConfig{ + OAuthProvider: mock.URL(), + OAuthClientId: "my-client", + OAuthClientSecret: "my-secret", + } + configuration.Get().Authentication.OAuthGroupScope = "groups" + defer func() { configuration.Get().Authentication.OAuthGroupScope = "" }() + + Init(mock.URL()+"/", credentials) + + // Group scope must be appended as a fourth scope + test.IsEqualInt(t, len(config.Scopes), 4) + test.IsEqualBool(t, containsString(strings.Join(config.Scopes, ","), "groups"), true) +} + +// initWithMock calls Init using the mock server's URL and wires up ctx/provider/config. +func initWithMock(m *mockOIDCServer) { + var err error + ctx = context.Background() + provider, err = oidc.NewProvider(ctx, m.server.URL) + if err != nil { + panic("failed to init mock OIDC provider: " + err.Error()) + } + config = oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + Endpoint: provider.Endpoint(), + RedirectURL: m.server.URL + "/oauth-callback", + Scopes: []string{oidc.ScopeOpenID, "email", "profile"}, + } +} + +// newRequest builds a test request, optionally attaching the OAuth state cookie. +func newRequest(url, stateValue string) *http.Request { + req := httptest.NewRequest("GET", url, nil) + if stateValue != "" { + req.AddCookie(&http.Cookie{Name: authentication.CookieOauth, Value: stateValue}) + } + return req +} + +// --- Tests --- + func TestSetCallbackCookie(t *testing.T) { w, _ := test.GetRecorder("GET", "/", nil, nil, nil) - setCallbackCookie(w, "test") + setCallbackCookie(w, "test-value") cookies := w.Result().Cookies() test.IsEqualInt(t, len(cookies), 1) test.IsEqualString(t, cookies[0].Name, authentication.CookieOauth) - value := cookies[0].Value - test.IsEqualString(t, value, "test") + test.IsEqualString(t, cookies[0].Value, "test-value") } func TestHandlerLogin(t *testing.T) { - // Setup a dummy config config.ClientID = "test-client" config.Endpoint.AuthURL = "https://example.com/auth" - req, _ := http.NewRequest("GET", "/login?consent=true", nil) - rr := httptest.NewRecorder() + t.Run("Without consent", func(t *testing.T) { + rr := httptest.NewRecorder() + HandlerLogin(rr, httptest.NewRequest("GET", "/login", nil)) + + test.IsEqualInt(t, rr.Code, http.StatusFound) + location := rr.Header().Get("Location") + test.IsNotEmpty(t, location) + test.IsEqualBool(t, containsString(location, "prompt=none"), true) + test.IsEqualBool(t, len(rr.Result().Cookies()) > 0, true) + }) - HandlerLogin(rr, req) + t.Run("With consent", func(t *testing.T) { + rr := httptest.NewRecorder() + HandlerLogin(rr, httptest.NewRequest("GET", "/login?consent=true", nil)) - // Check for redirect to provider - test.IsEqualInt(t, rr.Code, http.StatusFound) - location := rr.Header().Get("Location") - test.IsEqualBool(t, len(location) > 0, true) + test.IsEqualInt(t, rr.Code, http.StatusFound) + location := rr.Header().Get("Location") + test.IsNotEmpty(t, location) + test.IsEqualBool(t, containsString(location, "prompt=consent"), true) + test.IsEqualBool(t, len(rr.Result().Cookies()) > 0, true) + }) +} + +func TestHandlerCallback_MissingStateCookie(t *testing.T) { + rr := httptest.NewRecorder() + // No cookie — state cookie is absent entirely + HandlerCallback(rr, httptest.NewRequest("GET", "/oauth-callback?state=some-state&code=123", nil)) - // Verify prompt=consent was added - test.IsEqualBool(t, location != "", true) - // Check if cookie was set - test.IsEqualBool(t, len(rr.Result().Cookies()) > 0, true) + test.IsEqualInt(t, rr.Code, http.StatusTemporaryRedirect) + test.IsNotEmpty(t, rr.Header().Get("Location")) } func TestHandlerCallback_StateMismatch(t *testing.T) { - req, _ := http.NewRequest("GET", "/oauth-callback?state=wrong-state&code=123", nil) - // Add the correct cookie to the request, but use a wrong state in URL - req.AddCookie(&http.Cookie{Name: authentication.CookieOauth, Value: "correct-state"}) + rr := httptest.NewRecorder() + HandlerCallback(rr, newRequest("/oauth-callback?state=wrong-state&code=123", "correct-state")) + + test.IsEqualInt(t, rr.Code, http.StatusTemporaryRedirect) + test.IsNotEmpty(t, rr.Header().Get("Location")) +} + +func TestHandlerCallback_LoginRequired(t *testing.T) { + for _, errCode := range []string{"login_required", "consent_required", "interaction_required"} { + t.Run(errCode, func(t *testing.T) { + config.Endpoint.AuthURL = "https://example.com/auth" + rr := httptest.NewRecorder() + url := fmt.Sprintf("/oauth-callback?state=mystate&error=%s", errCode) + HandlerCallback(rr, newRequest(url, "mystate")) + + // Should re-initiate login with consent + test.IsEqualInt(t, rr.Code, http.StatusFound) + test.IsEqualBool(t, containsString(rr.Header().Get("Location"), "prompt=consent"), true) + }) + } +} + +func TestHandlerCallback_TokenExchangeFailure(t *testing.T) { + mock := newMockOIDCServer() + defer mock.Close() + initWithMock(mock) + mock.tokenValid = false + + rr := httptest.NewRecorder() + HandlerCallback(rr, newRequest("/oauth-callback?state=mystate&code=badcode", "mystate")) + + test.IsEqualInt(t, rr.Code, http.StatusTemporaryRedirect) + test.IsNotEmpty(t, rr.Header().Get("Location")) +} + +func TestHandlerCallback_EmptyEmail(t *testing.T) { + mock := newMockOIDCServer() + defer mock.Close() + initWithMock(mock) + mock.userEmail = "" // userinfo will return empty email rr := httptest.NewRecorder() - HandlerCallback(rr, req) + HandlerCallback(rr, newRequest("/oauth-callback?state=mystate&code=validcode", "mystate")) - // Should redirect to error page test.IsEqualInt(t, rr.Code, http.StatusTemporaryRedirect) - test.IsEqualBool(t, rr.Header().Get("Location") != "", true) + test.IsNotEmpty(t, rr.Header().Get("Location")) +} + +func TestHandlerCallback_Success(t *testing.T) { + mock := newMockOIDCServer() + defer mock.Close() + initWithMock(mock) + + rr := httptest.NewRecorder() + HandlerCallback(rr, newRequest("/oauth-callback?state=mystate&code=validcode", "mystate")) + + // Valid flow completes — should redirect somewhere (admin or error-auth depending on auth config) + test.IsEqualBool(t, rr.Code == http.StatusTemporaryRedirect || rr.Code == http.StatusFound, true) + test.IsNotEmpty(t, rr.Header().Get("Location")) } func TestIsLoginRequired(t *testing.T) { - t.Run("Standard error", func(t *testing.T) { - req, _ := http.NewRequest("GET", "/?error=login_required", nil) - test.IsEqualBool(t, isLoginRequired(req), true) - }) + cases := []struct { + query string + expected bool + }{ + {"?error=login_required", true}, + {"?error=consent_required", true}, + {"?error=interaction_required", true}, + {"?error=access_denied", false}, + {"?error=unknown_error", false}, + {"?code=123", false}, + {"", false}, + } + for _, tc := range cases { + t.Run(tc.query, func(t *testing.T) { + req := httptest.NewRequest("GET", "/"+tc.query, nil) + test.IsEqualBool(t, isLoginRequired(req), tc.expected) + }) + } +} - t.Run("No error", func(t *testing.T) { - req, _ := http.NewRequest("GET", "/?code=123", nil) - test.IsEqualBool(t, isLoginRequired(req), false) - }) +func containsString(s, substr string) bool { + if len(substr) == 0 { + return true + } + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false }