Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion cashu/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,30 @@ class MintBackends(MintSettings):

class MintLimits(MintSettings):
mint_rate_limit: bool = Field(
default=False, title="Rate limit", description="IP-based rate limiter."
default=True,
title="Rate limit",
description="IP-based rate limiter.",
)
mint_rate_limit_proxy_trust: bool = Field(
default=True,
title="Trust proxy headers for rate limiting",
description=(
"Extract client IP from proxy headers (X-Forwarded-For,"
" CF-Connecting-IP) for rate limiting. Enable this if the mint"
" is behind a reverse proxy (Caddy, nginx) or CDN (Cloudflare)."
" Disable if the mint is directly exposed to the internet to"
" prevent clients from spoofing their IP via headers."
),
)
mint_forwarded_allow_ips: str = Field(
default="127.0.0.1",
title="Forwarded-allow IPs",
description=(
"Comma-separated list of proxy IPs to trust for X-Forwarded-For"
" headers at the uvicorn level, or '*' to trust all."
" Only relevant when mint_rate_limit_proxy_trust is enabled."
" Set to '*' if your proxy's IP is dynamic or unknown."
),
)
mint_global_rate_limit_per_minute: int = Field(
default=60,
Expand Down
33 changes: 30 additions & 3 deletions cashu/mint/limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,38 @@


def _rate_limit_exceeded_handler(request: Request, exc: Exception) -> JSONResponse:
remote_address = get_remote_address(request)
remote_address = _get_client_ip(request)
logger.warning(
f"Rate limit {settings.mint_global_rate_limit_per_minute}/minute exceeded: {remote_address}"
f"Rate limit {settings.mint_global_rate_limit_per_minute}/minute exceeded:"
f" {remote_address}"
)
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit exceeded."},
)


def _get_client_ip(request: Request) -> str:
"""Extract the client IP from the request, checking proxy headers first
if configured to trust them.

Header priority (when proxy trust is enabled):
1. CF-Connecting-IP – set by Cloudflare
2. X-Forwarded-For – set by most reverse proxies (first entry)
3. request.client – direct connection IP (fallback)
"""
if settings.mint_rate_limit_proxy_trust:
cf_ip = request.headers.get("cf-connecting-ip")
if cf_ip:
return cf_ip.strip()
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
return get_remote_address(request)


