Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/deploy-mcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ jobs:
run: uv sync --directory everyrow-mcp
- name: Run ruff
run: uv run --directory everyrow-mcp ruff check
- name: Run basedpyright
run: uv run --directory everyrow-mcp basedpyright --project .
- name: Run tests
run: uv run --directory everyrow-mcp pytest tests

Expand Down
2 changes: 1 addition & 1 deletion everyrow-mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dev = [
]

[tool.basedpyright]
venvPath = "."
venvPath = ".."
venv = ".venv"
include = ["src", "tests"]
typeCheckingMode = "standard"
Expand Down
4 changes: 2 additions & 2 deletions everyrow-mcp/src/everyrow_mcp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def http_lifespan(_server: FastMCP):
across sessions. Process exit handles cleanup.
"""
redis_client = get_redis_client()
await redis_client.ping()
await redis_client.ping() # pyright: ignore[reportGeneralTypeIssues]

def _http_client_factory() -> AuthenticatedClient:
access_token = get_access_token()
Expand All @@ -73,7 +73,7 @@ def _http_client_factory() -> AuthenticatedClient:
async def no_auth_http_lifespan(_server: FastMCP):
"""HTTP no-auth mode: singleton client from API key, verify Redis."""
redis_client = get_redis_client()
await redis_client.ping()
await redis_client.ping() # pyright: ignore[reportGeneralTypeIssues]

with _create_sdk_client() as client:
response = await get_billing(client=client)
Expand Down
6 changes: 5 additions & 1 deletion everyrow-mcp/src/everyrow_mcp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ async def authorize(
state = secrets.token_urlsafe(32)
supabase_verifier = secrets.token_urlsafe(32)

assert client.client_id is not None
pending = PendingAuth(
client_id=client.client_id,
params=params,
Expand All @@ -364,10 +365,11 @@ async def handle_start(self, request: Request) -> RedirectResponse:
request, "start", request.path_params.get("state")
)

state = request.path_params.get("state", "")
response = RedirectResponse(url=pending.supabase_redirect_url, status_code=302)
response.set_cookie(
key="mcp_auth_state",
value=request.path_params.get("state"),
value=state,
max_age=settings.pending_auth_ttl,
httponly=True,
samesite="lax",
Expand Down Expand Up @@ -472,6 +474,7 @@ async def exchange_authorization_code(
client: OAuthClientInformationFull,
authorization_code: EveryRowAuthorizationCode,
) -> OAuthToken:
assert client.client_id is not None
return await self._issue_token_response(
access_token=authorization_code.supabase_access_token,
client_id=client.client_id,
Expand Down Expand Up @@ -519,6 +522,7 @@ async def exchange_refresh_token(
value=refresh_token.model_dump_json(),
)
raise
assert client.client_id is not None
return await self._issue_token_response(
access_token=supa_tokens.access_token,
client_id=client.client_id,
Expand Down
5 changes: 3 additions & 2 deletions everyrow-mcp/src/everyrow_mcp/http_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def configure_http_mode(
)

redis_client = get_redis_client()
auth_provider: EveryRowAuthProvider | None = None
if no_auth:
lifespan = no_auth_http_lifespan
else:
Expand All @@ -65,7 +66,7 @@ def configure_http_mode(
mcp.settings.port = port

_register_widgets(mcp, mcp_server_url)
_register_routes(mcp, redis_client, auth_provider if not no_auth else None)
_register_routes(mcp, redis_client, auth_provider)
_add_middleware(mcp, redis_client, rate_limit=not no_auth)


Expand Down Expand Up @@ -103,7 +104,7 @@ def _register_routes(

async def _health(_request: Request) -> Response:
try:
await redis.ping()
await redis.ping() # pyright: ignore[reportGeneralTypeIssues]
except Exception:
return JSONResponse(
{"status": "unhealthy", "redis": "unreachable"}, status_code=503
Expand Down
11 changes: 6 additions & 5 deletions everyrow-mcp/src/everyrow_mcp/result_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import logging
import math
from typing import Any

import pandas as pd
from mcp.types import TextContent
Expand All @@ -26,7 +27,7 @@
logger = logging.getLogger(__name__)


def _sanitize_records(records: list[dict]) -> list[dict]:
def _sanitize_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Replace NaN/Inf float values with None for valid JSON serialization.

pandas ``to_dict(orient="records")`` preserves ``float('nan')`` which
Expand Down Expand Up @@ -54,9 +55,9 @@ def _estimate_tokens(text: str) -> int:


def clamp_page_to_budget(
preview_records: list[dict],
preview_records: list[dict[str, Any]],
page_size: int,
) -> tuple[list[dict], int]:
) -> tuple[list[dict[str, Any]], int]:
estimated = _estimate_tokens(json.dumps(preview_records))
if estimated <= settings.token_budget:
return preview_records, page_size
Expand Down Expand Up @@ -86,7 +87,7 @@ def clamp_page_to_budget(
def _build_result_response(
task_id: str,
csv_url: str,
preview_records: list[dict],
preview_records: list[dict[str, Any]],
total: int,
columns: list[str],
offset: int,
Expand Down Expand Up @@ -242,7 +243,7 @@ async def try_store_result(
columns = list(df.columns)

# Store base metadata
meta = {"total": total, "columns": columns}
meta: dict[str, Any] = {"total": total, "columns": columns}
if session_url:
meta["session_url"] = session_url
await redis_store.store_result_meta(task_id, json.dumps(meta))
Expand Down
4 changes: 1 addition & 3 deletions everyrow-mcp/src/everyrow_mcp/tool_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ async def create_tool_response(
return [main_content]


_UI_EXCLUDE = frozenset(
{"is_terminal", "is_screen", "task_type", "error", "started_at"}
)
_UI_EXCLUDE: set[str] = {"is_terminal", "is_screen", "task_type", "error", "started_at"}


class TaskState(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion everyrow-mcp/src/everyrow_mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ async def everyrow_results_http(
# ── Fetch from API ────────────────────────────────────────────
try:
df, session_id = await _fetch_task_result(client, task_id)
session_url = get_session_url(session_id) if session_id else ""
session_url = get_session_url(UUID(session_id)) if session_id else ""
except TaskNotReady as e:
return [
TextContent(
Expand Down
5 changes: 3 additions & 2 deletions everyrow-mcp/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Set env vars for HttpSettings before any everyrow imports
import os
from collections.abc import AsyncGenerator

os.environ.setdefault("EVERYROW_API_KEY", "test-api-key")
os.environ.setdefault("SUPABASE_URL", "https://test.supabase.co")
Expand Down Expand Up @@ -64,7 +65,7 @@ def _redis_server():


@pytest.fixture
async def fake_redis(_redis_server) -> aioredis.Redis:
async def fake_redis(_redis_server) -> AsyncGenerator[aioredis.Redis, None]:
"""A real Redis client, flushed after each test."""
r = aioredis.Redis(host="localhost", port=_REDIS_PORT, decode_responses=True)
await r.flushdb()
Expand Down Expand Up @@ -102,7 +103,7 @@ def override_settings(**overrides):


@pytest.fixture
async def everyrow_client():
async def everyrow_client() -> AsyncGenerator[object, None]:
"""Provide a real everyrow SDK client for integration tests."""
with create_client() as client:
yield client
Expand Down
21 changes: 12 additions & 9 deletions everyrow-mcp/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from mcp.server.auth.provider import AccessToken, AuthorizationParams
from mcp.shared.auth import OAuthClientInformationFull
from pydantic import AnyUrl

from everyrow_mcp.auth import (
EveryRowAuthorizationCode,
Expand Down Expand Up @@ -53,6 +54,7 @@ def mock_redis():
async def _setex(*args, name=None, time=None, value=None): # noqa: ARG001
key = name if name is not None else args[0]
val = value if value is not None else args[2] if len(args) > 2 else None
assert val is not None
store[key] = val

async def _exists(key):
Expand Down Expand Up @@ -89,7 +91,7 @@ def verifier(rsa_keypair, mock_redis):

def _make_jwt(
private_key,
claims: dict | None = None,
claims: dict[str, str | int] | None = None,
*,
remove_claims: list[str] | None = None,
) -> str:
Expand Down Expand Up @@ -362,6 +364,7 @@ async def _set(key, value):
async def _setex(*args, name=None, time=None, value=None): # noqa: ARG001
key = name if name is not None else args[0]
val = value if value is not None else args[2] if len(args) > 2 else None
assert val is not None
store[key] = val

async def _get(key):
Expand Down Expand Up @@ -414,7 +417,7 @@ def test_client():
"""A minimal OAuthClientInformationFull for tests."""
return OAuthClientInformationFull(
client_id="test-client-id",
redirect_uris=["https://example.com/callback"],
redirect_uris=[AnyUrl("https://example.com/callback")],
)


Expand All @@ -427,7 +430,7 @@ async def test_auth_code_consumed_atomically(self, provider, test_client):
auth_code_obj = EveryRowAuthorizationCode(
code=auth_code_str,
client_id="test-client-id",
redirect_uri="https://example.com/callback",
redirect_uri=AnyUrl("https://example.com/callback"),
redirect_uri_provided_explicitly=True,
code_challenge="test-challenge",
scopes=["read"],
Expand Down Expand Up @@ -563,7 +566,7 @@ async def test_redirect_uri_mismatch_rejected(self, provider, test_client):
params = AuthorizationParams(
state="s1",
scopes=["read"],
redirect_uri="https://evil.example.com/callback",
redirect_uri=AnyUrl("https://evil.example.com/callback"),
code_challenge="challenge",
redirect_uri_provided_explicitly=True,
)
Expand All @@ -576,7 +579,7 @@ async def test_matching_redirect_uri_accepted(self, provider, test_client):
params = AuthorizationParams(
state="s1",
scopes=["read"],
redirect_uri="https://example.com/callback",
redirect_uri=AnyUrl("https://example.com/callback"),
code_challenge="challenge",
redirect_uri_provided_explicitly=True,
)
Expand Down Expand Up @@ -653,7 +656,7 @@ async def test_auth_code_client_id_mismatch(self, provider):
auth_code_obj = EveryRowAuthorizationCode(
code=auth_code_str,
client_id="other-client-id",
redirect_uri="https://example.com/callback",
redirect_uri=AnyUrl("https://example.com/callback"),
redirect_uri_provided_explicitly=True,
code_challenge="test-challenge",
scopes=["read"],
Expand All @@ -669,7 +672,7 @@ async def test_auth_code_client_id_mismatch(self, provider):

wrong_client = OAuthClientInformationFull(
client_id="wrong-client-id",
redirect_uris=["https://example.com/callback"],
redirect_uris=[AnyUrl("https://example.com/callback")],
)
result = await provider.load_authorization_code(wrong_client, auth_code_str)
assert result is None
Expand All @@ -692,7 +695,7 @@ async def test_refresh_token_client_id_mismatch(self, provider):

wrong_client = OAuthClientInformationFull(
client_id="wrong-client-id",
redirect_uris=["https://example.com/callback"],
redirect_uris=[AnyUrl("https://example.com/callback")],
)
result = await provider.load_refresh_token(wrong_client, "rt-mismatch")
assert result is None
Expand Down Expand Up @@ -731,7 +734,7 @@ async def test_auth_code_expired_rejected(self, provider, test_client):
auth_code_obj = EveryRowAuthorizationCode(
code=auth_code_str,
client_id="test-client-id",
redirect_uri="https://example.com/callback",
redirect_uri=AnyUrl("https://example.com/callback"),
redirect_uri_provided_explicitly=True,
code_challenge="test-challenge",
scopes=["read"],
Expand Down
23 changes: 16 additions & 7 deletions everyrow-mcp/tests/test_http_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import subprocess
import sys
import time
from collections.abc import Generator
from contextlib import asynccontextmanager
from typing import Any

import httpx
import pytest
import redis
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent

# Skip all tests unless opted in
pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -77,7 +80,7 @@ def _flush_redis_db(db: int = REDIS_TEST_DB) -> None:


@pytest.fixture(scope="session")
def mcp_server() -> str:
def mcp_server() -> Generator[str, None, None]:
"""Start the MCP server subprocess on a random port with Redis DB 15.

