Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions everyrow-mcp/src/everyrow_mcp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ async def verify_token(self, token: str) -> AccessToken | None:
payload = self._decode_jwt(token, signing_key)

if await self._is_revoked(token):
logger.debug("Token is revoked")
logger.warning("Revoked token presented")
return None

sub = payload.get("sub")
if not sub:
logger.debug("JWT missing required 'sub' claim")
logger.warning("JWT missing required 'sub' claim")
return None
return AccessToken(
token=token,
Expand All @@ -116,7 +116,7 @@ async def verify_token(self, token: str) -> AccessToken | None:
logger.warning("JWKS fetch timed out (10s)")
return None
except pyjwt.PyJWTError:
logger.debug("JWT verification failed")
logger.warning("JWT verification failed")
return None


Expand Down Expand Up @@ -212,7 +212,7 @@ async def _check_rate_limit(self, action: str, client_ip: str) -> None:
rl_key = build_key("ratelimit", action, client_ip)
async with self._redis.pipeline() as pipe:
pipe.incr(rl_key)
pipe.expire(rl_key, settings.registration_rate_window, nx=True)
pipe.expire(rl_key, settings.registration_rate_window)
count, _ = await pipe.execute()
if count > settings.registration_rate_limit:
raise ValueError(f"{action.title()} rate limit exceeded")
Expand All @@ -232,6 +232,7 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None
time=settings.client_registration_ttl,
value=client_info.model_dump_json(),
)
logger.info("Registered new OAuth client client_id=%s", client_info.client_id)

