diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 9fe134ed6..c73a8e4b1 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -292,7 +292,7 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, answerAPIRouter := router.NewAnswerAPIRouter(langController, userController, commentController, reportController, voteController, tagController, followController, collectionController, questionController, answerController, searchController, revisionController, rankController, userAdminController, reasonController, themeController, siteInfoController, controllerSiteInfoController, notificationController, dashboardController, uploadController, activityController, roleController, pluginController, permissionController, userPluginController, reviewController, metaController, badgeController, controller_adminBadgeController, adminAPIKeyController, aiController, aiConversationController, aiConversationAdminController, mcpController) swaggerRouter := router.NewSwaggerRouter(swaggerConf) uiRouter := router.NewUIRouter(controllerSiteInfoController, siteInfoCommonService) - authUserMiddleware := middleware.NewAuthUserMiddleware(authService, siteInfoCommonService) + authUserMiddleware := middleware.NewAuthUserMiddleware(authService, siteInfoCommonService, userRepo, userRoleRelService) avatarMiddleware := middleware.NewAvatarMiddleware(serviceConf, uploaderService) shortIDMiddleware := middleware.NewShortIDMiddleware(siteInfoCommonService) templateRenderController := templaterender.NewTemplateRenderController(questionService, userService, tagService, answerService, commentService, siteInfoCommonService, questionRepo) diff --git a/internal/base/middleware/api_key_auth.go b/internal/base/middleware/api_key_auth.go index cc8182d40..fe003f1df 100644 --- a/internal/base/middleware/api_key_auth.go +++ b/internal/base/middleware/api_key_auth.go @@ -20,14 +20,46 @@ package middleware import ( + "strings" + "github.com/apache/answer/internal/base/handler" "github.com/apache/answer/internal/base/reason" + "github.com/apache/answer/internal/entity" "github.com/gin-gonic/gin" "github.com/segmentfault/pacman/errors" + "github.com/segmentfault/pacman/log" ) -// AuthAPIKey middleware to authenticate API key -func (am *AuthUserMiddleware) AuthAPIKey() gin.HandlerFunc { +// apiKeyAllowedPrefixes lists the URL path prefixes accessible via API key. +// Routes not matching any prefix require a session token. +var apiKeyAllowedPrefixes = []string{ + "/answer/api/v1/question", + "/answer/api/v1/answer", + "/answer/api/v1/comment", + "/answer/api/v1/tag", + "/answer/api/v1/search", + "/answer/api/v1/collection", + "/answer/api/v1/vote", + "/answer/api/v1/follow", + "/answer/api/v1/revisions", + "/answer/api/v1/chat/completions", + "/answer/api/v1/ai/conversation", + "/answer/api/v1/mcp", +} + +func isAPIKeyAllowed(path string) bool { + for _, prefix := range apiKeyAllowedPrefixes { + if strings.HasPrefix(path, prefix) { + return true + } + } + return false +} + +// AuthSessionOrAPIKey tries session-based auth first, then falls back to API key auth. +// In both cases it injects a UserCacheInfo into the Gin context so that downstream +// handlers can use GetLoginUserIDFromContext() as usual. +func (am *AuthUserMiddleware) AuthSessionOrAPIKey() gin.HandlerFunc { return func(ctx *gin.Context) { token := ExtractToken(ctx) if len(token) == 0 { @@ -35,17 +67,61 @@ func (am *AuthUserMiddleware) AuthAPIKey() gin.HandlerFunc { ctx.Abort() return } - pass, err := am.authService.AuthAPIKey(ctx, ctx.Request.Method == "GET", token) - if err != nil { + + // 1. Try session-based auth + userInfo, err := am.authService.GetUserCacheInfo(ctx, token) + if err == nil && userInfo != nil { + if !am.validateUserStatus(ctx, userInfo) { + return + } + ctx.Set(ctxUUIDKey, userInfo) + ctx.Next() + return + } + + // 2. Fallback to API key auth (only for whitelisted routes) + if !isAPIKeyAllowed(ctx.Request.URL.Path) { handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) ctx.Abort() return } - if !pass { + + isRead := ctx.Request.Method == "GET" + apiKeyInfo, err := am.authService.GetAPIKeyInfo(ctx, isRead, token) + if err != nil || apiKeyInfo == nil { handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) ctx.Abort() return } + + // Resolve user from the API key's UserID + userEntity, exist, err := am.userRepo.GetByUserID(ctx, apiKeyInfo.UserID) + if err != nil || !exist { + log.Errorf("API key %s references unknown user %s", apiKeyInfo.AccessKey, apiKeyInfo.UserID) + handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) + ctx.Abort() + return + } + if userEntity.Status == entity.UserStatusDeleted || userEntity.Status == entity.UserStatusSuspended { + handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) + ctx.Abort() + return + } + + roleID, err := am.userRoleService.GetUserRole(ctx, userEntity.ID) + if err != nil { + log.Errorf("failed to get role for user %s: %v", userEntity.ID, err) + handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) + ctx.Abort() + return + } + + ctx.Set(ctxUUIDKey, &entity.UserCacheInfo{ + UserID: userEntity.ID, + UserStatus: userEntity.Status, + EmailStatus: userEntity.MailStatus, + RoleID: roleID, + }) ctx.Next() } } diff --git a/internal/base/middleware/api_key_auth_test.go b/internal/base/middleware/api_key_auth_test.go new file mode 100644 index 000000000..cfe554570 --- /dev/null +++ b/internal/base/middleware/api_key_auth_test.go @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/apache/answer/internal/entity" + "github.com/apache/answer/internal/service/auth" + "github.com/apache/answer/internal/service/role" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +// --- Mock repos for AuthService --- + +type mockAuthRepo struct { + userCache *entity.UserCacheInfo + err error +} + +func (m *mockAuthRepo) GetUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) { + return m.userCache, m.err +} +func (m *mockAuthRepo) SetUserCacheInfo(_ context.Context, _, _ string, _ *entity.UserCacheInfo) error { + return nil +} +func (m *mockAuthRepo) GetUserVisitCacheInfo(_ context.Context, _ string) (string, error) { + return "", nil +} +func (m *mockAuthRepo) RemoveUserCacheInfo(_ context.Context, _ string) error { return nil } +func (m *mockAuthRepo) RemoveUserVisitCacheInfo(_ context.Context, _ string) error { return nil } +func (m *mockAuthRepo) SetUserStatus(_ context.Context, _ string, _ *entity.UserCacheInfo) error { + return nil +} +func (m *mockAuthRepo) GetUserStatus(_ context.Context, _ string) (*entity.UserCacheInfo, error) { + return nil, nil +} +func (m *mockAuthRepo) RemoveUserStatus(_ context.Context, _ string) error { return nil } +func (m *mockAuthRepo) GetAdminUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) { + return nil, nil +} +func (m *mockAuthRepo) SetAdminUserCacheInfo(_ context.Context, _ string, _ *entity.UserCacheInfo) error { + return nil +} +func (m *mockAuthRepo) RemoveAdminUserCacheInfo(_ context.Context, _ string) error { return nil } +func (m *mockAuthRepo) AddUserTokenMapping(_ context.Context, _, _ string) error { return nil } +func (m *mockAuthRepo) RemoveUserTokens(_ context.Context, _ string, _ string) {} + +type mockAPIKeyRepo struct { + key *entity.APIKey + exist bool + err error +} + +func (m *mockAPIKeyRepo) GetAPIKeyList(_ context.Context) ([]*entity.APIKey, error) { return nil, nil } +func (m *mockAPIKeyRepo) GetAPIKey(_ context.Context, _ string) (*entity.APIKey, bool, error) { + return m.key, m.exist, m.err +} +func (m *mockAPIKeyRepo) UpdateAPIKey(_ context.Context, _ entity.APIKey) error { return nil } +func (m *mockAPIKeyRepo) AddAPIKey(_ context.Context, _ entity.APIKey) error { return nil } +func (m *mockAPIKeyRepo) DeleteAPIKey(_ context.Context, _ int) error { return nil } + +type mockUserRepo struct { + user *entity.User + exist bool + err error +} + +func (m *mockUserRepo) AddUser(_ context.Context, _ *entity.User) error { return nil } +func (m *mockUserRepo) IncreaseAnswerCount(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRepo) IncreaseQuestionCount(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRepo) UpdateQuestionCount(_ context.Context, _ string, _ int64) error { return nil } +func (m *mockUserRepo) UpdateAnswerCount(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRepo) UpdateLastLoginDate(_ context.Context, _ string) error { return nil } +func (m *mockUserRepo) UpdateEmailStatus(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRepo) UpdateNoticeStatus(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRepo) UpdateEmail(_ context.Context, _, _ string) error { return nil } +func (m *mockUserRepo) UpdateUserInterface(_ context.Context, _, _, _ string) error { return nil } +func (m *mockUserRepo) UpdatePass(_ context.Context, _, _ string) error { return nil } +func (m *mockUserRepo) UpdateInfo(_ context.Context, _ *entity.User) error { return nil } +func (m *mockUserRepo) UpdateUserProfile(_ context.Context, _ *entity.User) error { return nil } +func (m *mockUserRepo) BatchGetByID(_ context.Context, _ []string) ([]*entity.User, error) { return nil, nil } +func (m *mockUserRepo) GetByUsername(_ context.Context, _ string) (*entity.User, bool, error) { + return nil, false, nil +} +func (m *mockUserRepo) GetByUsernames(_ context.Context, _ []string) ([]*entity.User, error) { + return nil, nil +} +func (m *mockUserRepo) GetByEmail(_ context.Context, _ string) (*entity.User, bool, error) { + return nil, false, nil +} +func (m *mockUserRepo) GetUserCount(_ context.Context) (int64, error) { return 0, nil } +func (m *mockUserRepo) SearchUserListByName(_ context.Context, _ string, _ int, _ bool) ([]*entity.User, error) { + return nil, nil +} +func (m *mockUserRepo) IsAvatarFileUsed(_ context.Context, _ string) (bool, error) { + return false, nil +} +func (m *mockUserRepo) GetByUserID(_ context.Context, _ string) (*entity.User, bool, error) { + return m.user, m.exist, m.err +} + +// mockUserRoleRelRepo implements role.UserRoleRelRepo for testing +type mockUserRoleRelRepo struct { + roleID int + exist bool +} + +func (m *mockUserRoleRelRepo) SaveUserRoleRel(_ context.Context, _ string, _ int) error { return nil } +func (m *mockUserRoleRelRepo) GetUserRoleRelList(_ context.Context, _ []string) ([]*entity.UserRoleRel, error) { + return nil, nil +} +func (m *mockUserRoleRelRepo) GetUserRoleRelListByRoleID(_ context.Context, _ []int) ([]*entity.UserRoleRel, error) { + return nil, nil +} +func (m *mockUserRoleRelRepo) GetUserRoleRel(_ context.Context, _ string) (*entity.UserRoleRel, bool, error) { + if !m.exist { + return nil, false, nil + } + return &entity.UserRoleRel{RoleID: m.roleID}, true, nil +} + +// --- Helper --- + +func newTestMiddleware( + authRepo *mockAuthRepo, + apiKeyRepo *mockAPIKeyRepo, + userRepo *mockUserRepo, + roleID int, +) *AuthUserMiddleware { + svc := auth.NewAuthService(authRepo, apiKeyRepo) + userRoleRelService := role.NewUserRoleRelService(&mockUserRoleRelRepo{roleID: roleID, exist: true}, nil) + return NewAuthUserMiddleware(svc, nil, userRepo, userRoleRelService) +} + +func performRequest(mw gin.HandlerFunc, method, path string) *httptest.ResponseRecorder { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + _, engine := gin.CreateTestContext(w) + engine.Use(mw) + engine.Handle(method, path, func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + req, _ := http.NewRequest(method, path, nil) + req.Header.Set("Authorization", "Bearer test-token") + engine.ServeHTTP(w, req) + return w +} + +func performRequestNoToken(mw gin.HandlerFunc) *httptest.ResponseRecorder { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + _, engine := gin.CreateTestContext(w) + engine.Use(mw) + engine.Handle("GET", "/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + req, _ := http.NewRequest("GET", "/test", nil) + engine.ServeHTTP(w, req) + return w +} + +// --- Tests --- + +func TestAuthSessionOrAPIKey_ValidSession(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: &entity.UserCacheInfo{ + UserID: "100", + UserStatus: entity.UserStatusAvailable, + EmailStatus: entity.EmailStatusAvailable, + RoleID: 1, + }}, + &mockAPIKeyRepo{exist: false}, + &mockUserRepo{exist: false}, + 1, + ) + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/test") + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSessionOrAPIKey_InvalidSessionFallbackValidAPIKey(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: nil}, // session fails + &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "200"}, + exist: true, + }, + &mockUserRepo{ + user: &entity.User{ID: "200", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, + exist: true, + }, + 1, + ) + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question") + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSessionOrAPIKey_BothFail(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: nil}, + &mockAPIKeyRepo{exist: false}, + &mockUserRepo{exist: false}, + 1, + ) + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question") + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestAuthSessionOrAPIKey_NoToken(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: nil}, + &mockAPIKeyRepo{exist: false}, + &mockUserRepo{exist: false}, + 1, + ) + w := performRequestNoToken(m.AuthSessionOrAPIKey()) + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestAuthSessionOrAPIKey_ReadOnlyKeyPostRequest(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: nil}, // session fails + &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_ro", Scope: "read-only", UserID: "300"}, + exist: true, + }, + &mockUserRepo{ + user: &entity.User{ID: "300", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, + exist: true, + }, + 1, + ) + w := performRequest(m.AuthSessionOrAPIKey(), "POST", "/answer/api/v1/question") + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestAuthSessionOrAPIKey_APIKeyBlockedOnNonWhitelistedRoute(t *testing.T) { + m := newTestMiddleware( + &mockAuthRepo{userCache: nil}, + &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "400"}, + exist: true, + }, + &mockUserRepo{ + user: &entity.User{ID: "400", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, + exist: true, + }, + 1, + ) + w := performRequest(m.AuthSessionOrAPIKey(), "PUT", "/answer/api/v1/user/password") + assert.Equal(t, http.StatusUnauthorized, w.Code) +} diff --git a/internal/base/middleware/auth.go b/internal/base/middleware/auth.go index 57bbaae21..045d9d7a4 100644 --- a/internal/base/middleware/auth.go +++ b/internal/base/middleware/auth.go @@ -26,6 +26,7 @@ import ( "github.com/apache/answer/internal/schema" "github.com/apache/answer/internal/service/role" "github.com/apache/answer/internal/service/siteinfo_common" + usercommon "github.com/apache/answer/internal/service/user_common" "github.com/apache/answer/ui" "github.com/gin-gonic/gin" @@ -44,15 +45,22 @@ var ctxUUIDKey = "ctxUuidKey" type AuthUserMiddleware struct { authService *auth.AuthService siteInfoCommonService siteinfo_common.SiteInfoCommonService + userRepo usercommon.UserRepo + userRoleService *role.UserRoleRelService } // NewAuthUserMiddleware new auth user middleware func NewAuthUserMiddleware( authService *auth.AuthService, - siteInfoCommonService siteinfo_common.SiteInfoCommonService) *AuthUserMiddleware { + siteInfoCommonService siteinfo_common.SiteInfoCommonService, + userRepo usercommon.UserRepo, + userRoleService *role.UserRoleRelService, +) *AuthUserMiddleware { return &AuthUserMiddleware{ authService: authService, siteInfoCommonService: siteInfoCommonService, + userRepo: userRepo, + userRoleService: userRoleService, } } @@ -132,6 +140,29 @@ func (am *AuthUserMiddleware) MustAuthWithoutAccountAvailable() gin.HandlerFunc } } +// validateUserStatus checks email and user status, writes the error response and aborts if invalid. +// Returns true if the user is valid, false if the request was aborted. +func (am *AuthUserMiddleware) validateUserStatus(ctx *gin.Context, userInfo *entity.UserCacheInfo) bool { + if userInfo.EmailStatus != entity.EmailStatusAvailable { + handler.HandleResponse(ctx, errors.Forbidden(reason.EmailNeedToBeVerified), + &schema.ForbiddenResp{Type: schema.ForbiddenReasonTypeInactive}) + ctx.Abort() + return false + } + if userInfo.UserStatus == entity.UserStatusSuspended { + handler.HandleResponse(ctx, errors.Forbidden(reason.UserSuspended), + &schema.ForbiddenResp{Type: schema.ForbiddenReasonTypeUserSuspended}) + ctx.Abort() + return false + } + if userInfo.UserStatus == entity.UserStatusDeleted { + handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) + ctx.Abort() + return false + } + return true +} + // MustAuthAndAccountAvailable auth user info and check user status, only allow active user access. func (am *AuthUserMiddleware) MustAuthAndAccountAvailable() gin.HandlerFunc { return func(ctx *gin.Context) { @@ -147,21 +178,7 @@ func (am *AuthUserMiddleware) MustAuthAndAccountAvailable() gin.HandlerFunc { ctx.Abort() return } - if userInfo.EmailStatus != entity.EmailStatusAvailable { - handler.HandleResponse(ctx, errors.Forbidden(reason.EmailNeedToBeVerified), - &schema.ForbiddenResp{Type: schema.ForbiddenReasonTypeInactive}) - ctx.Abort() - return - } - if userInfo.UserStatus == entity.UserStatusSuspended { - handler.HandleResponse(ctx, errors.Forbidden(reason.UserSuspended), - &schema.ForbiddenResp{Type: schema.ForbiddenReasonTypeUserSuspended}) - ctx.Abort() - return - } - if userInfo.UserStatus == entity.UserStatusDeleted { - handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil) - ctx.Abort() + if !am.validateUserStatus(ctx, userInfo) { return } ctx.Set(ctxUUIDKey, userInfo) diff --git a/internal/base/server/http.go b/internal/base/server/http.go index 765cbf6be..55528376c 100644 --- a/internal/base/server/http.go +++ b/internal/base/server/http.go @@ -93,7 +93,7 @@ func NewHTTPServer(debug bool, // register api that must be authenticated authV1 := r.Group(uiConf.APIBaseURL + "/answer/api/v1") - authV1.Use(authUserMiddleware.MustAuthAndAccountAvailable()) + authV1.Use(authUserMiddleware.AuthSessionOrAPIKey()) answerRouter.RegisterAnswerAPIRouter(authV1) adminauthV1 := r.Group(uiConf.APIBaseURL + "/answer/admin/api") @@ -116,7 +116,7 @@ func NewHTTPServer(debug bool, // mcp mcpAPIGroup := r.Group(uiConf.APIBaseURL + "/answer/api/v1") - mcpAPIGroup.Use(authUserMiddleware.AuthMcpEnable(), authUserMiddleware.AuthAPIKey()) + mcpAPIGroup.Use(authUserMiddleware.AuthMcpEnable(), authUserMiddleware.AuthSessionOrAPIKey()) answerRouter.RegisterMCPRouter(mcpAPIGroup) return r } diff --git a/internal/service/auth/api_key_auth_test.go b/internal/service/auth/api_key_auth_test.go new file mode 100644 index 000000000..fd0703758 --- /dev/null +++ b/internal/service/auth/api_key_auth_test.go @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package auth + +import ( + "context" + "testing" + + "github.com/apache/answer/internal/entity" + "github.com/stretchr/testify/assert" +) + +// --- Mocks --- + +type mockAPIKeyRepo struct { + key *entity.APIKey + exist bool + err error +} + +func (m *mockAPIKeyRepo) GetAPIKeyList(_ context.Context) ([]*entity.APIKey, error) { + return nil, nil +} +func (m *mockAPIKeyRepo) GetAPIKey(_ context.Context, _ string) (*entity.APIKey, bool, error) { + return m.key, m.exist, m.err +} +func (m *mockAPIKeyRepo) UpdateAPIKey(_ context.Context, _ entity.APIKey) error { return nil } +func (m *mockAPIKeyRepo) AddAPIKey(_ context.Context, _ entity.APIKey) error { return nil } +func (m *mockAPIKeyRepo) DeleteAPIKey(_ context.Context, _ int) error { return nil } + +// --- Tests --- + +func TestGetAPIKeyInfo_ValidReadWriteKey(t *testing.T) { + svc := &AuthService{ + apiKeyRepo: &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "100"}, + exist: true, + }, + } + + info, err := svc.GetAPIKeyInfo(context.Background(), false, "sk_test") + assert.NoError(t, err) + assert.NotNil(t, info) + assert.Equal(t, "100", info.UserID) + assert.Equal(t, "read-write", info.Scope) +} + +func TestGetAPIKeyInfo_ReadOnlyKeyGetRequest(t *testing.T) { + svc := &AuthService{ + apiKeyRepo: &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_ro", Scope: "read-only", UserID: "200"}, + exist: true, + }, + } + + info, err := svc.GetAPIKeyInfo(context.Background(), true, "sk_ro") + assert.NoError(t, err) + assert.NotNil(t, info) + assert.Equal(t, "200", info.UserID) +} + +func TestGetAPIKeyInfo_ReadOnlyKeyNonGetRequest(t *testing.T) { + svc := &AuthService{ + apiKeyRepo: &mockAPIKeyRepo{ + key: &entity.APIKey{AccessKey: "sk_ro", Scope: "read-only", UserID: "200"}, + exist: true, + }, + } + + info, err := svc.GetAPIKeyInfo(context.Background(), false, "sk_ro") + assert.NoError(t, err) + assert.Nil(t, info) +} + +func TestGetAPIKeyInfo_KeyNotFound(t *testing.T) { + svc := &AuthService{ + apiKeyRepo: &mockAPIKeyRepo{exist: false}, + } + + info, err := svc.GetAPIKeyInfo(context.Background(), true, "sk_invalid") + assert.NoError(t, err) + assert.Nil(t, info) +} diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 8f539bf11..1941cd963 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -156,19 +156,20 @@ func (as *AuthService) SetAdminUserCacheInfo(ctx context.Context, accessToken st func (as *AuthService) RemoveAdminUserCacheInfo(ctx context.Context, accessToken string) (err error) { return as.authRepo.RemoveAdminUserCacheInfo(ctx, accessToken) } -func (as *AuthService) AuthAPIKey(ctx context.Context, read bool, apiKey string) (pass bool, err error) { - apiKeyInfo, exist, err := as.apiKeyRepo.GetAPIKey(ctx, apiKey) + +// GetAPIKeyInfo validates an API key and checks its scope. +// Returns the APIKey entity so callers can use the UserID. +func (as *AuthService) GetAPIKeyInfo(ctx context.Context, isRead bool, token string) (apiKeyInfo *entity.APIKey, err error) { + apiKeyInfo, exist, err := as.apiKeyRepo.GetAPIKey(ctx, token) if err != nil { - return false, err + return nil, err } if !exist { - return false, nil + return nil, nil } - // If the request is not read-only, check if the API key has write permissions - if !read && apiKeyInfo.Scope == "read-only" { + if !isRead && apiKeyInfo.Scope == "read-only" { log.Warnf("API key %s does not have write permissions", apiKeyInfo.AccessKey) - return false, nil + return nil, nil } - log.Infof("API key %s is valid, scope: %s", apiKeyInfo.AccessKey, apiKeyInfo.Scope) - return true, nil + return apiKeyInfo, nil }