diff --git a/packages/core/src/__tests__/middleware/auth.test.ts b/packages/core/src/__tests__/middleware/auth.test.ts index 608b216b1..48f669185 100644 --- a/packages/core/src/__tests__/middleware/auth.test.ts +++ b/packages/core/src/__tests__/middleware/auth.test.ts @@ -597,6 +597,170 @@ describe('requireAuth middleware - Error Handling', () => { }) }) +describe('JWT secret externalization (env var)', () => { + it('should generate token with custom secret', async () => { + const customSecret = 'my-custom-secret-key-from-env' + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', customSecret + ) + + expect(token).toBeTruthy() + expect(token.split('.')).toHaveLength(3) + }) + + it('should verify token with matching custom secret', async () => { + const customSecret = 'my-custom-secret-key-from-env' + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', customSecret + ) + + const payload = await AuthManager.verifyToken(token, customSecret) + expect(payload).not.toBeNull() + expect(payload?.userId).toBe('user-123') + expect(payload?.email).toBe('test@example.com') + }) + + it('should reject token when verified with wrong secret', async () => { + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', 'secret-a' + ) + + const payload = await AuthManager.verifyToken(token, 'secret-b') + expect(payload).toBeNull() + }) + + it('should reject custom-secret token when verified with fallback', async () => { + const customSecret = 'my-custom-secret-key-from-env' + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', customSecret + ) + + // Verify without providing secret (uses fallback) + const payload = await AuthManager.verifyToken(token) + expect(payload).toBeNull() + }) + + it('should verify fallback-secret token without providing secret', async () => { + // Generate without custom secret (uses fallback) + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin' + ) + + // Verify without custom secret (uses same fallback) + const payload = await AuthManager.verifyToken(token) + expect(payload).not.toBeNull() + expect(payload?.userId).toBe('user-123') + }) +}) + +describe('requireAuth middleware - JWT_SECRET from env', () => { + let mockNext: Next + + beforeEach(() => { + mockNext = vi.fn() + }) + + it('should use JWT_SECRET from env when verifying tokens', async () => { + const envSecret = 'env-jwt-secret-12345' + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', envSecret + ) + + const mockContext: any = { + req: { + header: vi.fn().mockImplementation((name: string) => { + if (name === 'Authorization') return `Bearer ${token}` + return undefined + }), + raw: { headers: new Headers() } + }, + set: vi.fn(), + json: vi.fn(), + redirect: vi.fn(), + env: { JWT_SECRET: envSecret } + } + + const middleware = requireAuth() + await middleware(mockContext as Context, mockNext) + + expect(mockContext.set).toHaveBeenCalledWith('user', expect.objectContaining({ + userId: 'user-123', + email: 'test@example.com', + role: 'admin' + })) + expect(mockNext).toHaveBeenCalled() + }) + + it('should reject token signed with different secret than env JWT_SECRET', async () => { + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'admin', 'wrong-secret' + ) + + const mockContext: any = { + req: { + header: vi.fn().mockImplementation((name: string) => { + if (name === 'Authorization') return `Bearer ${token}` + return undefined + }), + raw: { headers: new Headers() } + }, + set: vi.fn(), + json: vi.fn().mockReturnValue({ error: 'Invalid or expired token' }), + redirect: vi.fn(), + env: { JWT_SECRET: 'correct-env-secret' } + } + + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const middleware = requireAuth() + await middleware(mockContext as Context, mockNext) + + expect(mockContext.json).toHaveBeenCalledWith( + { error: 'Invalid or expired token' }, + 401 + ) + expect(mockNext).not.toHaveBeenCalled() + + consoleSpy.mockRestore() + }) +}) + +describe('optionalAuth middleware - JWT_SECRET from env', () => { + let mockNext: Next + + beforeEach(() => { + mockNext = vi.fn() + }) + + it('should use JWT_SECRET from env when verifying tokens', async () => { + const envSecret = 'env-jwt-secret-12345' + const token = await AuthManager.generateToken( + 'user-123', 'test@example.com', 'user', envSecret + ) + + const mockContext: any = { + req: { + header: vi.fn().mockImplementation((name: string) => { + if (name === 'Authorization') return `Bearer ${token}` + return undefined + }), + raw: { headers: new Headers() } + }, + set: vi.fn(), + env: { JWT_SECRET: envSecret } + } + + const middleware = optionalAuth() + await middleware(mockContext as Context, mockNext) + + expect(mockContext.set).toHaveBeenCalledWith('user', expect.objectContaining({ + userId: 'user-123', + role: 'user' + })) + expect(mockNext).toHaveBeenCalled() + }) +}) + describe('requireRole middleware - Browser Redirects', () => { let mockContext: any let mockNext: Next diff --git a/packages/core/src/__tests__/middleware/cors.test.ts b/packages/core/src/__tests__/middleware/cors.test.ts new file mode 100644 index 000000000..3e77da699 --- /dev/null +++ b/packages/core/src/__tests__/middleware/cors.test.ts @@ -0,0 +1,220 @@ +import { describe, it, expect } from 'vitest' +import { Hono } from 'hono' +import { cors } from 'hono/cors' + +/** + * Tests for the CORS origin allowlist implementation. + * + * The CORS middleware is configured inline in routes/api.ts using hono/cors. + * We recreate the same origin callback logic here to test it in isolation. + */ + +// Replicate the exact origin callback from routes/api.ts +function createCorsOriginCallback() { + return (origin: string, c: any): string | null => { + const allowed = (c.env as any)?.CORS_ORIGINS as string | undefined + if (!allowed) return null + const list = allowed.split(',').map((s: string) => s.trim()) + return list.includes(origin) ? origin : null + } +} + +describe('CORS origin allowlist', () => { + function createApp(corsOrigins?: string) { + const app = new Hono<{ Bindings: { CORS_ORIGINS?: string } }>() + + app.use( + '*', + cors({ + origin: createCorsOriginCallback(), + allowMethods: ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization', 'X-API-Key'], + }) + ) + + app.get('/api/test', (c) => c.json({ ok: true })) + + return app + } + + describe('when CORS_ORIGINS is not set', () => { + it('should not include Access-Control-Allow-Origin header', async () => { + const app = createApp() + const res = await app.request('/api/test', { + headers: { Origin: 'https://evil.com' }, + }) + + expect(res.status).toBe(200) + expect(res.headers.get('Access-Control-Allow-Origin')).toBeNull() + }) + + it('should reject preflight requests from any origin', async () => { + const app = createApp() + const res = await app.request('/api/test', { + method: 'OPTIONS', + headers: { + Origin: 'https://evil.com', + 'Access-Control-Request-Method': 'POST', + }, + }) + + expect(res.headers.get('Access-Control-Allow-Origin')).toBeNull() + }) + }) + + describe('when CORS_ORIGINS is set to a single origin', () => { + const ORIGINS = 'https://myapp.com' + + it('should allow requests from the configured origin', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { headers: { Origin: 'https://myapp.com' } }, + { CORS_ORIGINS: ORIGINS } + ) + + expect(res.headers.get('Access-Control-Allow-Origin')).toBe( + 'https://myapp.com' + ) + }) + + it('should reject requests from a non-configured origin', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { headers: { Origin: 'https://evil.com' } }, + { CORS_ORIGINS: ORIGINS } + ) + + expect(res.headers.get('Access-Control-Allow-Origin')).toBeNull() + }) + }) + + describe('when CORS_ORIGINS is set to multiple origins', () => { + const ORIGINS = 'https://app1.com, https://app2.com, http://localhost:8787' + + it('should allow requests from any listed origin', async () => { + const app = createApp(ORIGINS) + + const res1 = await app.request( + '/api/test', + { headers: { Origin: 'https://app1.com' } }, + { CORS_ORIGINS: ORIGINS } + ) + expect(res1.headers.get('Access-Control-Allow-Origin')).toBe( + 'https://app1.com' + ) + + const res2 = await app.request( + '/api/test', + { headers: { Origin: 'https://app2.com' } }, + { CORS_ORIGINS: ORIGINS } + ) + expect(res2.headers.get('Access-Control-Allow-Origin')).toBe( + 'https://app2.com' + ) + + const res3 = await app.request( + '/api/test', + { headers: { Origin: 'http://localhost:8787' } }, + { CORS_ORIGINS: ORIGINS } + ) + expect(res3.headers.get('Access-Control-Allow-Origin')).toBe( + 'http://localhost:8787' + ) + }) + + it('should reject requests from unlisted origins', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { headers: { Origin: 'https://evil.com' } }, + { CORS_ORIGINS: ORIGINS } + ) + + expect(res.headers.get('Access-Control-Allow-Origin')).toBeNull() + }) + }) + + describe('preflight requests', () => { + const ORIGINS = 'https://myapp.com' + + it('should handle OPTIONS preflight for allowed origin', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { + method: 'OPTIONS', + headers: { + Origin: 'https://myapp.com', + 'Access-Control-Request-Method': 'POST', + 'Access-Control-Request-Headers': 'Content-Type', + }, + }, + { CORS_ORIGINS: ORIGINS } + ) + + expect(res.headers.get('Access-Control-Allow-Origin')).toBe( + 'https://myapp.com' + ) + expect(res.headers.get('Access-Control-Allow-Methods')).toContain('POST') + }) + + it('should include X-API-Key in allowed headers', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { + method: 'OPTIONS', + headers: { + Origin: 'https://myapp.com', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-API-Key', + }, + }, + { CORS_ORIGINS: ORIGINS } + ) + + expect(res.headers.get('Access-Control-Allow-Headers')).toContain( + 'X-API-Key' + ) + }) + }) + + describe('allowed methods', () => { + const ORIGINS = 'https://myapp.com' + + it('should allow GET, POST, PUT, DELETE, OPTIONS methods', async () => { + const app = createApp(ORIGINS) + const res = await app.request( + '/api/test', + { + method: 'OPTIONS', + headers: { + Origin: 'https://myapp.com', + 'Access-Control-Request-Method': 'DELETE', + }, + }, + { CORS_ORIGINS: ORIGINS } + ) + + const allowedMethods = res.headers.get('Access-Control-Allow-Methods') + expect(allowedMethods).toContain('GET') + expect(allowedMethods).toContain('POST') + expect(allowedMethods).toContain('PUT') + expect(allowedMethods).toContain('DELETE') + expect(allowedMethods).toContain('OPTIONS') + }) + }) + + describe('same-origin requests', () => { + it('should work normally without Origin header (same-origin)', async () => { + const app = createApp('https://myapp.com') + const res = await app.request('/api/test', {}, { CORS_ORIGINS: 'https://myapp.com' }) + + expect(res.status).toBe(200) + const body = await res.json() + expect(body).toEqual({ ok: true }) + }) + }) +}) diff --git a/packages/core/src/__tests__/middleware/rate-limit.test.ts b/packages/core/src/__tests__/middleware/rate-limit.test.ts new file mode 100644 index 000000000..6ebe5076b --- /dev/null +++ b/packages/core/src/__tests__/middleware/rate-limit.test.ts @@ -0,0 +1,260 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { Context, Next } from 'hono' +import { rateLimit } from '../../middleware/rate-limit' + +describe('rateLimit middleware', () => { + let mockNext: Next + let mockKv: any + + beforeEach(() => { + mockNext = vi.fn() + mockKv = { + get: vi.fn().mockResolvedValue(null), + put: vi.fn().mockResolvedValue(undefined), + } + }) + + function createMockContext(overrides: { + ip?: string + kv?: any | null + } = {}): any { + const headers: Record = {} + return { + req: { + header: vi.fn().mockImplementation((name: string) => { + if (name === 'cf-connecting-ip') return overrides.ip ?? '1.2.3.4' + return undefined + }), + }, + env: { + CACHE_KV: overrides.kv === null ? undefined : (overrides.kv ?? mockKv), + }, + header: vi.fn().mockImplementation((key: string, value: string) => { + headers[key] = value + }), + json: vi.fn().mockImplementation((body: any, status?: number) => ({ body, status })), + _headers: headers, + } + } + + describe('when KV binding is not available', () => { + it('should skip rate limiting and call next', async () => { + const ctx = createMockContext({ kv: null }) + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'test' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + expect(ctx.json).not.toHaveBeenCalled() + }) + }) + + describe('when under the rate limit', () => { + it('should allow the request and call next', async () => { + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + expect(ctx.json).not.toHaveBeenCalled() + }) + + it('should set rate limit headers', async () => { + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith('X-RateLimit-Limit', '5') + expect(ctx.header).toHaveBeenCalledWith('X-RateLimit-Remaining', '4') + expect(ctx.header).toHaveBeenCalledWith( + 'X-RateLimit-Reset', + expect.any(String) + ) + }) + + it('should store entry in KV with TTL', async () => { + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockKv.put).toHaveBeenCalledWith( + 'ratelimit:login:1.2.3.4', + expect.any(String), + expect.objectContaining({ expirationTtl: expect.any(Number) }) + ) + }) + + it('should decrement remaining count for subsequent requests', async () => { + const now = Date.now() + mockKv.get.mockResolvedValue({ count: 3, resetAt: now + 30000 }) + + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + // count goes from 3 to 4, so remaining = 5 - 4 = 1 + expect(ctx.header).toHaveBeenCalledWith('X-RateLimit-Remaining', '1') + }) + }) + + describe('when rate limit is exceeded', () => { + it('should return 429 Too Many Requests', async () => { + const now = Date.now() + mockKv.get.mockResolvedValue({ count: 5, resetAt: now + 30000 }) + + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).not.toHaveBeenCalled() + expect(ctx.json).toHaveBeenCalledWith( + { error: 'Too many requests. Please try again later.' }, + 429 + ) + }) + + it('should set Retry-After header', async () => { + const now = Date.now() + mockKv.get.mockResolvedValue({ count: 5, resetAt: now + 30000 }) + + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith('Retry-After', expect.any(String)) + expect(ctx.header).toHaveBeenCalledWith('X-RateLimit-Remaining', '0') + }) + + it('should still store the updated count in KV', async () => { + const now = Date.now() + mockKv.get.mockResolvedValue({ count: 5, resetAt: now + 30000 }) + + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockKv.put).toHaveBeenCalled() + const storedEntry = JSON.parse(mockKv.put.mock.calls[0][1]) + expect(storedEntry.count).toBe(6) + }) + }) + + describe('window expiration', () => { + it('should reset count when window has expired', async () => { + const pastTime = Date.now() - 10000 + mockKv.get.mockResolvedValue({ count: 100, resetAt: pastTime }) + + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + // Window expired, so count resets to 1 (new window) + expect(mockNext).toHaveBeenCalled() + expect(ctx.header).toHaveBeenCalledWith('X-RateLimit-Remaining', '4') + }) + }) + + describe('IP address extraction', () => { + it('should use cf-connecting-ip header', async () => { + const ctx = createMockContext({ ip: '10.0.0.1' }) + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockKv.put).toHaveBeenCalledWith( + 'ratelimit:login:10.0.0.1', + expect.any(String), + expect.any(Object) + ) + }) + + it('should fall back to x-forwarded-for', async () => { + const ctx = createMockContext() + ctx.req.header = vi.fn().mockImplementation((name: string) => { + if (name === 'x-forwarded-for') return '192.168.1.1' + return undefined + }) + + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + await middleware(ctx as Context, mockNext) + + expect(mockKv.put).toHaveBeenCalledWith( + 'ratelimit:login:192.168.1.1', + expect.any(String), + expect.any(Object) + ) + }) + + it('should fall back to "unknown" when no IP headers present', async () => { + const ctx = createMockContext() + ctx.req.header = vi.fn().mockReturnValue(undefined) + + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + await middleware(ctx as Context, mockNext) + + expect(mockKv.put).toHaveBeenCalledWith( + 'ratelimit:login:unknown', + expect.any(String), + expect.any(Object) + ) + }) + }) + + describe('key prefix isolation', () => { + it('should use different KV keys for different prefixes', async () => { + const ctx1 = createMockContext() + const ctx2 = createMockContext() + + const loginLimiter = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + const registerLimiter = rateLimit({ max: 3, windowMs: 60000, keyPrefix: 'register' }) + + await loginLimiter(ctx1 as Context, mockNext) + await registerLimiter(ctx2 as Context, mockNext) + + const keys = mockKv.put.mock.calls.map((c: any[]) => c[0]) + expect(keys).toContain('ratelimit:login:1.2.3.4') + expect(keys).toContain('ratelimit:register:1.2.3.4') + }) + }) + + describe('error handling', () => { + it('should gracefully continue on KV get error', async () => { + mockKv.get.mockRejectedValue(new Error('KV unavailable')) + + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + expect(consoleSpy).toHaveBeenCalledWith( + 'Rate limiter error (non-fatal):', + expect.any(Error) + ) + consoleSpy.mockRestore() + }) + + it('should gracefully continue on KV put error', async () => { + mockKv.put.mockRejectedValue(new Error('KV write failed')) + + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + const ctx = createMockContext() + const middleware = rateLimit({ max: 5, windowMs: 60000, keyPrefix: 'login' }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + consoleSpy.mockRestore() + }) + }) +}) diff --git a/packages/core/src/__tests__/middleware/security-headers.test.ts b/packages/core/src/__tests__/middleware/security-headers.test.ts new file mode 100644 index 000000000..1f106d9d1 --- /dev/null +++ b/packages/core/src/__tests__/middleware/security-headers.test.ts @@ -0,0 +1,152 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { Context, Next } from 'hono' +import { securityHeadersMiddleware } from '../../middleware/security-headers' + +describe('securityHeadersMiddleware', () => { + let mockNext: Next + let headers: Record + + function createMockContext(env: Record = {}): any { + headers = {} + return { + env, + header: vi.fn().mockImplementation((key: string, value: string) => { + headers[key] = value + }), + } + } + + beforeEach(() => { + mockNext = vi.fn() + }) + + it('should call next before setting headers', async () => { + const ctx = createMockContext() + const middleware = securityHeadersMiddleware() + + let nextCalledBeforeHeaders = false + ;(mockNext as any).mockImplementation(() => { + // At this point, no headers should be set yet + nextCalledBeforeHeaders = Object.keys(headers).length === 0 + }) + + await middleware(ctx as Context, mockNext) + + expect(mockNext).toHaveBeenCalled() + expect(nextCalledBeforeHeaders).toBe(true) + }) + + describe('standard security headers', () => { + it('should set X-Content-Type-Options to nosniff', async () => { + const ctx = createMockContext() + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith('X-Content-Type-Options', 'nosniff') + }) + + it('should set X-Frame-Options to SAMEORIGIN', async () => { + const ctx = createMockContext() + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith('X-Frame-Options', 'SAMEORIGIN') + }) + + it('should set Referrer-Policy to strict-origin-when-cross-origin', async () => { + const ctx = createMockContext() + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith( + 'Referrer-Policy', + 'strict-origin-when-cross-origin' + ) + }) + + it('should set Permissions-Policy to restrict camera, microphone, and geolocation', async () => { + const ctx = createMockContext() + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith( + 'Permissions-Policy', + 'camera=(), microphone=(), geolocation=()' + ) + }) + }) + + describe('HSTS (Strict-Transport-Security)', () => { + it('should set HSTS when ENVIRONMENT is not set', async () => { + const ctx = createMockContext({}) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith( + 'Strict-Transport-Security', + 'max-age=31536000; includeSubDomains' + ) + }) + + it('should set HSTS when ENVIRONMENT is "production"', async () => { + const ctx = createMockContext({ ENVIRONMENT: 'production' }) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith( + 'Strict-Transport-Security', + 'max-age=31536000; includeSubDomains' + ) + }) + + it('should NOT set HSTS when ENVIRONMENT is "development"', async () => { + const ctx = createMockContext({ ENVIRONMENT: 'development' }) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + const hstsCall = (ctx.header as any).mock.calls.find( + (call: any[]) => call[0] === 'Strict-Transport-Security' + ) + expect(hstsCall).toBeUndefined() + }) + + it('should set HSTS for staging environment', async () => { + const ctx = createMockContext({ ENVIRONMENT: 'staging' }) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledWith( + 'Strict-Transport-Security', + 'max-age=31536000; includeSubDomains' + ) + }) + }) + + describe('all headers together', () => { + it('should set exactly 5 headers in production', async () => { + const ctx = createMockContext({ ENVIRONMENT: 'production' }) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledTimes(5) + }) + + it('should set exactly 4 headers in development (no HSTS)', async () => { + const ctx = createMockContext({ ENVIRONMENT: 'development' }) + const middleware = securityHeadersMiddleware() + + await middleware(ctx as Context, mockNext) + + expect(ctx.header).toHaveBeenCalledTimes(4) + }) + }) +})