From 3ae2cb047bb72599024145f3c39e9f11e1e2d70c Mon Sep 17 00:00:00 2001 From: Chibey-max Date: Fri, 24 Apr 2026 00:36:19 +0100 Subject: [PATCH] feat(#259): add API rate limiting - Implement token-bucket rate limiter using governor crate - Add per-IP rate limit (100 req/min) with fallback when no API key - Add per-API-key rate limit (1000 req/min) via x-api-key header - Return 429 Too Many Requests with JSON error body on limit exceeded - Use DashMap for thread-safe concurrent limiter storage --- api-server/Cargo.toml | 5 +++ api-server/src/main.rs | 12 ++++-- api-server/src/rate_limit.rs | 76 ++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 api-server/src/rate_limit.rs diff --git a/api-server/Cargo.toml b/api-server/Cargo.toml index c008484..13611cf 100644 --- a/api-server/Cargo.toml +++ b/api-server/Cargo.toml @@ -14,3 +14,8 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" utoipa = { version = "4", features = ["axum_extras"] } utoipa-swagger-ui = { version = "7", features = ["axum"] } +governor = "0.6" +dashmap = "6" +thiserror = "1" +tower = "0.4" +once_cell = "1" diff --git a/api-server/src/main.rs b/api-server/src/main.rs index b6b0ada..1c61ef0 100644 --- a/api-server/src/main.rs +++ b/api-server/src/main.rs @@ -1,8 +1,11 @@ -use axum::{routing::get, routing::post, Router}; +use axum::{extract::ConnectInfo, routing::get, routing::post, Router}; +use axum::middleware; +use std::net::SocketAddr; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; mod handlers; +mod rate_limit; mod schemas; #[derive(OpenApi)] @@ -62,10 +65,13 @@ async fn main() { .route("/swap/{swap_id}/reveal", post(handlers::reveal_key)) .route("/swap/{swap_id}/cancel", post(handlers::cancel_swap)) .route("/swap/{swap_id}/cancel-expired", post(handlers::cancel_expired_swap)) - .route("/swap/{swap_id}", get(handlers::get_swap)); + .route("/swap/{swap_id}", get(handlers::get_swap)) + .layer(middleware::from_fn(rate_limit::layer)); let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap(); println!("Swagger UI -> http://localhost:8080/docs"); println!("OpenAPI JSON -> http://localhost:8080/openapi.json"); - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app.into_make_service_with_connect_info::()) + .await + .unwrap(); } diff --git a/api-server/src/rate_limit.rs b/api-server/src/rate_limit.rs new file mode 100644 index 0000000..7bfbef6 --- /dev/null +++ b/api-server/src/rate_limit.rs @@ -0,0 +1,76 @@ +use std::sync::Arc; +use std::net::SocketAddr; +use axum::{ + extract::{ConnectInfo, Request}, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use dashmap::DashMap; +use governor::{ + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, + Quota, RateLimiter, +}; +use once_cell::sync::Lazy; + +/// Per-IP rate limit: 100 requests per minute. +static IP_LIMITERS: Lazy>>> = + Lazy::new(DashMap::new); + +/// Per-API-key rate limit: 1000 requests per minute. +static KEY_LIMITERS: Lazy>>> = + Lazy::new(DashMap::new); + +/// Custom rate-limit-exceeded response. +#[derive(Debug)] +pub struct RateLimitExceeded; + +impl IntoResponse for RateLimitExceeded { + fn into_response(self) -> Response { + let body = serde_json::json!({ "error": "Rate limit exceeded" }); + (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response() + } +} + +/// Axum middleware that enforces token-bucket rate limits. +/// Checks per-API-key first (if `x-api-key` header present), then falls back to per-IP. +pub async fn layer( + ConnectInfo(addr): ConnectInfo, + req: Request, + next: Next, +) -> Result { + // Prefer API-key based limiting + if let Some(api_key) = req.headers().get("x-api-key").and_then(|v| v.to_str().ok()) { + let limiter = KEY_LIMITERS + .entry(api_key.to_string()) + .or_insert_with(|| { + Arc::new(RateLimiter::direct(Quota::per_minute( + std::num::NonZeroU32::new(1000).unwrap(), + ))) + }) + .clone(); + + if limiter.check().is_err() { + return Err(RateLimitExceeded); + } + } else { + // Fall back to IP-based limiting + let ip = addr.ip().to_string(); + let limiter = IP_LIMITERS + .entry(ip) + .or_insert_with(|| { + Arc::new(RateLimiter::direct(Quota::per_minute( + std::num::NonZeroU32::new(100).unwrap(), + ))) + }) + .clone(); + + if limiter.check().is_err() { + return Err(RateLimitExceeded); + } + } + + Ok(next.run(req).await) +} +