@staticmethod
def _supabase_redirect_url(supabase_verifier: str) -> str:
Expand Down Expand Up @@ -361,11 +362,16 @@ async def authorize(
return f"{settings.mcp_server_url}/auth/start/{state}"

async def handle_start(self, request: Request) -> RedirectResponse:
state = request.path_params["state"]
pending = await self._validate_auth_request(
request, "start", request.path_params.get("state")
request, "start", state, consume=True
)
# Re-store so the callback can still find it
await self._redis.setex(
name=build_key("pending", state),
time=settings.pending_auth_ttl,
value=pending.model_dump_json(),
)

state = request.path_params.get("state", "")
response = RedirectResponse(url=pending.supabase_redirect_url, status_code=302)
response.set_cookie(
key="__Host-mcp_auth_state",
Expand Down Expand Up @@ -434,9 +440,9 @@ async def load_authorization_code(
if code_obj.expires_at and code_obj.expires_at < time.time():
return None
if code_obj.client_id != client.client_id:
# Re-store so the legitimate client can still use it.
# Re-store so the legitimate client can still use it (NX prevents overwrite).
remaining = max(1, int((code_obj.expires_at or 0) - time.time()))
await self._redis.setex(key, remaining, code_data_encrypted)
await self._redis.set(key, code_data_encrypted, ex=remaining, nx=True)
return None
return code_obj

Expand Down Expand Up @@ -476,6 +482,7 @@ async def exchange_authorization_code(
authorization_code: EveryRowAuthorizationCode,
) -> OAuthToken:
assert client.client_id is not None
logger.info("Token exchange successful user=%s", authorization_code.client_id)
return await self._issue_token_response(
access_token=authorization_code.supabase_access_token,
client_id=client.client_id,
Expand All @@ -499,8 +506,10 @@ async def load_refresh_token(
return None
rt = EveryRowRefreshToken.model_validate_json(decrypt_value(data_encrypted))
if rt.client_id != client.client_id:
# Re-store so the legitimate client can still use it.
await self._redis.setex(key, settings.refresh_token_ttl, data_encrypted)
# Re-store so the legitimate client can still use it (NX prevents overwrite).
await self._redis.set(
key, data_encrypted, ex=settings.refresh_token_ttl, nx=True
)
return None
return rt

Expand All @@ -524,6 +533,7 @@ async def exchange_refresh_token(
)
raise
assert client.client_id is not None
logger.info("Token refresh successful user=%s", client.client_id)
return await self._issue_token_response(
access_token=supa_tokens.access_token,
client_id=client.client_id,
Expand Down
45 changes: 42 additions & 3 deletions everyrow-mcp/src/everyrow_mcp/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import logging
from functools import lru_cache
from urllib.parse import urlparse

from pydantic import Field, PositiveInt, field_validator
from pydantic import Field, PositiveInt, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

logger = logging.getLogger(__name__)


class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
Expand Down Expand Up @@ -108,6 +112,15 @@ class Settings(BaseSettings):
description="Maximum response size when fetching CSV from a URL (50 MB).",
)

upload_rate_limit: PositiveInt = Field(
default=20,
description="Max uploads per user per rate window",
)
upload_rate_window: PositiveInt = Field(
default=3600,
description="Upload rate limit sliding window in seconds (1 hour)",
)

everyrow_api_key: str | None = Field(default=None, repr=False)

@property
Expand All @@ -120,8 +133,34 @@ def is_stdio(self) -> bool:

@field_validator("mcp_server_url", "supabase_url")
@classmethod
def _strip_url_slashes(cls, v: str) -> str:
return v.rstrip("/")
def _validate_url(cls, v: str) -> str:
v = v.rstrip("/")
if not v:
return v
parsed = urlparse(v)
host = (parsed.hostname or "").lower()
is_local = host in ("localhost", "127.0.0.1", "::1")
if not is_local and parsed.scheme != "https":
raise ValueError(
f"Non-localhost URLs must use https:// (got {parsed.scheme}://)"
)
return v

@model_validator(mode="after")
def _require_redis_ssl_for_remote(self) -> Settings:
host = (self.redis_host or "").lower()
is_local = host in ("localhost", "127.0.0.1", "::1", "")
if not is_local and not self.redis_ssl:
if self.is_http:
raise ValueError(
f"Redis host {self.redis_host} is remote but redis_ssl=False. "
"Enable redis_ssl for non-localhost Redis in HTTP mode."
)
logger.warning(
"Redis host %s is remote but redis_ssl=False — traffic is unencrypted.",
self.redis_host,
)
return self


@lru_cache
Expand Down
42 changes: 28 additions & 14 deletions everyrow-mcp/src/everyrow_mcp/http_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from __future__ import annotations

import logging
import time as _time
from urllib.parse import urlparse

from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.server import lifespan_wrapper
Expand Down Expand Up @@ -70,16 +72,20 @@ def configure_http_mode(
mcp.settings.host = host
mcp.settings.port = port

if not no_auth and not settings.upload_secret:
if not settings.upload_secret:
raise RuntimeError(
"UPLOAD_SECRET must be set in HTTP mode for HMAC signing. "
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
if settings.is_http and not no_auth and not settings.redis_password:
logger.warning(
if not no_auth and not settings.redis_password:
raise RuntimeError(
"REDIS_PASSWORD is not set — Redis is unauthenticated. "
"Set REDIS_PASSWORD for production deployments."
)
if no_auth and not settings.redis_password:
logger.warning(
"REDIS_PASSWORD is not set — acceptable for local development only."
)

_register_widgets(mcp, mcp_server_url)
_register_routes(mcp, redis_client, auth_provider if not no_auth else None)
Expand Down Expand Up @@ -164,23 +170,31 @@ def _ui_csp(connect_domains: list[str]) -> dict[str, str | list[str]]:


class _RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log every inbound request and its response status at DEBUG level."""
"""Log inbound requests at INFO level with method, path, status, and timing."""

async def dispatch(self, request, call_next):
has_auth = "authorization" in request.headers
logger.debug(
"INCOMING %s %s | Host: %s | Auth: %s",
request.method,
request.url.path,
request.headers.get("host", "?"),
"present" if has_auth else "none",
)
# Skip health check requests — k8s probes hit these every ~10s.
if request.url.path == "/health":
return await call_next(request)

start = _time.monotonic()
response = await call_next(request)
logger.debug(
"RESPONSE %s %s -> %s",
elapsed_ms = (_time.monotonic() - start) * 1000

# Extract user_id from the access token if available.
try:
access_token = get_access_token()
user_id = access_token.client_id if access_token else None
except Exception:
user_id = None

logger.info(
"HTTP %s %s -> %d (%.0fms) user=%s",
request.method,
request.url.path,
response.status_code,
elapsed_ms,
user_id or "anon",
)
return response

Expand Down
78 changes: 14 additions & 64 deletions everyrow-mcp/src/everyrow_mcp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from __future__ import annotations

import logging
import threading
import time

from redis.asyncio import Redis
from redis.exceptions import RedisError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand Down Expand Up @@ -46,11 +44,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
"""Redis-based fixed-window rate limiter per client IP.

Returns 429 with ``Retry-After`` header when the limit is exceeded.
Falls back to an in-memory counter when Redis is unavailable.
Redis is a hard dependency in HTTP mode — failures propagate.
"""

_MEM_CLEANUP_INTERVAL = 100 # clean up stale entries every N requests

def __init__(
self,
app,
Expand All @@ -63,42 +59,6 @@ def __init__(
self._redis = redis
self._max_requests = max_requests
self._window_seconds = window_seconds
# In-memory fallback: {key: (count, window_start)}
self._mem_counters: dict[str, tuple[int, float]] = {}
self._mem_lock = threading.Lock()
self._mem_request_count = 0

def _check_in_memory(self, ip: str) -> bool:
"""In-memory fixed-window rate check. Returns True if the request should be blocked."""
now = time.time()
window_start = (int(now) // self._window_seconds) * self._window_seconds
key = f"{ip}:{window_start}"

with self._mem_lock:
self._mem_request_count += 1
if self._mem_request_count % self._MEM_CLEANUP_INTERVAL == 0:
self._cleanup_mem_counters(now)

count, ws = self._mem_counters.get(key, (0, window_start))
count += 1
self._mem_counters[key] = (count, ws)
return count > self._max_requests

_MAX_MEM_ENTRIES = 50_000 # hard cap to prevent unbounded memory growth

def _cleanup_mem_counters(self, now: float) -> None:
"""Evict stale entries. Must be called under _mem_lock."""
cutoff = now - self._window_seconds * 2
stale = [k for k, (_, ws) in self._mem_counters.items() if ws < cutoff]
for k in stale:
del self._mem_counters[k]
# Hard cap: if still too many entries, evict oldest
if len(self._mem_counters) > self._MAX_MEM_ENTRIES:
sorted_keys = sorted(
self._mem_counters, key=lambda k: self._mem_counters[k][1]
)
for k in sorted_keys[: len(self._mem_counters) - self._MAX_MEM_ENTRIES]:
del self._mem_counters[k]

async def dispatch(self, request: Request, call_next) -> Response:
if request.url.path == "/health":
Expand All @@ -113,30 +73,20 @@ async def dispatch(self, request: Request, call_next) -> Response:
window_id = str(int(time.time()) // self._window_seconds)
key = build_key("rate", client_ip, window_id)

try:
async with self._redis.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, self._window_seconds, nx=True)
count, _ = await pipe.execute()

if count > self._max_requests:
ttl = await self._redis.ttl(key)
retry_after = max(ttl, 1)
return JSONResponse(
{"detail": "Rate limit exceeded"},
status_code=429,
headers={"Retry-After": str(retry_after)},
)
except (RedisError, OSError):
logger.warning(
"Rate-limit check failed (Redis unavailable), using in-memory fallback"
async with self._redis.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, self._window_seconds)
count, _ = await pipe.execute()

if count > self._max_requests:
logger.warning("Rate limit exceeded for IP %s", client_ip)
ttl = await self._redis.ttl(key)
retry_after = max(ttl, 1)
return JSONResponse(
{"detail": "Rate limit exceeded"},
status_code=429,
headers={"Retry-After": str(retry_after)},
)
if self._check_in_memory(client_ip):
return JSONResponse(
{"detail": "Rate limit exceeded"},
status_code=429,
headers={"Retry-After": str(self._window_seconds)},
)

return await call_next(request)

Expand Down
13 changes: 13 additions & 0 deletions everyrow-mcp/src/everyrow_mcp/redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def encrypt_value(value: str) -> str:
"""Encrypt a string value for Redis storage. No-op without UPLOAD_SECRET."""
f = _get_fernet()
if f is None:
if settings.is_http:
raise RuntimeError(
"UPLOAD_SECRET must be set in HTTP mode — cannot store sensitive values in plaintext."
)
return value
return f.encrypt(value.encode()).decode()

Expand All @@ -79,6 +83,10 @@ def decrypt_value(value: str) -> str:
"""Decrypt a string value from Redis. No-op without UPLOAD_SECRET."""
f = _get_fernet()
if f is None:
if settings.is_http:
raise RuntimeError(
"UPLOAD_SECRET must be set in HTTP mode — cannot read encrypted values without the key."
)
return value
return f.decrypt(value.encode()).decode()

Expand Down Expand Up @@ -278,6 +286,11 @@ async def store_upload_meta(upload_id: str, meta_json: str, ttl: int) -> None:
await get_redis_client().setex(build_key("upload", upload_id), ttl, meta_json)


async def get_upload_meta(upload_id: str) -> str | None:
"""Read upload metadata without consuming it."""
return await get_redis_client().get(build_key("upload", upload_id))


async def pop_upload_meta(upload_id: str) -> str | None:
"""Atomically get and delete upload metadata (prevents replay)."""
key = build_key("upload", upload_id)
Expand Down
Loading