diff --git a/everyrow-mcp/deploy/.dockerignore b/everyrow-mcp/deploy/.dockerignore new file mode 100644 index 00000000..9996afc3 --- /dev/null +++ b/everyrow-mcp/deploy/.dockerignore @@ -0,0 +1,15 @@ +.git +.github +.venv +__pycache__ +*.pyc +.env +*.env +!.env.example +.claude/ +.vscode/ +*.egg-info +dist/ +build/ +node_modules/ +docs-site/ diff --git a/everyrow-mcp/deploy/.env.example b/everyrow-mcp/deploy/.env.example index a13cc95b..9462b58e 100644 --- a/everyrow-mcp/deploy/.env.example +++ b/everyrow-mcp/deploy/.env.example @@ -1,4 +1,5 @@ EVERYROW_API_KEY=sk-cho-your-api-key-here SUPABASE_URL=https://your-project.supabase.co SUPABASE_ANON_KEY=sb_publishable_your-anon-key-here -REDIS_ENCRYPTION_KEY=generate-with-python-cryptography-fernet +MCP_SERVER_URL=https://your-tunnel-url.example.com +REDIS_PASSWORD=change-me-to-a-strong-random-password diff --git a/everyrow-mcp/deploy/Dockerfile b/everyrow-mcp/deploy/Dockerfile index 29945939..5f7b855a 100644 --- a/everyrow-mcp/deploy/Dockerfile +++ b/everyrow-mcp/deploy/Dockerfile @@ -16,10 +16,16 @@ RUN uv sync --package everyrow-mcp --no-dev --no-sources --no-editable # Stage 2: Slim runtime FROM python:3.13-slim +RUN groupadd -r mcp && useradd -r -g mcp -d /app -s /sbin/nologin mcp + ENV PATH="/app/.venv/bin:$PATH" EXPOSE 8000 -CMD ["everyrow-mcp", "--http", "--port", "8000", "--host", "0.0.0.0"] - WORKDIR /app COPY --link --from=build /app/.venv .venv +RUN chown -R mcp:mcp /app + +USER mcp + +STOPSIGNAL SIGTERM +CMD ["everyrow-mcp", "--http", "--port", "8000", "--host", "0.0.0.0"] diff --git a/everyrow-mcp/deploy/docker-compose.yaml b/everyrow-mcp/deploy/docker-compose.yaml index afd88b6b..b2e6cea0 100644 --- a/everyrow-mcp/deploy/docker-compose.yaml +++ b/everyrow-mcp/deploy/docker-compose.yaml @@ -1,20 +1,25 @@ services: redis: image: redis:7-alpine - ports: - - "6379:6379" + command: redis-server --requirepass "${REDIS_PASSWORD:?Set REDIS_PASSWORD}" + # No ports: — only reachable by other services on the Docker network. healthcheck: - test: ["CMD", "redis-cli", "ping"] + test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"] interval: 5s timeout: 3s retries: 3 + deploy: + resources: + limits: + memory: 256M + restart: unless-stopped mcp-server: build: context: ../.. dockerfile: everyrow-mcp/deploy/Dockerfile ports: - - "8000:8000" + - "127.0.0.1:8000:8000" depends_on: redis: condition: service_healthy @@ -25,4 +30,11 @@ services: REDIS_HOST: redis REDIS_PORT: "6379" REDIS_DB: "13" + REDIS_PASSWORD: "${REDIS_PASSWORD}" EVERYROW_API_URL: "https://everyrow.io/api/v0" + TRUST_PROXY_HEADERS: "true" # Behind Cloudflare tunnel + deploy: + resources: + limits: + memory: 512M + restart: unless-stopped diff --git a/everyrow-mcp/src/everyrow_mcp/app.py b/everyrow-mcp/src/everyrow_mcp/app.py index 2b84549f..919ad788 100644 --- a/everyrow-mcp/src/everyrow_mcp/app.py +++ b/everyrow-mcp/src/everyrow_mcp/app.py @@ -17,6 +17,8 @@ def _clear_task_state() -> None: + if settings.is_http: + return if TASK_STATE_FILE.exists(): TASK_STATE_FILE.unlink() diff --git a/everyrow-mcp/src/everyrow_mcp/auth.py b/everyrow-mcp/src/everyrow_mcp/auth.py index b0d92d5b..cbff8728 100644 --- a/everyrow-mcp/src/everyrow_mcp/auth.py +++ b/everyrow-mcp/src/everyrow_mcp/auth.py @@ -11,6 +11,7 @@ import httpx import jwt as pyjwt +import pydantic from jwt import PyJWKClient from mcp.server.auth.provider import ( AccessToken, @@ -78,7 +79,7 @@ def _decode_jwt(self, token: str, signing_key) -> dict[str, Any]: return pyjwt.decode( token, signing_key.key, - algorithms=["RS256", "ES256"], + algorithms=["RS256"], issuer=self._issuer, audience=self._audience, options={"require": ["exp", "sub", "iss", "aud"]}, @@ -107,7 +108,7 @@ async def verify_token(self, token: str) -> AccessToken | None: logger.warning("JWKS fetch timed out (10s)") return None except pyjwt.PyJWTError: - logger.debug("JWT verification failed", exc_info=True) + logger.debug("JWT verification failed") return None @@ -191,7 +192,9 @@ def _UNSAFE_decode_server_jwt(token: str) -> dict[str, Any]: endpoint over HTTPS and was never exposed to the client. NEVER use this for tokens received from end users. """ - return pyjwt.decode(token, options={"verify_signature": False}) + return pyjwt.decode( + token, options={"verify_signature": False}, algorithms=["RS256"] + ) @staticmethod def _client_ip(request: Request) -> str: @@ -199,10 +202,10 @@ def _client_ip(request: Request) -> str: async def _check_rate_limit(self, action: str, client_ip: str) -> None: rl_key = build_key("ratelimit", action, client_ip) - pipe = self._redis.pipeline() - pipe.incr(rl_key) - pipe.expire(rl_key, settings.registration_rate_window) - count, _ = await pipe.execute() + async with self._redis.pipeline() as pipe: + pipe.incr(rl_key) + pipe.expire(rl_key, settings.registration_rate_window, nx=True) + count, _ = await pipe.execute() if count > settings.registration_rate_limit: raise ValueError(f"{action.title()} rate limit exceeded") @@ -246,9 +249,10 @@ def _supabase_redirect_url(supabase_verifier: str) -> str: def _validate_redirect_url( client: OAuthClientInformationFull, params: AuthorizationParams ) -> None: - if client.redirect_uris: - if str(params.redirect_uri) not in [str(u) for u in client.redirect_uris]: - raise ValueError("redirect_uri does not match any registered URI") + if not client.redirect_uris: + raise ValueError("Client must register at least one redirect_uri") + if str(params.redirect_uri) not in [str(u) for u in client.redirect_uris]: + raise ValueError("redirect_uri does not match any registered URI") async def _validate_auth_request( self, request: Request, action: str, state: str | None, *, consume: bool = False @@ -272,15 +276,14 @@ async def _validate_auth_request( async def _validate_client(self, pending: PendingAuth) -> None: client_info = await self.get_client(pending.client_id) - if client_info is None or ( - pending.params.redirect_uri - and client_info.redirect_uris - and str(pending.params.redirect_uri) + if client_info is None: + raise HTTPException(status_code=400, detail="Invalid client") + if pending.params.redirect_uri and ( + not client_info.redirect_uris + or str(pending.params.redirect_uri) not in [str(u) for u in client_info.redirect_uris] ): - raise HTTPException( - status_code=400, detail="Invalid client or redirect_uri" - ) + raise HTTPException(status_code=400, detail="Invalid redirect_uri") async def _validate_supabase_code( self, code: str, supabase_code_verifier: str @@ -289,10 +292,10 @@ async def _validate_supabase_code( return await self._exchange_supabase_code( code=code, code_verifier=supabase_code_verifier ) - except Exception: - logger.exception("Failed to exchange Supabase code") + except Exception as exc: + logger.error("Failed to exchange Supabase code: %s", type(exc).__name__) raise HTTPException( - status_code=500, detail="Failed to authenticate with Supabase" + status_code=500, detail="Authentication failed. Please try again." ) async def _validate_callback_request( @@ -359,7 +362,7 @@ async def handle_start(self, request: Request) -> RedirectResponse: value=request.path_params.get("state"), max_age=settings.pending_auth_ttl, httponly=True, - samesite="lax", + samesite="strict", secure=True, path="/auth/callback", ) @@ -398,7 +401,7 @@ async def handle_callback(self, request: Request) -> RedirectResponse: "mcp_auth_state", path="/auth/callback", httponly=True, - samesite="lax", + samesite="strict", secure=True, ) return response @@ -411,11 +414,18 @@ async def load_authorization_code( if len(authorization_code) > 256: return None - code_data = await self._redis.getdel(build_key("authcode", authorization_code)) + key = build_key("authcode", authorization_code) + # GETDEL atomically consumes the code — no race between concurrent requests. + code_data = await self._redis.getdel(key) if code_data is None: return None code_obj = EveryRowAuthorizationCode.model_validate_json(code_data) + if code_obj.expires_at and code_obj.expires_at < time.time(): + return None if code_obj.client_id != client.client_id: + # Re-store so the legitimate client can still use it. + remaining = max(1, int((code_obj.expires_at or 0) - time.time())) + await self._redis.setex(key, remaining, code_data) return None return code_obj @@ -470,12 +480,16 @@ async def load_refresh_token( if len(refresh_token) > 256: return None - data = await self._redis.getdel(build_key("refresh", refresh_token)) + key = build_key("refresh", refresh_token) + # GET first, verify client_id, then DELETE — same pattern as + # load_authorization_code to avoid cross-client DoS. + data = await self._redis.get(key) if data is None: return None rt = EveryRowRefreshToken.model_validate_json(data) if rt.client_id != client.client_id: return None + await self._redis.delete(key) return rt async def exchange_refresh_token( @@ -485,9 +499,18 @@ async def exchange_refresh_token( scopes: list[str], ) -> OAuthToken: final_scopes = self._validate_scopes(scopes, refresh_token) - supa_tokens = await self._refresh_supabase_token( - refresh_token.supabase_refresh_token - ) + try: + supa_tokens = await self._refresh_supabase_token( + refresh_token.supabase_refresh_token + ) + except Exception: + # Re-store the consumed refresh token so the user isn't locked out. + await self._redis.setex( + name=build_key("refresh", refresh_token.token), + time=settings.refresh_token_ttl, + value=refresh_token.model_dump_json(), + ) + raise return await self._issue_token_response( access_token=supa_tokens.access_token, client_id=client.client_id, @@ -500,9 +523,11 @@ async def revoke_token(self, token: AccessToken | EveryRowRefreshToken) -> None: await self._redis.delete(build_key("refresh", token.token)) elif isinstance(token, AccessToken): fp = SupabaseTokenVerifier._token_fingerprint(token.token) + remaining = max(0, (token.expires_at or 0) - int(time.time())) + 60 + ttl = remaining if remaining > 60 else self._token_verifier._revocation_ttl await self._redis.setex( name=build_key("revoked", fp), - time=self._token_verifier._revocation_ttl, + time=ttl, value="1", ) @@ -519,10 +544,14 @@ async def _supabase_token_request( ) resp.raise_for_status() data = resp.json() - return SupabaseTokenResponse( - access_token=data["access_token"], - refresh_token=data["refresh_token"], - ) + try: + return SupabaseTokenResponse.model_validate(data) + except pydantic.ValidationError: + logger.error( + "Supabase token response missing required fields: %s", + sorted(data.keys()), + ) + raise ValueError("Invalid token response from identity provider") async def _exchange_supabase_code( self, code: str, code_verifier: str diff --git a/everyrow-mcp/src/everyrow_mcp/config.py b/everyrow-mcp/src/everyrow_mcp/config.py index 7046f65f..f01cd1de 100644 --- a/everyrow-mcp/src/everyrow_mcp/config.py +++ b/everyrow-mcp/src/everyrow_mcp/config.py @@ -27,6 +27,12 @@ class Settings(BaseSettings): ) redis_sentinel_master_name: str | None = Field(default=None) + trust_proxy_headers: bool = Field( + default=False, + description="Trust X-Forwarded-For and CF-Connecting-IP headers for client IP. " + "Enable only when behind a trusted reverse proxy (e.g. Cloudflare).", + ) + # HTTP-only settings — unused in stdio mode mcp_server_url: str = Field(default="") supabase_url: str = Field(default="") diff --git a/everyrow-mcp/src/everyrow_mcp/http_config.py b/everyrow-mcp/src/everyrow_mcp/http_config.py index caff69e7..478b8009 100644 --- a/everyrow-mcp/src/everyrow_mcp/http_config.py +++ b/everyrow-mcp/src/everyrow_mcp/http_config.py @@ -35,6 +35,19 @@ def configure_http_mode( mcp_server_url: str, ) -> None: """Configure the MCP server for HTTP transport.""" + if not no_auth: + missing = [] + if not settings.supabase_url: + missing.append("SUPABASE_URL") + if not settings.supabase_anon_key: + missing.append("SUPABASE_ANON_KEY") + if not settings.mcp_server_url: + missing.append("MCP_SERVER_URL") + if missing: + raise RuntimeError( + f"HTTP auth mode requires these environment variables: {', '.join(missing)}" + ) + redis_client = get_redis_client() if no_auth: lifespan = no_auth_http_lifespan @@ -52,7 +65,7 @@ def configure_http_mode( mcp.settings.port = port _register_widgets(mcp, mcp_server_url) - _register_routes(mcp, auth_provider if not no_auth else None) + _register_routes(mcp, redis_client, auth_provider if not no_auth else None) _add_middleware(mcp, redis_client, rate_limit=not no_auth) @@ -79,6 +92,7 @@ def _results_ui_http() -> str: def _register_routes( mcp: FastMCP, + redis: Redis, auth_provider: EveryRowAuthProvider | None, ) -> None: """Register REST endpoints for widget polling, CSV download, health, and auth.""" @@ -88,6 +102,12 @@ def _register_routes( ) async def _health(_request: Request) -> Response: + try: + await redis.ping() + except Exception: + return JSONResponse( + {"status": "unhealthy", "redis": "unreachable"}, status_code=503 + ) return JSONResponse({"status": "ok"}) mcp.custom_route("/health", ["GET"])(_health) diff --git a/everyrow-mcp/src/everyrow_mcp/middleware.py b/everyrow-mcp/src/everyrow_mcp/middleware.py index a9469310..dfe2ecc5 100644 --- a/everyrow-mcp/src/everyrow_mcp/middleware.py +++ b/everyrow-mcp/src/everyrow_mcp/middleware.py @@ -6,27 +6,31 @@ import time from redis.asyncio import Redis +from redis.exceptions import RedisError from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response +from everyrow_mcp.config import settings from everyrow_mcp.redis_store import build_key logger = logging.getLogger(__name__) def get_client_ip(request: Request) -> str | None: - """Extract client IP, preferring proxy headers when behind a reverse proxy. + """Extract client IP, preferring proxy headers only when trusted. - Priority: CF-Connecting-IP (Cloudflare) > X-Forwarded-For > request.client. - Returns None if the IP cannot be determined. + Only reads CF-Connecting-IP / X-Forwarded-For when + ``settings.trust_proxy_headers`` is True (i.e. running behind a known + reverse proxy like Cloudflare). Otherwise uses the direct connection IP. """ - cf_ip = request.headers.get("cf-connecting-ip") - if cf_ip: - return cf_ip.strip() - forwarded = request.headers.get("x-forwarded-for") - if forwarded: - return forwarded.split(",")[0].strip() + if settings.trust_proxy_headers: + cf_ip = request.headers.get("cf-connecting-ip") + if cf_ip: + return cf_ip.strip() + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() return request.client.host if request.client else None @@ -52,10 +56,15 @@ def __init__( self._window_seconds = window_seconds async def dispatch(self, request: Request, call_next) -> Response: + if request.url.path == "/health": + return await call_next(request) + client_ip = get_client_ip(request) if client_ip is None: - logger.warning("Could not determine client IP, skipping rate limit") - return await call_next(request) + logger.warning( + "Could not determine client IP, using shared fallback bucket" + ) + client_ip = "__unknown__" window_id = str(int(time.time()) // self._window_seconds) key = build_key("rate", client_ip, window_id) @@ -73,7 +82,7 @@ async def dispatch(self, request: Request, call_next) -> Response: status_code=429, headers={"Retry-After": str(retry_after)}, ) - except Exception: - logger.warning("Rate-limit check failed (Redis unavailable)", exc_info=True) + except (RedisError, OSError): + logger.warning("Rate-limit check failed (Redis unavailable)") return await call_next(request) diff --git a/everyrow-mcp/src/everyrow_mcp/models.py b/everyrow-mcp/src/everyrow_mcp/models.py index ba30ad81..32a780ea 100644 --- a/everyrow-mcp/src/everyrow_mcp/models.py +++ b/everyrow-mcp/src/everyrow_mcp/models.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, Literal +from uuid import UUID from jsonschema import SchemaError from jsonschema.validators import validator_for @@ -25,6 +26,10 @@ "object": dict, } +MAX_INLINE_ROWS = 50_000 +MAX_INLINE_DATA_BYTES = 10 * 1024 * 1024 # 10 MB +MAX_SCHEMA_PROPERTIES = 50 + def _validate_response_schema(schema: dict[str, Any] | None) -> dict[str, Any] | None: """Validate response_schema is a JSON Schema object schema.""" @@ -51,6 +56,12 @@ def _validate_response_schema(schema: dict[str, Any] | None) -> dict[str, Any] | "response_schema must include a non-empty top-level 'properties' object" ) + if len(properties) > MAX_SCHEMA_PROPERTIES: + raise ValueError( + f"response_schema has {len(properties)} properties " + f"(max {MAX_SCHEMA_PROPERTIES})" + ) + for field_name, field_def in properties.items(): if not isinstance(field_def, dict): raise ValueError( @@ -124,6 +135,8 @@ def _check_exactly_one( class _SingleSourceInput(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + input_csv: str | None = Field( default=None, description="Absolute path to CSV file (local/stdio mode only).", @@ -140,6 +153,21 @@ def validate_input_csv(cls, v: str | None) -> str | None: validate_csv_path(v) return v + @field_validator("data") + @classmethod + def validate_data_size( + cls, v: str | list[dict[str, Any]] | None + ) -> str | list[dict[str, Any]] | None: + if v is None: + return v + if isinstance(v, str) and len(v) > MAX_INLINE_DATA_BYTES: + raise ValueError( + f"Inline data exceeds {MAX_INLINE_DATA_BYTES // (1024 * 1024)} MB limit" + ) + if isinstance(v, list) and len(v) > MAX_INLINE_ROWS: + raise ValueError(f"Inline data has {len(v)} rows (max {MAX_INLINE_ROWS})") + return v + @model_validator(mode="after") def check_input_source(self): _check_exactly_one( @@ -348,6 +376,15 @@ def validate_response_schema( return _validate_response_schema(v) +def _validate_task_id(v: str) -> str: + """Validate task_id is a valid UUID.""" + try: + UUID(v) + except ValueError as exc: + raise ValueError("task_id must be a valid UUID") from exc + return v + + class ProgressInput(BaseModel): """Input for checking task progress.""" @@ -355,6 +392,11 @@ class ProgressInput(BaseModel): task_id: str = Field(..., description="The task ID returned by the operation tool.") + @field_validator("task_id") + @classmethod + def validate_task_id(cls, v: str) -> str: + return _validate_task_id(v) + class CancelInput(BaseModel): """Input for cancelling a running task.""" @@ -363,6 +405,11 @@ class CancelInput(BaseModel): task_id: str = Field(..., description="The task ID to cancel.") + @field_validator("task_id") + @classmethod + def validate_task_id(cls, v: str) -> str: + return _validate_task_id(v) + def _validate_output_path(v: str | None) -> str | None: """Validate output_path ends in .csv and parent directory exists.""" @@ -389,6 +436,11 @@ class StdioResultsInput(BaseModel): description="Full absolute path to the output CSV file (must end in .csv).", ) + @field_validator("task_id") + @classmethod + def validate_task_id(cls, v: str) -> str: + return _validate_task_id(v) + @field_validator("output_path") @classmethod def validate_output(cls, v: str) -> str: @@ -403,6 +455,12 @@ class HttpResultsInput(BaseModel): model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") task_id: str = Field(..., description="The task ID of the completed task.") + + @field_validator("task_id") + @classmethod + def validate_task_id(cls, v: str) -> str: + return _validate_task_id(v) + output_path: str | None = Field( default=None, description="Full absolute path to the output CSV file (must end in .csv). " diff --git a/everyrow-mcp/src/everyrow_mcp/redis_store.py b/everyrow-mcp/src/everyrow_mcp/redis_store.py index 18057080..aa7111bf 100644 --- a/everyrow-mcp/src/everyrow_mcp/redis_store.py +++ b/everyrow-mcp/src/everyrow_mcp/redis_store.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging +import re from enum import StrEnum -from functools import lru_cache from pathlib import Path from redis.asyncio import Redis, Sentinel @@ -33,9 +33,12 @@ class Transport(StrEnum): # ── Redis infrastructure ────────────────────────────────────── +_KEY_UNSAFE = re.compile(r"[^a-zA-Z0-9._\-]") + + def build_key(*parts: str) -> str: - """Build a namespaced Redis key, sanitising embedded colons.""" - sanitized = [p.replace(":", "_") for p in parts] + """Build a namespaced Redis key, sanitising user-controlled characters.""" + sanitized = [_KEY_UNSAFE.sub("_", p) for p in parts] return "mcp:" + ":".join(sanitized) @@ -90,16 +93,27 @@ def create_redis_client( return client -@lru_cache +_redis_client: Redis | None = None + + def get_redis_client() -> Redis: - return create_redis_client( - host=settings.redis_host, - port=settings.redis_port, - db=settings.redis_db, - password=settings.redis_password, - sentinel_endpoints=settings.redis_sentinel_endpoints, - sentinel_master_name=settings.redis_sentinel_master_name, - ) + global _redis_client # noqa: PLW0603 + if _redis_client is None: + _redis_client = create_redis_client( + host=settings.redis_host, + port=settings.redis_port, + db=settings.redis_db, + password=settings.redis_password, + sentinel_endpoints=settings.redis_sentinel_endpoints, + sentinel_master_name=settings.redis_sentinel_master_name, + ) + return _redis_client + + +def set_redis_client(client: Redis | None) -> None: + """Override the Redis client (for testing).""" + global _redis_client # noqa: PLW0603 + _redis_client = client async def get_result_meta(task_id: str) -> str | None: @@ -133,7 +147,18 @@ async def store_result_page( # ── CSV result storage ──────────────────────────────────────── +MAX_CSV_CACHE_BYTES = 50 * 1024 * 1024 # 50 MB — skip Redis cache for oversized results + + async def store_result_csv(task_id: str, csv_text: str) -> None: + if len(csv_text) > MAX_CSV_CACHE_BYTES: + logger.warning( + "Skipping Redis cache for task %s: CSV is %d bytes (limit %d)", + task_id, + len(csv_text), + MAX_CSV_CACHE_BYTES, + ) + return await get_redis_client().setex( name=build_key("result", task_id, "csv"), time=CSV_CACHE_TTL, value=csv_text ) diff --git a/everyrow-mcp/src/everyrow_mcp/result_store.py b/everyrow-mcp/src/everyrow_mcp/result_store.py index b68c8b03..4dced821 100644 --- a/everyrow-mcp/src/everyrow_mcp/result_store.py +++ b/everyrow-mcp/src/everyrow_mcp/result_store.py @@ -61,12 +61,19 @@ def clamp_page_to_budget( if estimated <= settings.token_budget: return preview_records, page_size + # Pre-compute per-row token sizes and build a prefix sum so the binary + # search doesn't need to re-serialize on every iteration. + # Overhead per-row is ~2 tokens for the JSON array wrapper/commas. + row_sizes = [_estimate_tokens(json.dumps(r)) + 2 for r in preview_records] + prefix = [0] * (len(row_sizes) + 1) + for i, s in enumerate(row_sizes): + prefix[i + 1] = prefix[i] + s + lo, hi = 1, len(preview_records) best = 1 while lo <= hi: mid = (lo + hi) // 2 - candidate = preview_records[:mid] - if _estimate_tokens(json.dumps(candidate)) <= settings.token_budget: + if prefix[mid] <= settings.token_budget: best = mid lo = mid + 1 else: diff --git a/everyrow-mcp/src/everyrow_mcp/routes.py b/everyrow-mcp/src/everyrow_mcp/routes.py index 83650bd9..16a8d64b 100644 --- a/everyrow-mcp/src/everyrow_mcp/routes.py +++ b/everyrow-mcp/src/everyrow_mcp/routes.py @@ -18,7 +18,25 @@ logger = logging.getLogger(__name__) -_CORS = {"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET"} + +def _cors_headers() -> dict[str, str]: + origin = settings.mcp_server_url or "http://localhost:8000" + return { + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": "GET", + "Access-Control-Allow-Headers": "Authorization", + } + + +def _validate_uuid(task_id: str) -> JSONResponse | None: + """Return a 400 response if task_id is not a valid UUID, else None.""" + try: + UUID(task_id) + except ValueError: + return JSONResponse( + {"error": "Invalid task ID"}, status_code=400, headers=_cors_headers() + ) + return None async def _validate_poll_token(task_id: str, request: Request) -> JSONResponse | None: @@ -26,24 +44,30 @@ async def _validate_poll_token(task_id: str, request: Request) -> JSONResponse | expected = await redis_store.get_poll_token(task_id) provided = request.query_params.get("token", "") if not expected or not secrets.compare_digest(provided, expected): - return JSONResponse({"error": "Unauthorized"}, status_code=403, headers=_CORS) + return JSONResponse( + {"error": "Unauthorized"}, status_code=403, headers=_cors_headers() + ) return None async def api_progress(request: Request) -> Response: """REST endpoint for the session widget to poll task progress.""" + cors = _cors_headers() if request.method == "OPTIONS": - return Response(status_code=204, headers=_CORS) + return Response(status_code=204, headers=cors) task_id = request.path_params["task_id"] + if err := _validate_uuid(task_id): + return err + if err := await _validate_poll_token(task_id, request): return err api_key = await redis_store.get_task_token(task_id) if not api_key: - return JSONResponse({"error": "Unknown task"}, status_code=404, headers=_CORS) + return JSONResponse({"error": "Unknown task"}, status_code=404, headers=cors) try: client = AuthenticatedClient( @@ -65,36 +89,42 @@ async def api_progress(request: Request) -> Response: await redis_store.pop_task_token(task_id) return JSONResponse( - ts.model_dump(mode="json", exclude=_UI_EXCLUDE), headers=_CORS + ts.model_dump(mode="json", exclude=_UI_EXCLUDE), headers=cors ) except Exception: logger.exception("Progress poll failed for task %s", task_id) return JSONResponse( - {"error": "Internal server error"}, status_code=500, headers=_CORS + {"error": "Internal server error"}, status_code=500, headers=cors ) async def api_download(request: Request) -> Response: """REST endpoint to download task results as CSV.""" + cors = _cors_headers() if request.method == "OPTIONS": - return Response(status_code=204, headers=_CORS) + return Response(status_code=204, headers=cors) task_id = request.path_params["task_id"] + if err := _validate_uuid(task_id): + return err + if err := await _validate_poll_token(task_id, request): return err csv_text = await redis_store.get_result_csv(task_id) if csv_text is None: return JSONResponse( - {"error": "Results not found or expired"}, status_code=404, headers=_CORS + {"error": "Results not found or expired"}, status_code=404, headers=cors ) + safe_prefix = "".join(c for c in task_id[:8] if c.isalnum() or c == "-") return Response( content=csv_text, media_type="text/csv", headers={ - **_CORS, - "Content-Disposition": f'attachment; filename="results_{task_id[:8]}.csv"', + **cors, + "Content-Disposition": f'attachment; filename="results_{safe_prefix}.csv"', + "Referrer-Policy": "no-referrer", }, ) diff --git a/everyrow-mcp/src/everyrow_mcp/templates.py b/everyrow-mcp/src/everyrow_mcp/templates.py index 80817d08..eec85255 100644 --- a/everyrow-mcp/src/everyrow_mcp/templates.py +++ b/everyrow-mcp/src/everyrow_mcp/templates.py @@ -776,6 +776,7 @@ const app=new App({name:"EveryRow Session",version:"1.0.0"}); const el=document.getElementById("c"); let pollUrl=null,pollTimer=null,sessionUrl="",wasDone=false; +function esc(s){const d=document.createElement("div");d.textContent=String(s);return d.innerHTML;} app.ontoolresult=({content})=>{ const t=content?.find(c=>c.type==="text");if(!t)return; @@ -805,7 +806,7 @@ h+=`