def get_remote_address_excluding_local(request: Request) -> str:
remote_address = get_remote_address(request)
remote_address = _get_client_ip(request)
if remote_address == "127.0.0.1":
return ""
return remote_address
Expand Down Expand Up @@ -76,6 +96,13 @@ def get_ws_remote_address(ws: WebSocket) -> str:
Returns:
str: The ip address for the current websocket.
"""
if settings.mint_rate_limit_proxy_trust:
cf_ip = ws.headers.get("cf-connecting-ip")
if cf_ip:
return cf_ip.strip()
xff = ws.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
if not ws.client or not ws.client.host:
return "127.0.0.1"

Expand Down
2 changes: 2 additions & 0 deletions cashu/mint/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def main(
host=host,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
proxy_headers=settings.mint_rate_limit_proxy_trust,
forwarded_allow_ips=settings.mint_forwarded_allow_ips,
**d, # type: ignore
)

Expand Down
211 changes: 211 additions & 0 deletions tests/mint/test_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from unittest.mock import patch

import pytest
from starlette.requests import Request
from starlette.websockets import WebSocket

from cashu.core.settings import settings
from cashu.mint.limit import (
_get_client_ip,
_rate_limit_exceeded_handler,
assert_limit,
get_remote_address_excluding_local,
get_ws_remote_address,
limit_websocket,
)


def test_get_client_ip_proxy_trust_enabled():
settings.mint_rate_limit_proxy_trust = True

# Test CF-Connecting-IP
scope = {
"type": "http",
"headers": [(b"cf-connecting-ip", b"203.0.113.1")],
"client": ("127.0.0.1", 8000),
}
request = Request(scope)
assert _get_client_ip(request) == "203.0.113.1"

# Test X-Forwarded-For
scope = {
"type": "http",
"headers": [(b"x-forwarded-for", b"203.0.113.2, 198.51.100.1")],
"client": ("127.0.0.1", 8000),
}
request = Request(scope)
assert _get_client_ip(request) == "203.0.113.2"

# Test fallback to client IP
scope = {
"type": "http",
"headers": [],
"client": ("203.0.113.3", 8000),
}
request = Request(scope)
assert _get_client_ip(request) == "203.0.113.3"


def test_get_client_ip_proxy_trust_disabled():
settings.mint_rate_limit_proxy_trust = False

# Test headers are ignored
scope = {
"type": "http",
"headers": [
(b"cf-connecting-ip", b"203.0.113.1"),
(b"x-forwarded-for", b"203.0.113.2"),
],
"client": ("203.0.113.3", 8000),
}
request = Request(scope)
assert _get_client_ip(request) == "203.0.113.3"


def test_get_ws_remote_address_proxy_trust_enabled():
settings.mint_rate_limit_proxy_trust = True

# Test CF-Connecting-IP
scope = {
"type": "websocket",
"headers": [(b"cf-connecting-ip", b"203.0.113.1")],
"client": ("127.0.0.1", 8000),
}

async def dummy_receive():
pass

async def dummy_send(msg):
pass

ws = WebSocket(scope, dummy_receive, dummy_send)
assert get_ws_remote_address(ws) == "203.0.113.1"

# Test X-Forwarded-For
scope = {
"type": "websocket",
"headers": [(b"x-forwarded-for", b"203.0.113.2, 198.51.100.1")],
"client": ("127.0.0.1", 8000),
}
ws = WebSocket(scope, dummy_receive, dummy_send)
assert get_ws_remote_address(ws) == "203.0.113.2"

# Test fallback to client IP
scope = {
"type": "websocket",
"headers": [],
"client": ("203.0.113.3", 8000),
}
ws = WebSocket(scope, dummy_receive, dummy_send)
assert get_ws_remote_address(ws) == "203.0.113.3"


def test_get_ws_remote_address_proxy_trust_disabled():
settings.mint_rate_limit_proxy_trust = False

# Test headers are ignored
scope = {
"type": "websocket",
"headers": [
(b"cf-connecting-ip", b"203.0.113.1"),
(b"x-forwarded-for", b"203.0.113.2"),
],
"client": ("203.0.113.3", 8000),
}

async def dummy_receive():
pass

async def dummy_send(msg):
pass

ws = WebSocket(scope, dummy_receive, dummy_send)
assert get_ws_remote_address(ws) == "203.0.113.3"

# Test no client host
scope = {
"type": "websocket",
"headers": [],
"client": None,
}
ws = WebSocket(scope, dummy_receive, dummy_send)
assert get_ws_remote_address(ws) == "127.0.0.1"


def test_rate_limit_exceeded_handler():
settings.mint_rate_limit_proxy_trust = True
scope = {
"type": "http",
"headers": [(b"cf-connecting-ip", b"203.0.113.1")],
"client": ("127.0.0.1", 8000),
}
request = Request(scope)
response = _rate_limit_exceeded_handler(request, Exception("Test"))
assert response.status_code == 429
assert response.body == b'{"detail":"Rate limit exceeded."}'


def test_get_remote_address_excluding_local():
settings.mint_rate_limit_proxy_trust = True
# Test remote
scope = {
"type": "http",
"headers": [(b"cf-connecting-ip", b"203.0.113.1")],
"client": ("127.0.0.1", 8000),
}
request = Request(scope)
assert get_remote_address_excluding_local(request) == "203.0.113.1"

# Test local
scope = {
"type": "http",
"headers": [],
"client": ("127.0.0.1", 8000),
}
request = Request(scope)
assert get_remote_address_excluding_local(request) == ""


def test_limit_websocket():
settings.mint_rate_limit_proxy_trust = True

async def dummy_receive():
pass

async def dummy_send(msg):
pass

# Local shouldn't limit
scope_local = {
"type": "websocket",
"headers": [],
"client": ("127.0.0.1", 8000),
}
ws_local = WebSocket(scope_local, dummy_receive, dummy_send)

# This shouldn't raise exception
limit_websocket(ws_local)

# Remote should limit
scope_remote = {
"type": "websocket",
"headers": [],
"client": ("203.0.113.1", 8000),
}
ws_remote = WebSocket(scope_remote, dummy_receive, dummy_send)

with patch("cashu.mint.limit.assert_limit") as mock_assert:
limit_websocket(ws_remote)
mock_assert.assert_called_once_with("203.0.113.1")


def test_assert_limit():
# It uses a global slowapi Limiter
with patch("cashu.mint.limit.limiter._limiter.hit") as mock_hit:
mock_hit.return_value = False
with pytest.raises(Exception, match="Rate limit exceeded"):
assert_limit("1.2.3.4", limit=10)

mock_hit.return_value = True
# Shouldn't raise
assert_limit("1.2.3.4", limit=10)
10 changes: 5 additions & 5 deletions tests/mint/test_mint_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def hit(self, item, identifier):


def test_get_ws_remote_address_defaults_to_localhost():
ws_without_client = SimpleNamespace(client=None)
ws_without_host = SimpleNamespace(client=SimpleNamespace(host=None))
ws_remote = SimpleNamespace(client=SimpleNamespace(host="198.51.100.5"))
ws_without_client = SimpleNamespace(client=None, headers={})
ws_without_host = SimpleNamespace(client=SimpleNamespace(host=None), headers={})
ws_remote = SimpleNamespace(client=SimpleNamespace(host="198.51.100.5"), headers={})

assert limit.get_ws_remote_address(ws_without_client) == "127.0.0.1"
assert limit.get_ws_remote_address(ws_without_host) == "127.0.0.1"
Expand All @@ -71,8 +71,8 @@ def test_limit_websocket_skips_localhost_and_limits_remote(monkeypatch):
limit, "assert_limit", lambda identifier: called.append(identifier)
)

local_ws = SimpleNamespace(client=SimpleNamespace(host="127.0.0.1"))
remote_ws = SimpleNamespace(client=SimpleNamespace(host="203.0.113.7"))
local_ws = SimpleNamespace(client=SimpleNamespace(host="127.0.0.1"), headers={})
remote_ws = SimpleNamespace(client=SimpleNamespace(host="203.0.113.7"), headers={})

limit.limit_websocket(local_ws)
limit.limit_websocket(remote_ws)
Expand Down
Loading