diff --git a/.gitignore b/.gitignore index a1929ea..94c36bc 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ coverage.html # Benchmarks benchmarks/results/*.txt + +# Agents +.sisyphus/ diff --git a/cmd/server/main.go b/cmd/server/main.go index 532e69c..c2f361f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -9,6 +9,7 @@ import ( "syscall" "time" + swaggerdocs "github.com/capyrpi/api/docs/swagger" "github.com/capyrpi/api/internal/config" "github.com/capyrpi/api/internal/database" "github.com/capyrpi/api/internal/handler" @@ -27,7 +28,7 @@ import ( // @license.name Apache 2.0 // @license.url http://www.apache.org/licenses/LICENSE-2.0.html -// @host api.capyrpi.org +// @host localhost:8080 // @BasePath /v1 // @securityDefinitions.apikey CookieAuth @@ -50,6 +51,18 @@ func main() { slog.Info("starting server", "env", cfg.Env) + swaggerHost := cfg.Swagger.Host + if swaggerHost == "" { + swaggerHost = "localhost:" + cfg.Server.Port + } + swaggerdocs.SwaggerInfo.Host = swaggerHost + + if cfg.Env == "production" { + swaggerdocs.SwaggerInfo.Schemes = []string{"https"} + } else { + swaggerdocs.SwaggerInfo.Schemes = []string{"http"} + } + // Connect to database ctx := context.Background() pool, err := database.NewPool(ctx, cfg.Database.URL) diff --git a/docker-compose.yml b/docker-compose.yml index 9f5f0a7..6594f6f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,7 +10,7 @@ services: volumes: - pgdata:/var/lib/postgresql/data healthcheck: - test: [ "CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}" ] + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] interval: 5s timeout: 5s retries: 5 diff --git a/docs/schema/schema.json b/docs/schema/schema.json index 67eaee4..3830c84 100644 --- a/docs/schema/schema.json +++ b/docs/schema/schema.json @@ -1144,6 +1144,7 @@ "student", "alumni", "faculty", + "dev", "external" ], "comment": "" diff --git a/internal/auth/adapters/oauth.go b/internal/auth/adapters/oauth.go new file mode 100644 index 0000000..af181aa --- /dev/null +++ b/internal/auth/adapters/oauth.go @@ -0,0 +1,68 @@ +package adapters + +import ( + "context" + + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/oauth" +) + +// GoogleOAuthAdapter wraps internal/oauth/google.go +type GoogleOAuthAdapter struct { + provider *oauth.GoogleProvider +} + +func NewGoogleOAuthAdapter(provider *oauth.GoogleProvider) ports.OAuthProvider { + return &GoogleOAuthAdapter{ + provider: provider, + } +} + +func (a *GoogleOAuthAdapter) GetAuthURL(state string) string { + return a.provider.GetAuthURL(state) +} + +func (a *GoogleOAuthAdapter) ExchangeCode(ctx context.Context, code string) (*ports.UserInfo, error) { + userInfo, err := a.provider.ExchangeCode(ctx, code) + if err != nil { + return nil, err + } + return &ports.UserInfo{ + Email: userInfo.Email, + FirstName: userInfo.GivenName, + LastName: userInfo.FamilyName, + }, nil +} + +// MicrosoftOAuthAdapter wraps internal/oauth/microsoft.go +type MicrosoftOAuthAdapter struct { + provider *oauth.MicrosoftProvider +} + +func NewMicrosoftOAuthAdapter(provider *oauth.MicrosoftProvider) ports.OAuthProvider { + return &MicrosoftOAuthAdapter{ + provider: provider, + } +} + +func (a *MicrosoftOAuthAdapter) GetAuthURL(state string) string { + return a.provider.GetAuthURL(state) +} + +func (a *MicrosoftOAuthAdapter) ExchangeCode(ctx context.Context, code string) (*ports.UserInfo, error) { + userInfo, err := a.provider.ExchangeCode(ctx, code) + if err != nil { + return nil, err + } + + email := userInfo.UserPrincipalName + if email == "" { + email = userInfo.Mail + } + + return &ports.UserInfo{ + Email: email, + FirstName: userInfo.GivenName, + LastName: userInfo.Surname, + }, nil +} diff --git a/internal/auth/adapters/persistence.go b/internal/auth/adapters/persistence.go new file mode 100644 index 0000000..af7fd6d --- /dev/null +++ b/internal/auth/adapters/persistence.go @@ -0,0 +1,36 @@ +package adapters + +import ( + "context" + + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/database" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +type UserRepoAdapter struct { + queries database.Querier +} + +func NewUserRepoAdapter(queries database.Querier) ports.UserRepo { + return &UserRepoAdapter{ + queries: queries, + } +} + +func (r *UserRepoAdapter) GetUserByEmail(ctx context.Context, email pgtype.Text) (database.User, error) { + return r.queries.GetUserByEmail(ctx, email) +} + +func (r *UserRepoAdapter) CreateUser(ctx context.Context, arg database.CreateUserParams) (database.User, error) { + return r.queries.CreateUser(ctx, arg) +} + +func (r *UserRepoAdapter) GetUserByID(ctx context.Context, uid uuid.UUID) (database.User, error) { + return r.queries.GetUserByID(ctx, uid) +} + +func (r *UserRepoAdapter) CreateBotToken(ctx context.Context, arg database.CreateBotTokenParams) (database.BotToken, error) { + return r.queries.CreateBotToken(ctx, arg) +} diff --git a/internal/auth/adapters/token.go b/internal/auth/adapters/token.go new file mode 100644 index 0000000..60da8ef --- /dev/null +++ b/internal/auth/adapters/token.go @@ -0,0 +1,47 @@ +package adapters + +import ( + "time" + + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/config" + "github.com/capyrpi/api/internal/database" + "github.com/capyrpi/api/internal/middleware" + "github.com/golang-jwt/jwt/v5" +) + +type JWTAdapter struct { + cfg *config.Config +} + +func NewJWTAdapter(cfg *config.Config) ports.TokenProvider { + return &JWTAdapter{ + cfg: cfg, + } +} + +func (a *JWTAdapter) GenerateToken(user database.User) (string, error) { + claims := &middleware.UserClaims{ + UserID: user.Uid.String(), + Email: getEmail(user), + Role: string(user.Role.UserRole), + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(a.cfg.JWT.ExpiryHours) * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "capy-api", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(a.cfg.JWT.Secret)) +} + +func getEmail(user database.User) string { + if user.SchoolEmail.Valid { + return user.SchoolEmail.String + } + if user.PersonalEmail.Valid { + return user.PersonalEmail.String + } + return "" +} diff --git a/internal/auth/ports/interfaces.go b/internal/auth/ports/interfaces.go new file mode 100644 index 0000000..15adf01 --- /dev/null +++ b/internal/auth/ports/interfaces.go @@ -0,0 +1,39 @@ +package ports + +import ( + "context" + + "github.com/capyrpi/api/internal/database" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +// UserInfo represents the normalized user information from OAuth providers +type UserInfo struct { + Email string + FirstName string + LastName string +} + +// UserRepo defines the interface for user persistence +type UserRepo interface { + GetUserByEmail(ctx context.Context, email pgtype.Text) (database.User, error) + CreateUser(ctx context.Context, arg database.CreateUserParams) (database.User, error) + GetUserByID(ctx context.Context, uid uuid.UUID) (database.User, error) +} + +// BotRepo defines the interface for bot token persistence +type BotRepo interface { + CreateBotToken(ctx context.Context, arg database.CreateBotTokenParams) (database.BotToken, error) +} + +// TokenProvider defines the interface for token generation +type TokenProvider interface { + GenerateToken(user database.User) (string, error) +} + +// OAuthProvider defines the interface for OAuth operations +type OAuthProvider interface { + GetAuthURL(state string) string + ExchangeCode(ctx context.Context, code string) (*UserInfo, error) +} diff --git a/internal/auth/service/service.go b/internal/auth/service/service.go new file mode 100644 index 0000000..e2ab5cd --- /dev/null +++ b/internal/auth/service/service.go @@ -0,0 +1,131 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "time" + + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/database" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "golang.org/x/crypto/bcrypt" +) + +type AuthResult struct { + User database.User + Token string +} + +type BotTokenResult struct { + Token database.BotToken + RawToken string +} + +type AuthService struct { + userRepo ports.UserRepo + botRepo ports.BotRepo + tokenProvider ports.TokenProvider + googleAuth ports.OAuthProvider + microsoftAuth ports.OAuthProvider +} + +func NewAuthService( + userRepo ports.UserRepo, + botRepo ports.BotRepo, + tokenProvider ports.TokenProvider, + googleAuth ports.OAuthProvider, + microsoftAuth ports.OAuthProvider, +) *AuthService { + return &AuthService{ + userRepo: userRepo, + botRepo: botRepo, + tokenProvider: tokenProvider, + googleAuth: googleAuth, + microsoftAuth: microsoftAuth, + } +} + +func (s *AuthService) HandleOAuthCallback(ctx context.Context, providerName string, code string) (*AuthResult, error) { + var provider ports.OAuthProvider + switch providerName { + case "google": + provider = s.googleAuth + case "microsoft": + provider = s.microsoftAuth + default: + return nil, errors.New("invalid provider") + } + + userInfo, err := provider.ExchangeCode(ctx, code) + if err != nil { + return nil, err + } + + pgEmail := pgtype.Text{String: userInfo.Email, Valid: true} + user, err := s.userRepo.GetUserByEmail(ctx, pgEmail) + if err != nil { + if err != pgx.ErrNoRows { + return nil, err + } + + // Create user if not exists + user, err = s.userRepo.CreateUser(ctx, database.CreateUserParams{ + FirstName: userInfo.FirstName, + LastName: userInfo.LastName, + PersonalEmail: pgEmail, + SchoolEmail: pgtype.Text{Valid: false}, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }) + if err != nil { + return nil, err + } + } + + token, err := s.tokenProvider.GenerateToken(user) + if err != nil { + return nil, err + } + + return &AuthResult{User: user, Token: token}, nil +} + +func (s *AuthService) GenerateBotToken(ctx context.Context, name string, createdBy uuid.UUID, expiresAt *time.Time) (*BotTokenResult, error) { + rawToken, err := generateSecureToken(32) + if err != nil { + return nil, err + } + + hashedToken, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + + pgExpiresAt := pgtype.Timestamp{Valid: false} + if expiresAt != nil { + pgExpiresAt = pgtype.Timestamp{Time: *expiresAt, Valid: true} + } + + token, err := s.botRepo.CreateBotToken(ctx, database.CreateBotTokenParams{ + TokenHash: string(hashedToken), + Name: name, + CreatedBy: createdBy, + ExpiresAt: pgExpiresAt, + }) + if err != nil { + return nil, err + } + + return &BotTokenResult{Token: token, RawToken: rawToken}, nil +} + +func generateSecureToken(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} diff --git a/internal/auth/service/service_test.go b/internal/auth/service/service_test.go new file mode 100644 index 0000000..4c3ae94 --- /dev/null +++ b/internal/auth/service/service_test.go @@ -0,0 +1,180 @@ +package service_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/auth/service" + "github.com/capyrpi/api/internal/database" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Mocks +type MockUserRepo struct { + mock.Mock +} + +func (m *MockUserRepo) GetUserByEmail(ctx context.Context, email pgtype.Text) (database.User, error) { + args := m.Called(ctx, email) + return args.Get(0).(database.User), args.Error(1) +} + +func (m *MockUserRepo) CreateUser(ctx context.Context, arg database.CreateUserParams) (database.User, error) { + args := m.Called(ctx, arg) + return args.Get(0).(database.User), args.Error(1) +} + +func (m *MockUserRepo) GetUserByID(ctx context.Context, uid uuid.UUID) (database.User, error) { + args := m.Called(ctx, uid) + return args.Get(0).(database.User), args.Error(1) +} + +type MockBotRepo struct { + mock.Mock +} + +func (m *MockBotRepo) CreateBotToken(ctx context.Context, arg database.CreateBotTokenParams) (database.BotToken, error) { + args := m.Called(ctx, arg) + return args.Get(0).(database.BotToken), args.Error(1) +} + +type MockTokenProvider struct { + mock.Mock +} + +func (m *MockTokenProvider) GenerateToken(user database.User) (string, error) { + args := m.Called(user) + return args.String(0), args.Error(1) +} + +type MockOAuthProvider struct { + mock.Mock +} + +func (m *MockOAuthProvider) GetAuthURL(state string) string { + args := m.Called(state) + return args.String(0) +} + +func (m *MockOAuthProvider) ExchangeCode(ctx context.Context, code string) (*ports.UserInfo, error) { + args := m.Called(ctx, code) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ports.UserInfo), args.Error(1) +} + +func TestAuthService_HandleOAuthCallback(t *testing.T) { + // Setup + mockUserRepo := new(MockUserRepo) + mockBotRepo := new(MockBotRepo) + mockTokenProvider := new(MockTokenProvider) + mockGoogle := new(MockOAuthProvider) + mockMicrosoft := new(MockOAuthProvider) + + svc := service.NewAuthService(mockUserRepo, mockBotRepo, mockTokenProvider, mockGoogle, mockMicrosoft) + ctx := context.Background() + + t.Run("Success existing user", func(t *testing.T) { + email := "test@example.com" + userInfo := &ports.UserInfo{Email: email, FirstName: "John", LastName: "Doe"} + user := database.User{ + Uid: uuid.New(), + FirstName: "John", + LastName: "Doe", + PersonalEmail: pgtype.Text{String: email, Valid: true}, + } + token := "jwt-token" + + mockGoogle.On("ExchangeCode", ctx, "auth-code").Return(userInfo, nil) + mockUserRepo.On("GetUserByEmail", ctx, pgtype.Text{String: email, Valid: true}).Return(user, nil) + mockTokenProvider.On("GenerateToken", user).Return(token, nil) + + res, err := svc.HandleOAuthCallback(ctx, "google", "auth-code") + + assert.NoError(t, err) + assert.Equal(t, user, res.User) + assert.Equal(t, token, res.Token) + mockGoogle.AssertExpectations(t) + mockUserRepo.AssertExpectations(t) + mockTokenProvider.AssertExpectations(t) + }) + + t.Run("Success new user", func(t *testing.T) { + email := "new@example.com" + userInfo := &ports.UserInfo{Email: email, FirstName: "Jane", LastName: "Doe"} + user := database.User{ + Uid: uuid.New(), + FirstName: "Jane", + LastName: "Doe", + PersonalEmail: pgtype.Text{String: email, Valid: true}, + } + token := "jwt-token" + + mockGoogle.On("ExchangeCode", ctx, "new-code").Return(userInfo, nil) + mockUserRepo.On("GetUserByEmail", ctx, pgtype.Text{String: email, Valid: true}).Return(database.User{}, pgx.ErrNoRows) + mockUserRepo.On("CreateUser", ctx, mock.AnythingOfType("database.CreateUserParams")).Return(user, nil) + mockTokenProvider.On("GenerateToken", user).Return(token, nil) + + res, err := svc.HandleOAuthCallback(ctx, "google", "new-code") + + assert.NoError(t, err) + assert.Equal(t, user, res.User) + assert.Equal(t, token, res.Token) + }) + + t.Run("Invalid provider", func(t *testing.T) { + res, err := svc.HandleOAuthCallback(ctx, "invalid", "code") + assert.Error(t, err) + assert.Nil(t, res) + }) + + t.Run("Exchange code error", func(t *testing.T) { + mockGoogle.On("ExchangeCode", ctx, "bad-code").Return(nil, errors.New("exchange error")) + res, err := svc.HandleOAuthCallback(ctx, "google", "bad-code") + assert.Error(t, err) + assert.Nil(t, res) + }) +} + +func TestAuthService_GenerateBotToken(t *testing.T) { + // Setup + mockUserRepo := new(MockUserRepo) + mockBotRepo := new(MockBotRepo) + mockTokenProvider := new(MockTokenProvider) + mockGoogle := new(MockOAuthProvider) + mockMicrosoft := new(MockOAuthProvider) + + svc := service.NewAuthService(mockUserRepo, mockBotRepo, mockTokenProvider, mockGoogle, mockMicrosoft) + ctx := context.Background() + uid := uuid.New() + + t.Run("Success", func(t *testing.T) { + name := "My Bot" + botToken := database.BotToken{ + TokenID: uuid.New(), + Name: name, + CreatedBy: uid, + CreatedAt: pgtype.Timestamp{Time: time.Now(), Valid: true}, + } + + // Use mock.MatchedBy to check arguments + mockBotRepo.On("CreateBotToken", ctx, mock.MatchedBy(func(arg database.CreateBotTokenParams) bool { + return arg.Name == name && arg.CreatedBy == uid && len(arg.TokenHash) > 0 + })).Return(botToken, nil) + + res, err := svc.GenerateBotToken(ctx, name, uid, nil) + + assert.NoError(t, err) + assert.Equal(t, botToken, res.Token) + assert.NotEmpty(t, res.RawToken) + mockBotRepo.AssertExpectations(t) + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index 0cd7e6d..4823966 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,9 +10,14 @@ type Config struct { JWT JWTConfig Cookie CookieConfig OAuth OAuthConfig + Swagger SwaggerConfig Env string `env:"ENV" env-default:"development"` } +type SwaggerConfig struct { + Host string `env:"SWAGGER_HOST"` +} + type ServerConfig struct { Host string `env:"SERVER_HOST" env-default:"0.0.0.0"` Port string `env:"SERVER_PORT" env-default:"8080"` diff --git a/internal/database/models.go b/internal/database/models.go index a54ed5a..2960bb2 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -18,6 +18,7 @@ const ( UserRoleStudent UserRole = "student" UserRoleAlumni UserRole = "alumni" UserRoleFaculty UserRole = "faculty" + UserRoleDev UserRole = "dev" UserRoleExternal UserRole = "external" ) diff --git a/internal/dto/dto.go b/internal/dto/dto.go index cabfa00..2bbf6b7 100644 --- a/internal/dto/dto.go +++ b/internal/dto/dto.go @@ -17,7 +17,7 @@ type CreateUserRequest struct { SchoolEmail string `json:"school_email,omitempty" validate:"omitempty,email"` Phone string `json:"phone,omitempty"` GradYear int `json:"grad_year,omitempty" validate:"omitempty,gte=2000,lte=2100"` - Role string `json:"role,omitempty" validate:"omitempty,oneof=student alumni faculty external"` + Role string `json:"role,omitempty" validate:"omitempty,oneof=student alumni faculty dev external"` } type UpdateUserRequest struct { @@ -27,7 +27,7 @@ type UpdateUserRequest struct { SchoolEmail *string `json:"school_email,omitempty" validate:"omitempty,email"` Phone *string `json:"phone,omitempty"` GradYear *int `json:"grad_year,omitempty" validate:"omitempty,gte=2000,lte=2100"` - Role *string `json:"role,omitempty" validate:"omitempty,oneof=student alumni faculty external"` + Role *string `json:"role,omitempty" validate:"omitempty,oneof=student alumni faculty dev external"` } type UserResponse struct { diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 5a43190..815a46b 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -1,9 +1,6 @@ package handler import ( - "context" - "crypto/rand" - "encoding/hex" "encoding/json" "net/http" "time" @@ -14,9 +11,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "golang.org/x/crypto/bcrypt" ) // ============================================================================ @@ -96,25 +90,13 @@ func (h *Handler) GoogleCallback(w http.ResponseWriter, r *http.Request) { return } - userInfo, err := h.googleAuth.ExchangeCode(r.Context(), code) + authResult, err := h.authService.HandleOAuthCallback(r.Context(), "google", code) if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to exchange code") + h.respondError(w, http.StatusInternalServerError, "Authentication failed") return } - user, err := h.upsertUser(r.Context(), userInfo.Email, userInfo.GivenName, userInfo.FamilyName) - if err != nil { - h.handleDBError(w, err) - return - } - - token, err := h.generateJWT(user) - if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to generate session") - return - } - - h.setAuthCookie(w, token) + h.setAuthCookie(w, authResult.Token) h.respondWithCloseWindow(w) } @@ -157,31 +139,13 @@ func (h *Handler) MicrosoftCallback(w http.ResponseWriter, r *http.Request) { return } - userInfo, err := h.microsoftAuth.ExchangeCode(r.Context(), code) - if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to exchange code") - return - } - - // Use PrincipalName (email) or Mail - email := userInfo.UserPrincipalName - if email == "" { - email = userInfo.Mail - } - - user, err := h.upsertUser(r.Context(), email, userInfo.GivenName, userInfo.Surname) - if err != nil { - h.handleDBError(w, err) - return - } - - token, err := h.generateJWT(user) + authResult, err := h.authService.HandleOAuthCallback(r.Context(), "microsoft", code) if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to generate session") + h.respondError(w, http.StatusInternalServerError, "Authentication failed") return } - h.setAuthCookie(w, token) + h.setAuthCookie(w, authResult.Token) h.respondWithCloseWindow(w) } @@ -304,7 +268,7 @@ func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) { // ListBotTokens lists all bot tokens // @Summary List bot tokens -// @Description Returns all bot tokens (requires faculty role) +// @Description Returns all bot tokens (requires dev role) // @Tags bot // @Accept json // @Produce json @@ -313,7 +277,10 @@ func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) { // @Security CookieAuth // @Router /bot/tokens [get] func (h *Handler) ListBotTokens(w http.ResponseWriter, r *http.Request) { - // TODO: Check faculty role + if _, ok := h.requireDevRole(w, r); !ok { + return + } + tokens, err := h.queries.ListBotTokens(r.Context()) if err != nil { h.handleDBError(w, err) @@ -336,7 +303,7 @@ func (h *Handler) ListBotTokens(w http.ResponseWriter, r *http.Request) { // CreateBotToken creates a new bot token // @Summary Create bot token -// @Description Creates a new bot token (requires faculty role) +// @Description Creates a new bot token (requires dev role) // @Tags bot // @Accept json // @Produce json @@ -347,14 +314,11 @@ func (h *Handler) ListBotTokens(w http.ResponseWriter, r *http.Request) { // @Security CookieAuth // @Router /bot/tokens [post] func (h *Handler) CreateBotToken(w http.ResponseWriter, r *http.Request) { - claims, ok := middleware.GetUserClaims(r.Context()) + uid, ok := h.requireDevRole(w, r) if !ok { - h.respondError(w, http.StatusUnauthorized, "Not authenticated") return } - // TODO: Check faculty role - var req CreateBotTokenRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { h.respondError(w, http.StatusBadRequest, "Invalid request body") @@ -366,46 +330,25 @@ func (h *Handler) CreateBotToken(w http.ResponseWriter, r *http.Request) { return } - // Generate random token - rawToken, err := generateSecureToken(32) - if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to generate token") - return - } - - // Hash the token for storage - hashedToken, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost) - if err != nil { - h.respondError(w, http.StatusInternalServerError, "Failed to hash token") - return - } - - uid, _ := uuid.Parse(claims.UserID) - - token, err := h.queries.CreateBotToken(r.Context(), database.CreateBotTokenParams{ - TokenHash: string(hashedToken), - Name: req.Name, - CreatedBy: uid, - ExpiresAt: toPgTimestamp(req.ExpiresAt), - }) + result, err := h.authService.GenerateBotToken(r.Context(), req.Name, uid, req.ExpiresAt) if err != nil { h.handleDBError(w, err) return } h.respondJSON(w, http.StatusCreated, BotTokenResponse{ - TokenID: token.TokenID, - Name: token.Name, - Token: rawToken, // Only returned on creation! - CreatedAt: token.CreatedAt.Time, - ExpiresAt: fromPgTimestamp(token.ExpiresAt), - IsActive: token.IsActive.Bool, + TokenID: result.Token.TokenID, + Name: result.Token.Name, + Token: result.RawToken, // Only returned on creation! + CreatedAt: result.Token.CreatedAt.Time, + ExpiresAt: fromPgTimestamp(result.Token.ExpiresAt), + IsActive: result.Token.IsActive.Bool, }) } // RevokeBotToken revokes a bot token // @Summary Revoke bot token -// @Description Revokes a bot token (requires faculty role) +// @Description Revokes a bot token (requires dev role) // @Tags bot // @Accept json // @Produce json @@ -423,7 +366,9 @@ func (h *Handler) RevokeBotToken(w http.ResponseWriter, r *http.Request) { return } - // TODO: Check faculty role + if _, ok := h.requireDevRole(w, r); !ok { + return + } if err := h.queries.RevokeBotToken(r.Context(), tokenID); err != nil { h.handleDBError(w, err) @@ -568,29 +513,6 @@ func (h *Handler) verifyStateCookie(w http.ResponseWriter, r *http.Request, stat return cookie.Value == state } -func (h *Handler) upsertUser(ctx context.Context, email, firstName, lastName string) (database.User, error) { - pgEmail := toPgTextFromString(email) - - // Check if user exists - user, err := h.queries.GetUserByEmail(ctx, pgEmail) - if err == nil { - return user, nil - } - - if err != pgx.ErrNoRows { - return database.User{}, err - } - - // Create new user - return h.queries.CreateUser(ctx, database.CreateUserParams{ - FirstName: firstName, - LastName: lastName, - PersonalEmail: pgEmail, // Default to personal email for oauth - SchoolEmail: pgtype.Text{Valid: false}, - Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, // Default role - }) -} - func getEmail(user database.User) string { if user.SchoolEmail.Valid { return user.SchoolEmail.String @@ -601,10 +523,29 @@ func getEmail(user database.User) string { return "" } -func generateSecureToken(length int) (string, error) { - bytes := make([]byte, length) - if _, err := rand.Read(bytes); err != nil { - return "", err +func (h *Handler) requireDevRole(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { + claims, ok := middleware.GetUserClaims(r.Context()) + if !ok { + h.respondError(w, http.StatusUnauthorized, "Not authenticated") + return uuid.Nil, false + } + + uid, err := uuid.Parse(claims.UserID) + if err != nil { + h.respondError(w, http.StatusUnauthorized, "Invalid user ID in token") + return uuid.Nil, false } - return hex.EncodeToString(bytes), nil + + user, err := h.queries.GetUserByID(r.Context(), uid) + if err != nil { + h.handleDBError(w, err) + return uuid.Nil, false + } + + if !user.Role.Valid || user.Role.UserRole != database.UserRoleDev { + h.respondError(w, http.StatusForbidden, "Requires dev role") + return uuid.Nil, false + } + + return uid, true } diff --git a/internal/handler/auth_characterization_test.go b/internal/handler/auth_characterization_test.go new file mode 100644 index 0000000..442d461 --- /dev/null +++ b/internal/handler/auth_characterization_test.go @@ -0,0 +1,213 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/capyrpi/api/internal/config" + "github.com/capyrpi/api/internal/database" + "github.com/capyrpi/api/internal/database/mocks" + "github.com/capyrpi/api/internal/middleware" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// TestRespondWithCloseWindow_Contract verifies the exact HTML response contract +// used by the frontend popup. +func TestRespondWithCloseWindow_Contract(t *testing.T) { + h := &Handler{} + rr := httptest.NewRecorder() + h.respondWithCloseWindow(rr) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) + + body := rr.Body.String() + assert.Contains(t, body, "") + assert.Contains(t, body, "window.close()") + assert.Contains(t, body, "Login Successful") +} + +// TestSetAuthCookie_Contract verifies cookie attributes are preserved. +func TestSetAuthCookie_Contract(t *testing.T) { + cfg := &config.Config{ + Cookie: config.CookieConfig{ + Domain: "example.com", + Secure: true, + }, + JWT: config.JWTConfig{ + ExpiryHours: 24, + }, + } + h := &Handler{Config: cfg} + rr := httptest.NewRecorder() + h.setAuthCookie(rr, "test-token") + + cookies := rr.Result().Cookies() + requireCookie(t, cookies, "capy_auth", func(c *http.Cookie) { + assert.Equal(t, "test-token", c.Value) + assert.Equal(t, "/", c.Path) + assert.Equal(t, "example.com", c.Domain) + assert.True(t, c.HttpOnly) + assert.True(t, c.Secure) + assert.Equal(t, http.SameSiteLaxMode, c.SameSite) + assert.Equal(t, 24*3600, c.MaxAge) + }) +} + +// TestLogout_Contract verifies logout behavior (cookie clearing). +func TestLogout_Contract(t *testing.T) { + cfg := &config.Config{ + Cookie: config.CookieConfig{ + Domain: "example.com", + Secure: true, + }, + } + h := &Handler{Config: cfg} + + req := httptest.NewRequest("POST", "/auth/logout", nil) + rr := httptest.NewRecorder() + + h.Logout(rr, req) + + assert.Equal(t, http.StatusNoContent, rr.Code) + cookies := rr.Result().Cookies() + requireCookie(t, cookies, "capy_auth", func(c *http.Cookie) { + assert.Equal(t, "", c.Value) + assert.Equal(t, -1, c.MaxAge) // Cleared + assert.Equal(t, "/", c.Path) + assert.Equal(t, "example.com", c.Domain) + }) +} + +// TestSetStateCookie_Contract verifies oauth state cookie attributes. +func TestSetStateCookie_Contract(t *testing.T) { + cfg := &config.Config{ + Cookie: config.CookieConfig{ + Domain: "example.com", + Secure: true, + }, + } + h := &Handler{Config: cfg} + rr := httptest.NewRecorder() + h.setStateCookie(rr, "random-state") + + cookies := rr.Result().Cookies() + requireCookie(t, cookies, "oauth_state", func(c *http.Cookie) { + assert.Equal(t, "random-state", c.Value) + assert.Equal(t, "/v1/auth", c.Path) // Specific path for state + assert.Equal(t, "example.com", c.Domain) + assert.True(t, c.HttpOnly) + assert.True(t, c.Secure) + assert.Equal(t, http.SameSiteLaxMode, c.SameSite) + assert.Equal(t, 300, c.MaxAge) // 5 minutes + }) +} + +// TestVerifyStateCookie_Contract verifies validation and clearing. +func TestVerifyStateCookie_Contract(t *testing.T) { + cfg := &config.Config{ + Cookie: config.CookieConfig{ + Domain: "example.com", + Secure: true, + }, + } + h := &Handler{Config: cfg} + + // Case 1: Valid match + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{Name: "oauth_state", Value: "valid-state"}) + rr := httptest.NewRecorder() + + valid := h.verifyStateCookie(rr, req, "valid-state") + assert.True(t, valid) + + // Should clear cookie + cookies := rr.Result().Cookies() + requireCookie(t, cookies, "oauth_state", func(c *http.Cookie) { + assert.Equal(t, "", c.Value) + assert.Equal(t, -1, c.MaxAge) + }) + + // Case 2: Mismatch + req = httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{Name: "oauth_state", Value: "valid-state"}) + rr = httptest.NewRecorder() + + valid = h.verifyStateCookie(rr, req, "invalid-state") + assert.False(t, valid) + // Even on mismatch, it should arguably clear or at least fail safely. + // Current impl clears it unconditionally. + cookies = rr.Result().Cookies() + requireCookie(t, cookies, "oauth_state", func(c *http.Cookie) { + assert.Equal(t, -1, c.MaxAge) + }) +} + +// TestRefreshToken_Contract verifies refresh flow logic including cookie set. +func TestRefreshToken_Contract(t *testing.T) { + uid := uuid.New() + mockQueries := mocks.NewQuerier(t) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpiryHours: 1, + }, + Cookie: config.CookieConfig{ + Domain: "localhost", + }, + } + h := New(mockQueries, cfg) + + // Mock DB lookup + mockQueries.On("GetUserByID", mock.Anything, uid).Return(database.User{ + Uid: uid, + FirstName: "Refresh", + LastName: "User", + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + + req := httptest.NewRequest("POST", "/auth/refresh", nil) + // Inject claims via context (simulating middleware) + claims := &middleware.UserClaims{ + UserID: uid.String(), + Role: "student", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + ctx := context.WithValue(req.Context(), middleware.UserClaimsKey, claims) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + h.RefreshToken(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + // Verify new cookie set + cookies := rr.Result().Cookies() + requireCookie(t, cookies, "capy_auth", func(c *http.Cookie) { + assert.NotEmpty(t, c.Value) + assert.Equal(t, 3600, c.MaxAge) + }) +} + +func requireCookie(t *testing.T, cookies []*http.Cookie, name string, check func(*http.Cookie)) { + t.Helper() + var found *http.Cookie + for _, c := range cookies { + if c.Name == name { + found = c + break + } + } + if found == nil { + t.Fatalf("Cookie %s not found", name) + } + check(found) +} diff --git a/internal/handler/auth_test.go b/internal/handler/auth_test.go new file mode 100644 index 0000000..0d64c51 --- /dev/null +++ b/internal/handler/auth_test.go @@ -0,0 +1,253 @@ +package handler_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/capyrpi/api/internal/config" + "github.com/capyrpi/api/internal/database" + "github.com/capyrpi/api/internal/database/mocks" + "github.com/capyrpi/api/internal/handler" + "github.com/capyrpi/api/internal/middleware" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetMe(t *testing.T) { + uid := uuid.New() + + tests := []struct { + name string + claims *middleware.UserClaims + setupMock func(*mocks.Querier) + expectedStatus int + }{ + { + name: "Success", + claims: &middleware.UserClaims{ + UserID: uid.String(), + Role: string(database.UserRoleStudent), + }, + setupMock: func(m *mocks.Querier) { + m.On("GetUserByID", mock.Anything, uid).Return(database.User{ + Uid: uid, + FirstName: "Test", + LastName: "User", + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "Unauthorized_NoClaims", + claims: nil, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + tt.setupMock(mockQueries) + + h := handler.New(mockQueries, &config.Config{}) + + req := httptest.NewRequest("GET", "/auth/me", nil) + if tt.claims != nil { + ctx := context.WithValue(req.Context(), middleware.UserClaimsKey, tt.claims) + req = req.WithContext(ctx) + } + + rr := httptest.NewRecorder() + http.HandlerFunc(h.GetMe).ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} + +func TestBotToken_RoleChecks(t *testing.T) { + uid := uuid.New() + + tests := []struct { + name string + handlerFunc func(*handler.Handler) http.HandlerFunc + method string + path string + tokenRole database.UserRole + dbRole database.UserRole + setupMock func(*mocks.Querier) + expectedStatus int + }{ + { + name: "ListBotTokens_Dev_Success", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.ListBotTokens }, + method: "GET", + path: "/bot/tokens", + tokenRole: database.UserRoleDev, + dbRole: database.UserRoleDev, + setupMock: func(m *mocks.Querier) { + m.On("ListBotTokens", mock.Anything).Return([]database.ListBotTokensRow{}, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "ListBotTokens_StudentClaim_DevDB_Success", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.ListBotTokens }, + method: "GET", + path: "/bot/tokens", + tokenRole: database.UserRoleStudent, + dbRole: database.UserRoleDev, + setupMock: func(m *mocks.Querier) { + m.On("ListBotTokens", mock.Anything).Return([]database.ListBotTokensRow{}, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "ListBotTokens_Faculty_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.ListBotTokens }, + method: "GET", + path: "/bot/tokens", + tokenRole: database.UserRoleFaculty, + dbRole: database.UserRoleFaculty, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + { + name: "ListBotTokens_Student_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.ListBotTokens }, + method: "GET", + path: "/bot/tokens", + tokenRole: database.UserRoleStudent, + dbRole: database.UserRoleStudent, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + { + name: "CreateBotToken_Dev_Success", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.CreateBotToken }, + method: "POST", + path: "/bot/tokens", + tokenRole: database.UserRoleDev, + dbRole: database.UserRoleDev, + setupMock: func(m *mocks.Querier) { + m.On("CreateBotToken", mock.Anything, mock.Anything).Return(database.BotToken{ + TokenID: uuid.New(), + Name: "Bot", + CreatedAt: pgtype.Timestamp{Time: time.Now(), Valid: true}, + IsActive: pgtype.Bool{Bool: true, Valid: true}, + }, nil) + }, + expectedStatus: http.StatusCreated, + }, + { + name: "CreateBotToken_Faculty_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.CreateBotToken }, + method: "POST", + path: "/bot/tokens", + tokenRole: database.UserRoleFaculty, + dbRole: database.UserRoleFaculty, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + { + name: "CreateBotToken_Student_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.CreateBotToken }, + method: "POST", + path: "/bot/tokens", + tokenRole: database.UserRoleStudent, + dbRole: database.UserRoleStudent, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + { + name: "RevokeBotToken_Dev_Success", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.RevokeBotToken }, + method: "DELETE", + path: "/bot/tokens/" + uuid.New().String(), + tokenRole: database.UserRoleDev, + dbRole: database.UserRoleDev, + setupMock: func(m *mocks.Querier) { + m.On("RevokeBotToken", mock.Anything, mock.Anything).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "RevokeBotToken_Faculty_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.RevokeBotToken }, + method: "DELETE", + path: "/bot/tokens/" + uuid.New().String(), + tokenRole: database.UserRoleFaculty, + dbRole: database.UserRoleFaculty, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + { + name: "RevokeBotToken_Student_Forbidden", + handlerFunc: func(h *handler.Handler) http.HandlerFunc { return h.RevokeBotToken }, + method: "DELETE", + path: "/bot/tokens/" + uuid.New().String(), + tokenRole: database.UserRoleStudent, + dbRole: database.UserRoleStudent, + setupMock: func(m *mocks.Querier) {}, + expectedStatus: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + mockQueries.On("GetUserByID", mock.Anything, uid).Return(database.User{ + Uid: uid, + Role: database.NullUserRole{UserRole: tt.dbRole, Valid: true}, + }, nil) + if tt.expectedStatus != http.StatusForbidden { + tt.setupMock(mockQueries) + } + + h := handler.New(mockQueries, &config.Config{}) + + var body *strings.Reader + if tt.method == "POST" { + jsonBody, _ := json.Marshal(map[string]interface{}{ + "name": "Test Bot", + }) + body = strings.NewReader(string(jsonBody)) + } else { + body = strings.NewReader("") + } + + req := httptest.NewRequest(tt.method, tt.path, body) + + claims := &middleware.UserClaims{ + UserID: uid.String(), + Role: string(tt.tokenRole), + } + ctx := context.WithValue(req.Context(), middleware.UserClaimsKey, claims) + + if tt.method == "DELETE" { + rctx := chi.NewRouteContext() + parts := strings.Split(tt.path, "/") + if len(parts) > 0 { + rctx.URLParams.Add("token_id", parts[len(parts)-1]) + } + ctx = context.WithValue(ctx, chi.RouteCtxKey, rctx) + } + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + tt.handlerFunc(h).ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} diff --git a/internal/handler/events.go b/internal/handler/events.go index 0d247f7..21f6186 100644 --- a/internal/handler/events.go +++ b/internal/handler/events.go @@ -6,6 +6,7 @@ import ( "github.com/capyrpi/api/internal/database" "github.com/capyrpi/api/internal/dto" + "github.com/capyrpi/api/internal/middleware" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" @@ -254,7 +255,26 @@ func (h *Handler) RegisterForEvent(w http.ResponseWriter, r *http.Request) { return } - // TODO: Get user ID from context if not provided (for human auth) + authType := middleware.GetAuthType(r.Context()) + if authType == "human" { + claims, ok := middleware.GetUserClaims(r.Context()) + if !ok { + h.respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + userUID, err := uuid.Parse(claims.UserID) + if err != nil { + h.respondError(w, http.StatusInternalServerError, "Invalid user ID in token") + return + } + + if req.UID != nil && *req.UID != userUID { + h.respondError(w, http.StatusForbidden, "Cannot register for another user") + return + } + req.UID = &userUID + } + if req.UID == nil { h.respondError(w, http.StatusBadRequest, "uid is required") return @@ -295,16 +315,44 @@ func (h *Handler) UnregisterFromEvent(w http.ResponseWriter, r *http.Request) { // Get UID from query param or context uidStr := r.URL.Query().Get("uid") - if uidStr == "" { - // TODO: Get from auth context - h.respondError(w, http.StatusBadRequest, "uid is required") - return - } + var uid uuid.UUID + + authType := middleware.GetAuthType(r.Context()) + if authType == "human" { + claims, ok := middleware.GetUserClaims(r.Context()) + if !ok { + h.respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + userUID, err := uuid.Parse(claims.UserID) + if err != nil { + h.respondError(w, http.StatusInternalServerError, "Invalid user ID in token") + return + } - uid, err := uuid.Parse(uidStr) - if err != nil { - h.respondError(w, http.StatusBadRequest, "Invalid user ID") - return + if uidStr != "" { + requestedUID, err := uuid.Parse(uidStr) + if err != nil { + h.respondError(w, http.StatusBadRequest, "Invalid user ID") + return + } + if requestedUID != userUID { + h.respondError(w, http.StatusForbidden, "Cannot unregister for another user") + return + } + } + uid = userUID + } else { + if uidStr == "" { + h.respondError(w, http.StatusBadRequest, "uid is required") + return + } + var err error + uid, err = uuid.Parse(uidStr) + if err != nil { + h.respondError(w, http.StatusBadRequest, "Invalid user ID") + return + } } if err := h.queries.UnregisterFromEvent(r.Context(), database.UnregisterFromEventParams{ diff --git a/internal/handler/events_test.go b/internal/handler/events_test.go new file mode 100644 index 0000000..7e372d5 --- /dev/null +++ b/internal/handler/events_test.go @@ -0,0 +1,252 @@ +package handler_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/capyrpi/api/internal/config" + "github.com/capyrpi/api/internal/database" + "github.com/capyrpi/api/internal/database/mocks" + "github.com/capyrpi/api/internal/dto" + "github.com/capyrpi/api/internal/handler" + "github.com/capyrpi/api/internal/middleware" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestRegisterForEvent(t *testing.T) { + eid := uuid.New() + userUID := uuid.New() + otherUID := uuid.New() + + tests := []struct { + name string + authType string + claims *middleware.UserClaims + requestBody dto.RegisterEventRequest + mockSetup func(*mocks.Querier) + expectedStatus int + }{ + { + name: "Human_Success_ExplicitUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + requestBody: dto.RegisterEventRequest{ + UID: &userUID, + IsAttending: true, + }, + mockSetup: func(m *mocks.Querier) { + m.On("RegisterForEvent", mock.Anything, mock.MatchedBy(func(arg database.RegisterForEventParams) bool { + return arg.Uid == userUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusCreated, + }, + { + name: "Human_Success_ImplicitUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + requestBody: dto.RegisterEventRequest{ + UID: nil, + IsAttending: true, + }, + mockSetup: func(m *mocks.Querier) { + m.On("RegisterForEvent", mock.Anything, mock.MatchedBy(func(arg database.RegisterForEventParams) bool { + return arg.Uid == userUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusCreated, + }, + { + name: "Human_Forbidden_MismatchUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + requestBody: dto.RegisterEventRequest{ + UID: &otherUID, + IsAttending: true, + }, + mockSetup: func(m *mocks.Querier) { + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "Bot_Success", + authType: "bot", + claims: nil, + requestBody: dto.RegisterEventRequest{ + UID: &otherUID, + IsAttending: true, + }, + mockSetup: func(m *mocks.Querier) { + m.On("RegisterForEvent", mock.Anything, mock.MatchedBy(func(arg database.RegisterForEventParams) bool { + return arg.Uid == otherUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusCreated, + }, + { + name: "Bot_MissingUID", + authType: "bot", + claims: nil, + requestBody: dto.RegisterEventRequest{ + UID: nil, + IsAttending: true, + }, + mockSetup: func(m *mocks.Querier) { + }, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + if tt.mockSetup != nil { + tt.mockSetup(mockQueries) + } + + h := handler.New(mockQueries, &config.Config{}) + r := chi.NewRouter() + r.Post("/events/{eid}/register", h.RegisterForEvent) + + body, _ := json.Marshal(tt.requestBody) + req, _ := http.NewRequest("POST", fmt.Sprintf("/events/%s/register", eid), bytes.NewBuffer(body)) + + ctx := req.Context() + if tt.authType != "" { + ctx = context.WithValue(ctx, middleware.AuthTypeKey, tt.authType) + } + if tt.claims != nil { + ctx = context.WithValue(ctx, middleware.UserClaimsKey, tt.claims) + } + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} + +func TestUnregisterFromEvent(t *testing.T) { + eid := uuid.New() + userUID := uuid.New() + otherUID := uuid.New() + + tests := []struct { + name string + authType string + claims *middleware.UserClaims + uidParam string + mockSetup func(*mocks.Querier) + expectedStatus int + }{ + { + name: "Human_Success_ExplicitUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + uidParam: userUID.String(), + mockSetup: func(m *mocks.Querier) { + m.On("UnregisterFromEvent", mock.Anything, mock.MatchedBy(func(arg database.UnregisterFromEventParams) bool { + return arg.Uid == userUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "Human_Success_ImplicitUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + uidParam: "", + mockSetup: func(m *mocks.Querier) { + m.On("UnregisterFromEvent", mock.Anything, mock.MatchedBy(func(arg database.UnregisterFromEventParams) bool { + return arg.Uid == userUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "Human_Forbidden_MismatchUID", + authType: "human", + claims: &middleware.UserClaims{ + UserID: userUID.String(), + }, + uidParam: otherUID.String(), + mockSetup: func(m *mocks.Querier) { + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "Bot_Success", + authType: "bot", + claims: nil, + uidParam: otherUID.String(), + mockSetup: func(m *mocks.Querier) { + m.On("UnregisterFromEvent", mock.Anything, mock.MatchedBy(func(arg database.UnregisterFromEventParams) bool { + return arg.Uid == otherUID && arg.Eid == eid + })).Return(nil) + }, + expectedStatus: http.StatusNoContent, + }, + { + name: "Bot_MissingUID", + authType: "bot", + claims: nil, + uidParam: "", + mockSetup: func(m *mocks.Querier) { + }, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + if tt.mockSetup != nil { + tt.mockSetup(mockQueries) + } + + h := handler.New(mockQueries, &config.Config{}) + r := chi.NewRouter() + r.Delete("/events/{eid}/register", h.UnregisterFromEvent) + + url := fmt.Sprintf("/events/%s/register", eid) + if tt.uidParam != "" { + url += "?uid=" + tt.uidParam + } + req, _ := http.NewRequest("DELETE", url, nil) + + ctx := req.Context() + if tt.authType != "" { + ctx = context.WithValue(ctx, middleware.AuthTypeKey, tt.authType) + } + if tt.claims != nil { + ctx = context.WithValue(ctx, middleware.UserClaimsKey, tt.claims) + } + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go index cbfc2e7..6ec7803 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -5,6 +5,9 @@ import ( "log/slog" "net/http" + "github.com/capyrpi/api/internal/auth/adapters" + "github.com/capyrpi/api/internal/auth/ports" + "github.com/capyrpi/api/internal/auth/service" "github.com/capyrpi/api/internal/config" "github.com/capyrpi/api/internal/database" "github.com/capyrpi/api/internal/oauth" @@ -18,15 +21,38 @@ type Handler struct { Config *config.Config googleAuth *oauth.GoogleProvider microsoftAuth *oauth.MicrosoftProvider + authService *service.AuthService } // New creates a new Handler with the given dependencies func New(queries database.Querier, cfg *config.Config) *Handler { + googleProvider := oauth.NewGoogleProvider(cfg.OAuth.Google.ClientID, cfg.OAuth.Google.ClientSecret, cfg.OAuth.Google.RedirectURL) + microsoftProvider := oauth.NewMicrosoftProvider(cfg.OAuth.Microsoft.ClientID, cfg.OAuth.Microsoft.ClientSecret, cfg.OAuth.Microsoft.RedirectURL, cfg.OAuth.Microsoft.TenantID) + + userRepoAdapter := adapters.NewUserRepoAdapter(queries) + botRepoAdapter, ok := userRepoAdapter.(ports.BotRepo) + if !ok { + panic("UserRepoAdapter does not implement BotRepo") + } + + tokenProviderAdapter := adapters.NewJWTAdapter(cfg) + googleAdapter := adapters.NewGoogleOAuthAdapter(googleProvider) + microsoftAdapter := adapters.NewMicrosoftOAuthAdapter(microsoftProvider) + + authService := service.NewAuthService( + userRepoAdapter, + botRepoAdapter, + tokenProviderAdapter, + googleAdapter, + microsoftAdapter, + ) + return &Handler{ queries: queries, Config: cfg, - googleAuth: oauth.NewGoogleProvider(cfg.OAuth.Google.ClientID, cfg.OAuth.Google.ClientSecret, cfg.OAuth.Google.RedirectURL), - microsoftAuth: oauth.NewMicrosoftProvider(cfg.OAuth.Microsoft.ClientID, cfg.OAuth.Microsoft.ClientSecret, cfg.OAuth.Microsoft.RedirectURL, cfg.OAuth.Microsoft.TenantID), + googleAuth: googleProvider, + microsoftAuth: microsoftProvider, + authService: authService, } } diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index ce7d3c7..418f367 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -5,7 +5,9 @@ import ( "net/http" "strings" + "github.com/capyrpi/api/internal/database" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" ) type contextKey string @@ -23,8 +25,12 @@ type UserClaims struct { jwt.RegisteredClaims } +type UserLookup interface { + GetUserByID(ctx context.Context, uid uuid.UUID) (database.User, error) +} + // Auth middleware validates JWT tokens from cookies or Authorization header -func Auth(jwtSecret string) func(http.Handler) http.Handler { +func Auth(jwtSecret string, userLookup UserLookup) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var tokenString string @@ -43,15 +49,6 @@ func Auth(jwtSecret string) func(http.Handler) http.Handler { } } - // Check for bot token - botToken := r.Header.Get("X-Bot-Token") - if botToken != "" { - // Bot authentication will be handled separately - ctx := context.WithValue(r.Context(), AuthTypeKey, "bot") - next.ServeHTTP(w, r.WithContext(ctx)) - return - } - if tokenString == "" { http.Error(w, "Unauthorized", http.StatusUnauthorized) return @@ -68,6 +65,23 @@ func Auth(jwtSecret string) func(http.Handler) http.Handler { return } + uid, err := uuid.Parse(claims.UserID) + if err != nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + user, err := userLookup.GetUserByID(r.Context(), uid) + if err != nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + if !user.Role.Valid || string(user.Role.UserRole) != claims.Role { + http.Error(w, "Stale token", http.StatusUnauthorized) + return + } + // Add claims to context ctx := context.WithValue(r.Context(), UserClaimsKey, claims) ctx = context.WithValue(ctx, AuthTypeKey, "human") diff --git a/internal/middleware/auth_human_test.go b/internal/middleware/auth_human_test.go index b649206..dad10fd 100644 --- a/internal/middleware/auth_human_test.go +++ b/internal/middleware/auth_human_test.go @@ -1,18 +1,41 @@ package middleware_test import ( + "context" "net/http" "net/http/httptest" "testing" "time" + "github.com/capyrpi/api/internal/database" "github.com/capyrpi/api/internal/middleware" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) +type stubUserLookup struct { + userByID map[uuid.UUID]database.User +} + +func (s stubUserLookup) GetUserByID(_ context.Context, uid uuid.UUID) (database.User, error) { + if user, ok := s.userByID[uid]; ok { + return user, nil + } + return database.User{}, assert.AnError +} + func TestAuth(t *testing.T) { secret := "test-secret" + uid := uuid.New() + lookup := stubUserLookup{ + userByID: map[uuid.UUID]database.User{ + uid: { + Uid: uid, + Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + }, + }, + } tests := []struct { name string @@ -24,7 +47,8 @@ func TestAuth(t *testing.T) { tokenSetup: func() *http.Request { // Generate token claims := middleware.UserClaims{ - UserID: "user-123", + UserID: uid.String(), + Role: string(database.UserRoleStudent), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), }, @@ -38,6 +62,25 @@ func TestAuth(t *testing.T) { }, expectedStatus: http.StatusOK, }, + { + name: "StaleRoleToken", + tokenSetup: func() *http.Request { + claims := middleware.UserClaims{ + UserID: uid.String(), + Role: string(database.UserRoleDev), + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, _ := token.SignedString([]byte(secret)) + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{Name: "capy_auth", Value: tokenStr}) + return req + }, + expectedStatus: http.StatusUnauthorized, + }, { name: "MissingToken", tokenSetup: func() *http.Request { @@ -58,7 +101,8 @@ func TestAuth(t *testing.T) { name: "ExpiredToken", tokenSetup: func() *http.Request { claims := middleware.UserClaims{ - UserID: "user-123", + UserID: uid.String(), + Role: string(database.UserRoleStudent), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, @@ -76,11 +120,11 @@ func TestAuth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - handler := middleware.Auth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := middleware.Auth(secret, lookup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If we get here, check context claims, ok := middleware.GetUserClaims(r.Context()) if ok { - assert.Equal(t, "user-123", claims.UserID) + assert.Equal(t, uid.String(), claims.UserID) assert.Equal(t, "human", middleware.GetAuthType(r.Context())) } w.WriteHeader(http.StatusOK) diff --git a/internal/middleware/auth_m2m.go b/internal/middleware/auth_m2m.go index c2d6aab..c843e9b 100644 --- a/internal/middleware/auth_m2m.go +++ b/internal/middleware/auth_m2m.go @@ -21,9 +21,12 @@ type BotTokenInfo struct { func M2MAuth(queries database.Querier) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get("X-Bot-Token") + token := r.Header.Get("X-API-Key") if token == "" { - http.Error(w, "Missing X-Bot-Token header", http.StatusUnauthorized) + token = r.Header.Get("X-Bot-Token") + } + if token == "" { + http.Error(w, "Missing API key header", http.StatusUnauthorized) return } diff --git a/internal/middleware/auth_m2m_test.go b/internal/middleware/auth_m2m_test.go new file mode 100644 index 0000000..3582392 --- /dev/null +++ b/internal/middleware/auth_m2m_test.go @@ -0,0 +1,64 @@ +package middleware_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/capyrpi/api/internal/database/mocks" + "github.com/capyrpi/api/internal/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestM2MAuth_MissingHeaders(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + handler := middleware.M2MAuth(mockQueries)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/bot/me", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Body.String(), "Missing API key header") +} + +func TestM2MAuth_AcceptsXAPIKeyHeader(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + mockQueries.On("ListBotTokens", mock.Anything).Return(nil, errors.New("db down")).Once() + + handler := middleware.M2MAuth(mockQueries)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/bot/me", nil) + req.Header.Set("X-API-Key", "test-api-key") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Internal server error") +} + +func TestM2MAuth_AcceptsXBotTokenHeader(t *testing.T) { + mockQueries := mocks.NewQuerier(t) + mockQueries.On("ListBotTokens", mock.Anything).Return(nil, errors.New("db down")).Once() + + handler := middleware.M2MAuth(mockQueries)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/bot/me", nil) + req.Header.Set("X-Bot-Token", "test-bot-token") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Internal server error") +} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index d1740a6..36c3b69 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -30,7 +30,7 @@ func CORS(allowedOrigins []string) func(http.Handler) http.Handler { if allowed { w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, X-Bot-Token") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, X-Bot-Token, X-API-Key") w.Header().Set("Access-Control-Allow-Credentials", "true") } diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index c097f48..d36e3bb 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -78,6 +78,9 @@ func TestCORS(t *testing.T) { assert.Equal(t, tt.expectedOrigin, rr.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, tt.expectedCreds, rr.Header().Get("Access-Control-Allow-Credentials")) + if tt.expectedOrigin != "" { + assert.Equal(t, "Accept, Authorization, Content-Type, X-Bot-Token, X-API-Key", rr.Header().Get("Access-Control-Allow-Headers")) + } if tt.method == "OPTIONS" && tt.expectedOrigin != "" { assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", rr.Header().Get("Access-Control-Allow-Methods")) diff --git a/internal/router/router.go b/internal/router/router.go index 97c2cb1..7e6694d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -27,7 +27,34 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed // Swagger UI (public) - Only in non-production environments if h.Config.Env != "production" { - r.Get("/swagger/*", httpSwagger.WrapHandler) + r.Get("/swagger/*", httpSwagger.Handler( + httpSwagger.URL("/swagger/doc.json"), + httpSwagger.BeforeScript(`const UrlMutatorPlugin = (system) => ({ + rootInjects: { + setScheme: (scheme) => { + const jsonSpec = system.getState().toJSON().spec.json; + const schemes = Array.isArray(scheme) ? scheme : [scheme]; + const newJsonSpec = Object.assign({}, jsonSpec, { schemes }); + + return system.specActions.updateJsonSpec(newJsonSpec); + }, + setHost: (host) => { + const jsonSpec = system.getState().toJSON().spec.json; + const newJsonSpec = Object.assign({}, jsonSpec, { host }); + + return system.specActions.updateJsonSpec(newJsonSpec); + } + } +});`), + httpSwagger.Plugins([]string{"UrlMutatorPlugin"}), + httpSwagger.UIConfig(map[string]string{ + "onComplete": `() => { + const loc = window.location; + window.ui.setScheme(loc.protocol.replace(':', '')); + window.ui.setHost(loc.host); + }`, + }), + )) } // API v1 routes @@ -41,7 +68,7 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed // Protected auth routes r.Group(func(r chi.Router) { - r.Use(middleware.Auth(jwtSecret)) + r.Use(middleware.Auth(jwtSecret, queries)) r.Get("/me", h.GetMe) r.Post("/logout", h.Logout) r.Post("/refresh", h.RefreshToken) @@ -50,7 +77,7 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed // Protected routes - require human authentication r.Group(func(r chi.Router) { - r.Use(middleware.Auth(jwtSecret)) + r.Use(middleware.Auth(jwtSecret, queries)) // Users r.Route("/users", func(r chi.Router) { @@ -93,6 +120,12 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed r.Post("/", h.CreateBotToken) r.Delete("/{token_id}", h.RevokeBotToken) }) + + r.Route("/api-keys", func(r chi.Router) { + r.Get("/", h.ListBotTokens) + r.Post("/", h.CreateBotToken) + r.Delete("/{token_id}", h.RevokeBotToken) + }) }) // Bot routes (M2M auth) diff --git a/schema.sql b/schema.sql index 160c511..0824021 100644 --- a/schema.sql +++ b/schema.sql @@ -2,7 +2,7 @@ -- Database Schema for CAPY (Club Assistant in Python) -- 1. ENUMs & Functions -CREATE TYPE user_role AS ENUM ('student', 'alumni', 'faculty', 'external'); +CREATE TYPE user_role AS ENUM ('student', 'alumni', 'faculty', 'dev', 'external'); CREATE OR REPLACE FUNCTION update_modified_column() RETURNS TRIGGER AS $$ @@ -88,4 +88,4 @@ DROP TRIGGER IF EXISTS update_orgs_modtime ON organizations; CREATE TRIGGER update_orgs_modtime BEFORE UPDATE ON organizations FOR EACH ROW EXECUTE FUNCTION update_modified_column(); DROP TRIGGER IF EXISTS update_events_modtime ON events; -CREATE TRIGGER update_events_modtime BEFORE UPDATE ON events FOR EACH ROW EXECUTE FUNCTION update_modified_column(); \ No newline at end of file +CREATE TRIGGER update_events_modtime BEFORE UPDATE ON events FOR EACH ROW EXECUTE FUNCTION update_modified_column(); diff --git a/scripts/create_dev_user/main.go b/scripts/create_dev_user/main.go index 020d933..e7c6893 100644 --- a/scripts/create_dev_user/main.go +++ b/scripts/create_dev_user/main.go @@ -45,7 +45,7 @@ func main() { FirstName: "Dev", LastName: "User", PersonalEmail: pgtype.Text{String: "dev@example.com", Valid: true}, - Role: database.NullUserRole{UserRole: database.UserRoleStudent, Valid: true}, + Role: database.NullUserRole{UserRole: database.UserRoleDev, Valid: true}, }) if err != nil { // Try to find existing if duplicate @@ -60,7 +60,7 @@ func main() { claims := middleware.UserClaims{ UserID: user.Uid.String(), Email: user.PersonalEmail.String, - Role: "student", + Role: string(database.UserRoleDev), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), diff --git a/server b/server new file mode 100755 index 0000000..ebb890e Binary files /dev/null and b/server differ diff --git a/tests/benchmarks/benchmark_test.go b/tests/benchmarks/benchmark_test.go index 7e46749..d59cb22 100644 --- a/tests/benchmarks/benchmark_test.go +++ b/tests/benchmarks/benchmark_test.go @@ -15,6 +15,16 @@ var benchClient = &http.Client{ }, } +func drainAndCloseBody(b *testing.B, body io.ReadCloser) { + b.Helper() + if _, err := io.Copy(io.Discard, body); err != nil { + b.Fatalf("failed to drain response body: %v", err) + } + if err := body.Close(); err != nil { + b.Fatalf("failed to close response body: %v", err) + } +} + func BenchmarkHealthEndpoint(b *testing.B) { b.ResetTimer() b.ReportAllocs() @@ -27,8 +37,7 @@ func BenchmarkHealthEndpoint(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -53,8 +62,7 @@ func BenchmarkAuthGoogleInitiation(b *testing.B) { if resp.StatusCode != http.StatusTemporaryRedirect && resp.StatusCode != http.StatusFound { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -79,8 +87,7 @@ func BenchmarkAuthMicrosoftInitiation(b *testing.B) { if resp.StatusCode != http.StatusTemporaryRedirect && resp.StatusCode != http.StatusFound { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -97,8 +104,7 @@ func BenchmarkGetMe(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -115,8 +121,7 @@ func BenchmarkGetUser(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -133,8 +138,7 @@ func BenchmarkListOrganizations(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -157,8 +161,7 @@ func BenchmarkCreateOrganization(b *testing.B) { if resp.StatusCode != http.StatusCreated { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -175,8 +178,7 @@ func BenchmarkGetOrganization(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -193,8 +195,7 @@ func BenchmarkListEvents(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -218,8 +219,7 @@ func BenchmarkCreateEvent(b *testing.B) { if resp.StatusCode != http.StatusCreated { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } @@ -236,7 +236,6 @@ func BenchmarkGetEvent(b *testing.B) { if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + drainAndCloseBody(b, resp.Body) } } diff --git a/tests/integration/api_test.go b/tests/integration/api_test.go index e8f4674..65f33ff 100644 --- a/tests/integration/api_test.go +++ b/tests/integration/api_test.go @@ -53,6 +53,7 @@ func TestFullAPI(t *testing.T) { // 3. Generate Auth Token claims := middleware.UserClaims{ UserID: user.Uid.String(), + Role: string(user.Role.UserRole), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), },