Yields the base URL (e.g. http://127.0.0.1:PORT).
Expand Down Expand Up @@ -166,7 +169,7 @@ async def poll_via_mcp(
result = await session.call_tool(
"everyrow_progress", {"params": {"task_id": task_id}}
)
texts = [b.text for b in result.content if hasattr(b, "text")]
texts = [b.text for b in result.content if isinstance(b, TextContent)]
human_text = texts[-1] if texts else ""
print(f" Poll {i + 1}: {human_text.splitlines()[0]}")

Expand All @@ -180,10 +183,10 @@ async def poll_via_mcp(
raise TimeoutError(f"Task {task_id} did not complete within {max_polls} polls")


def parse_widget_json(content_blocks) -> dict:
def parse_widget_json(content_blocks) -> dict[str, Any]:
"""Parse the first TextContent block as JSON (the widget data)."""
for block in content_blocks:
if hasattr(block, "text"):
if isinstance(block, TextContent):
try:
return json.loads(block.text)
except json.JSONDecodeError:
Expand Down Expand Up @@ -256,7 +259,10 @@ async def test_agent_submit_poll_results(
)

# Fail fast on tool errors
first_text = submit_result.content[0].text if submit_result.content else ""
first_block = submit_result.content[0] if submit_result.content else None
first_text = (
first_block.text if isinstance(first_block, TextContent) else ""
)
assert not first_text.startswith("Error"), f"Tool error: {first_text}"

# Parse widget JSON from the first content block
Expand All @@ -276,7 +282,9 @@ async def test_agent_submit_poll_results(
"everyrow_results",
{"params": {"task_id": task_id, "page_size": 1}},
)
results_texts = [b.text for b in results_resp.content if hasattr(b, "text")]
results_texts = [
b.text for b in results_resp.content if isinstance(b, TextContent)
]
results_widget = parse_widget_json(results_resp.content)
print(f"\nResults page 1 widget: {json.dumps(results_widget, indent=2)}")

Expand Down Expand Up @@ -306,7 +314,7 @@ async def test_agent_submit_poll_results(
)
results_widget2 = parse_widget_json(results_resp2.content)
results_texts2 = [
b.text for b in results_resp2.content if hasattr(b, "text")
b.text for b in results_resp2.content if isinstance(b, TextContent)
]
print(f"\nResults page 2 widget: {json.dumps(results_widget2, indent=2)}")

Expand All @@ -328,6 +336,7 @@ async def test_agent_submit_poll_results(
reader = csv.DictReader(io.StringIO(csv_response.text))
rows = list(reader)
assert len(rows) == 2, f"Expected 2 CSV rows, got {len(rows)}"
assert reader.fieldnames is not None
assert "headquarters" in reader.fieldnames, (
f"Expected 'headquarters' column, got columns: {reader.fieldnames}"
)
Expand Down
Loading