diff --git a/internal/handler/users.go b/internal/handler/users.go index 340a592..13dba98 100644 --- a/internal/handler/users.go +++ b/internal/handler/users.go @@ -40,7 +40,7 @@ func (h *Handler) GetUser(w http.ResponseWriter, r *http.Request) { // UpdateUser updates a user's profile // @Summary Update user -// @Description Updates a user's profile. Users can only update their own profile. +// @Description Updates a user's profile. Only role changes require the caller to have the dev role. // @Tags users // @Accept json // @Produce json @@ -60,7 +60,7 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - authenticatedUser, ok := h.requireSelfOrDev(w, r, uid) + authenticatedUser, _, ok := h.requireAuthenticatedUserRecord(w, r) if !ok { return } @@ -71,13 +71,23 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } + targetUser, err := h.queries.GetUserByID(r.Context(), uid) + if err != nil { + h.handleDBError(w, err) + return + } + var role database.NullUserRole if req.Role != nil { - if !authenticatedUser.Role.Valid || authenticatedUser.Role.UserRole != database.UserRoleDev { + requestedRole := database.UserRole(*req.Role) + roleChanged := !targetUser.Role.Valid || targetUser.Role.UserRole != requestedRole + if roleChanged && (!authenticatedUser.Role.Valid || authenticatedUser.Role.UserRole != database.UserRoleDev) { h.respondError(w, http.StatusForbidden, "Only dev may update user roles") return } - role = database.NullUserRole{UserRole: database.UserRole(*req.Role), Valid: true} + if roleChanged { + role = database.NullUserRole{UserRole: requestedRole, Valid: true} + } } user, err := h.queries.UpdateUser(r.Context(), database.UpdateUserParams{ diff --git a/internal/handler/users_test.go b/internal/handler/users_test.go index 44fe5a3..62c8b19 100644 --- a/internal/handler/users_test.go +++ b/internal/handler/users_test.go @@ -1,6 +1,9 @@ package handler_test import ( + "bytes" + "context" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -9,7 +12,9 @@ import ( "github.com/capyrpi/api/internal/config" "github.com/capyrpi/api/internal/database" "github.com/capyrpi/api/internal/database/mocks" + "github.com/capyrpi/api/internal/dto" "github.com/capyrpi/api/internal/handler" + "github.com/capyrpi/api/internal/middleware" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -83,3 +88,128 @@ func TestGetUser(t *testing.T) { }) } } + +func TestUpdateUser(t *testing.T) { + targetUID := uuid.New() + authenticatedUID := uuid.New() + firstName := "Updated" + currentRole := "student" + role := "faculty" + + tests := []struct { + name string + requestBody dto.UpdateUserRequest + mockSetup func(*mocks.Querier) + setupContext func() context.Context + expectedStatus int + }{ + { + name: "NonDevCanUpdateWhenSubmittedRoleMatchesCurrentRole", + requestBody: dto.UpdateUserRequest{ + FirstName: &firstName, + Role: ¤tRole, + }, + mockSetup: func(m *mocks.Querier) { + m.On("GetUserByID", mock.Anything, authenticatedUID).Return(database.User{ + Uid: authenticatedUID, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + m.On("GetUserByID", mock.Anything, targetUID).Return(database.User{ + Uid: targetUID, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + m.On("UpdateUser", mock.Anything, mock.MatchedBy(func(arg database.UpdateUserParams) bool { + return arg.Uid == targetUID && + arg.FirstName.Valid && arg.FirstName.String == firstName && + !arg.Role.Valid + })).Return(database.User{ + Uid: targetUID, + FirstName: firstName, + LastName: "Doe", + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + }, + setupContext: func() context.Context { + ctx := context.Background() + claims := &middleware.UserClaims{UserID: authenticatedUID.String()} + ctx = context.WithValue(ctx, middleware.UserClaimsKey, claims) + return context.WithValue(ctx, middleware.AuthTypeKey, "human") + }, + expectedStatus: http.StatusOK, + }, + { + name: "NonDevCannotUpdateRole", + requestBody: dto.UpdateUserRequest{ + Role: &role, + }, + mockSetup: func(m *mocks.Querier) { + m.On("GetUserByID", mock.Anything, authenticatedUID).Return(database.User{ + Uid: authenticatedUID, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + m.On("GetUserByID", mock.Anything, targetUID).Return(database.User{ + Uid: targetUID, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + }, + setupContext: func() context.Context { + ctx := context.Background() + claims := &middleware.UserClaims{UserID: authenticatedUID.String()} + ctx = context.WithValue(ctx, middleware.UserClaimsKey, claims) + return context.WithValue(ctx, middleware.AuthTypeKey, "human") + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "DevCanUpdateRole", + requestBody: dto.UpdateUserRequest{ + Role: &role, + }, + mockSetup: func(m *mocks.Querier) { + m.On("GetUserByID", mock.Anything, authenticatedUID).Return(database.User{ + Uid: authenticatedUID, + Role: database.NullUserRole{UserRole: database.UserRoleDev, Valid: true}, + }, nil) + m.On("GetUserByID", mock.Anything, targetUID).Return(database.User{ + Uid: targetUID, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + m.On("UpdateUser", mock.Anything, mock.MatchedBy(func(arg database.UpdateUserParams) bool { + return arg.Uid == targetUID && + arg.Role.Valid && + arg.Role.UserRole == database.UserRoleFaculty + })).Return(database.User{ + Uid: targetUID, + Role: database.NullUserRole{UserRole: database.UserRoleFaculty, Valid: true}, + }, nil) + }, + setupContext: func() context.Context { + ctx := context.Background() + claims := &middleware.UserClaims{UserID: authenticatedUID.String()} + ctx = context.WithValue(ctx, middleware.UserClaimsKey, claims) + return context.WithValue(ctx, middleware.AuthTypeKey, "human") + }, + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + tt.mockSetup(mockQueries) + + h := handler.New(mockQueries, &config.Config{}) + r := chi.NewRouter() + r.Put("/users/{uid}", h.UpdateUser) + + body, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/users/%s", targetUID), bytes.NewBuffer(body)) + req = req.WithContext(tt.setupContext()) + rr := httptest.NewRecorder() + + r.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +}