Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backend/api-gateway/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
<artifactId>spring-cloud-starter-config</artifactId>
</dependency>

<!-- Redis for rate limiting -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

<!-- JWT for admin route authorization -->
<dependency>
<groupId>io.jsonwebtoken</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.finpay.gateway.config;

import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;

/**
* Configurable rate limiting properties.
*
* rate-limiting:
* enabled: true
* default-rate: 100 # requests per window
* default-window-seconds: 60
* auth-rate: 20 # stricter for auth endpoints
* auth-window-seconds: 60
* admin-rate: 200 # higher for admin endpoints
* admin-window-seconds: 60
*/
@Configuration
@ConfigurationProperties(prefix = "rate-limiting")
@Getter
@Setter
public class RateLimitProperties {
private boolean enabled = true;

/** Default requests allowed per window for general API calls. */
private int defaultRate = 100;
private int defaultWindowSeconds = 60;

/** Stricter limit for auth endpoints (login/register). */
private int authRate = 20;
private int authWindowSeconds = 60;

/** Higher limit for admin endpoints. */
private int adminRate = 200;
private int adminWindowSeconds = 60;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package com.finpay.gateway.filter;

import com.finpay.gateway.config.RateLimitProperties;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.List;

/**
* Redis-backed sliding-window rate limiter.
*
* Runs before all other gateway filters. Uses a Lua script to
* atomically increment a per-client counter in Redis and check
* against the configured limit. Different rate tiers apply to
* auth, admin, and general API endpoints.
*
* The client key is derived from the X-User-Id header (set by
* AdminAuthFilter for authenticated users) or the client IP
* for anonymous requests.
*/
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
@Slf4j
public class RateLimitFilter implements Filter {

private final StringRedisTemplate redisTemplate;
private final RateLimitProperties properties;
private final DefaultRedisScript<Long> rateLimitScript;

public RateLimitFilter(StringRedisTemplate redisTemplate, RateLimitProperties properties) {
this.redisTemplate = redisTemplate;
this.properties = properties;

// Lua script: sliding-window counter using a sorted set
// Returns 1 if allowed, 0 if rate limit exceeded
String lua = """
local key = KEYS[1]
local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local window_start = now - window
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
local count = redis.call('ZCARD', key)
if count < limit then
redis.call('ZADD', key, now, now .. '-' .. math.random(1000000))
redis.call('EXPIRE', key, window)
return 1
end
return 0
""";
this.rateLimitScript = new DefaultRedisScript<>(lua, Long.class);
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {

if (!properties.isEnabled()) {
chain.doFilter(request, response);
return;
}

HttpServletRequest httpRequest = (HttpServletRequest) request;
HttpServletResponse httpResponse = (HttpServletResponse) response;

String path = httpRequest.getRequestURI();
String clientKey = resolveClientKey(httpRequest);
RateTier tier = resolveTier(path);

String redisKey = "gateway:ratelimit:" + tier.name().toLowerCase() + ":" + clientKey;

try {
long nowMillis = System.currentTimeMillis();
Long allowed = redisTemplate.execute(
rateLimitScript,
List.of(redisKey),
String.valueOf(tier.windowSeconds()),
String.valueOf(tier.maxRequests()),
String.valueOf(nowMillis)
);

if (allowed != null && allowed == 0L) {
log.warn("Rate limit exceeded for client: {} on path: {} (tier: {})", clientKey, path, tier.name());
httpResponse.setStatus(429);
httpResponse.setContentType("application/json");
httpResponse.setHeader("Retry-After", String.valueOf(tier.windowSeconds()));
httpResponse.getWriter().write(
"{\"status\":429,\"error\":\"Too Many Requests\",\"message\":\"Rate limit exceeded. Try again later.\"}");
return;
}
} catch (Exception e) {
// If Redis is down, allow the request (fail-open) to avoid blocking all traffic
log.warn("Rate limiter unavailable, allowing request: {}", e.getMessage());
}

chain.doFilter(request, response);
}

private String resolveClientKey(HttpServletRequest request) {
// Prefer authenticated user ID
String userId = request.getHeader("X-User-Id");
if (userId != null && !userId.isBlank()) {
return "user:" + userId;
}
// Fall back to client IP
String forwarded = request.getHeader("X-Forwarded-For");
if (forwarded != null && !forwarded.isBlank()) {
return "ip:" + forwarded.split(",")[0].trim();
}
return "ip:" + request.getRemoteAddr();
}

private RateTier resolveTier(String path) {
if (path.startsWith("/api/v1/auth/")) {
return new RateTier("AUTH", properties.getAuthRate(), properties.getAuthWindowSeconds());
}
if (path.startsWith("/api/v1/admin/")) {
return new RateTier("ADMIN", properties.getAdminRate(), properties.getAdminWindowSeconds());
}
return new RateTier("DEFAULT", properties.getDefaultRate(), properties.getDefaultWindowSeconds());
}

private record RateTier(String name, int maxRequests, int windowSeconds) {}
}
15 changes: 15 additions & 0 deletions backend/api-gateway/src/main/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ spring:
enabled: true
lower-case-service-id: true

data:
redis:
host: ${REDIS_HOST:localhost}
port: ${REDIS_PORT:6379}

# Routes are defined programmatically in GatewayRoutesConfig.java
# CORS is handled by CorsConfig.java

Expand Down Expand Up @@ -89,3 +94,13 @@ springdoc:
- name: Notification Service
url: /api-docs/notification-service
urls-primary-name: Auth Service

# Redis-backed rate limiting
rate-limiting:
enabled: true
default-rate: 100
default-window-seconds: 60
auth-rate: 20
auth-window-seconds: 60
admin-rate: 200
admin-window-seconds: 60
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package com.finpay.gateway.filter;

import com.finpay.gateway.config.RateLimitProperties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentMatchers;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
@DisplayName("RateLimitFilter Unit Tests")
class RateLimitFilterTest {

@Mock private StringRedisTemplate redisTemplate;
@Spy private RateLimitProperties properties;

@InjectMocks
private RateLimitFilter rateLimitFilter;

@BeforeEach
void setUp() {
properties.setEnabled(true);
properties.setDefaultRate(100);
properties.setDefaultWindowSeconds(60);
properties.setAuthRate(20);
properties.setAuthWindowSeconds(60);
}

@Nested
@DisplayName("Rate Limiting")
class RateLimitingTests {

@Test
@DisplayName("should allow request when under rate limit")
void shouldAllowRequestUnderLimit() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/v1/users/me");
request.setRemoteAddr("192.168.1.1");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

when(redisTemplate.execute(ArgumentMatchers.<RedisScript<Long>>any(), anyList(), any(), any(), any()))
.thenReturn(1L);

rateLimitFilter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).isNotEqualTo(429);
}

@Test
@DisplayName("should block request when rate limit exceeded")
void shouldBlockWhenRateLimitExceeded() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/v1/users/me");
request.setRemoteAddr("192.168.1.1");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

when(redisTemplate.execute(ArgumentMatchers.<RedisScript<Long>>any(), anyList(), any(), any(), any()))
.thenReturn(0L);

rateLimitFilter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).isEqualTo(429);
assertThat(response.getContentType()).isEqualTo("application/json");
assertThat(response.getHeader("Retry-After")).isEqualTo("60");
}

@Test
@DisplayName("should allow request when rate limiting is disabled")
void shouldAllowWhenDisabled() throws Exception {
properties.setEnabled(false);

MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/v1/users/me");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

rateLimitFilter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).isNotEqualTo(429);
verifyNoInteractions(redisTemplate);
}

@Test
@DisplayName("should fail open when Redis is unavailable")
void shouldFailOpenWhenRedisDown() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/v1/users/me");
request.setRemoteAddr("192.168.1.1");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

when(redisTemplate.execute(ArgumentMatchers.<RedisScript<Long>>any(), anyList(), any(), any(), any()))
.thenThrow(new RuntimeException("Redis connection refused"));

rateLimitFilter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).isNotEqualTo(429);
}

@Test
@DisplayName("should use user ID as key when X-User-Id header is present")
void shouldUseUserIdForAuthenticatedRequests() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/v1/wallets/me");
request.addHeader("X-User-Id", "user-123");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

when(redisTemplate.execute(ArgumentMatchers.<RedisScript<Long>>any(), anyList(), any(), any(), any()))
.thenReturn(1L);

rateLimitFilter.doFilter(request, response, filterChain);

verify(redisTemplate).execute(
ArgumentMatchers.<RedisScript<Long>>any(),
eq(List.of("gateway:ratelimit:default:user:user-123")),
any(), any(), any()
);
}

@Test
@DisplayName("should apply auth tier rate limit for auth endpoints")
void shouldApplyAuthTierForAuthEndpoints() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/api/v1/auth/login");
request.setRemoteAddr("10.0.0.1");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

when(redisTemplate.execute(ArgumentMatchers.<RedisScript<Long>>any(), anyList(), any(), any(), any()))
.thenReturn(1L);

rateLimitFilter.doFilter(request, response, filterChain);

verify(redisTemplate).execute(
ArgumentMatchers.<RedisScript<Long>>any(),
eq(List.of("gateway:ratelimit:auth:ip:10.0.0.1")),
any(), any(), any()
);
}
}
}
6 changes: 6 additions & 0 deletions backend/auth-service/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@
<artifactId>spring-kafka</artifactId>
</dependency>

<!-- Redis for token blocklist and user session caching -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ public ResponseEntity<Map<String, String>> logout(
HttpServletRequest request,
HttpServletResponse response) {
String refreshToken = extractRefreshTokenFromCookie(request);
String accessToken = extractAccessTokenFromCookie(request);
if (accessToken == null) {
accessToken = extractTokenFromHeader(request);
}
if (refreshToken != null) {
authService.logout(refreshToken);
authService.logout(refreshToken, accessToken);
}
cookieService.clearAuthCookies(response);
return ResponseEntity.ok(Map.of("message", "Successfully logged out"));
Expand Down
Loading
Loading