Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions src/session_analytics/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,81 @@ def get_session_efficiency(
return {"status": "ok", **result}


class TailscaleAuthMiddleware:
"""ASGI middleware that requires Tailscale identity headers.

When running behind `tailscale serve`, Tailscale injects identity headers
(Tailscale-User-Login) into requests. This middleware rejects requests
that don't have these headers.

Set SESSION_ANALYTICS_AUTH_DISABLED=1 to disable (for testing/local dev).
"""

TAILSCALE_USER_HEADER = b"tailscale-user-login"

def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

headers = dict(scope.get("headers", []))
tailscale_user = headers.get(self.TAILSCALE_USER_HEADER)

if not tailscale_user:
logger.warning(
f"Rejected unauthenticated request to {scope.get('path', '/')} "
f"from {scope.get('client', ('unknown',))[0]}"
)
await self._send_unauthorized(send)
return

user = tailscale_user.decode("utf-8", errors="replace")
logger.debug(f"Authenticated request from {user}")
await self.app(scope, receive, send)

async def _send_unauthorized(self, send):
"""Send a 401 Unauthorized response."""
body = b'{"error": "Unauthorized", "message": "Tailscale identity required"}'
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(body)).encode()),
],
}
)
await send(
{
"type": "http.response.body",
"body": body,
"more_body": False,
}
)


def create_app():
"""Create the ASGI app for uvicorn."""
# stateless_http=True allows resilience to server restarts
return mcp.http_app(stateless_http=True)
"""Create the ASGI app for uvicorn.

Set SESSION_ANALYTICS_AUTH_DISABLED=1 to disable auth (for testing/local dev).
"""
app = mcp.http_app(stateless_http=True)

auth_disabled = os.environ.get("SESSION_ANALYTICS_AUTH_DISABLED", "").lower() in (
"1",
"true",
)
if not auth_disabled:
app = TailscaleAuthMiddleware(app)
logger.info("Tailscale auth enabled - requests require identity headers")
else:
logger.warning("Tailscale auth DISABLED - all requests allowed")

return app


def main():
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pytest configuration and shared fixtures."""

import os
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
Expand All @@ -9,6 +10,12 @@
from session_analytics.storage import Event, Session, SQLiteStorage


def pytest_configure(config):
"""Set up test environment before any imports happen."""
# Disable Tailscale auth for tests
os.environ["SESSION_ANALYTICS_AUTH_DISABLED"] = "1"


@pytest.fixture
def storage():
"""Create a temporary storage instance for testing.
Expand Down
109 changes: 109 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Tests for the MCP server."""

import pytest

from session_analytics.server import (
TailscaleAuthMiddleware,
analyze_failures,
analyze_trends,
classify_sessions,
Expand Down Expand Up @@ -396,3 +399,109 @@ def test_get_large_tool_results():
assert "min_size_kb" in result
assert "large_results" in result
assert isinstance(result["large_results"], list)


# --- Tailscale Auth Middleware Tests ---


class TestTailscaleAuthMiddleware:
"""Tests for TailscaleAuthMiddleware."""

@pytest.fixture
def mock_app(self):
"""Mock ASGI app that tracks calls."""

async def app(scope, receive, send):
app.called = True
app.scope = scope
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
}
)
await send(
{
"type": "http.response.body",
"body": b'{"status": "ok"}',
"more_body": False,
}
)

app.called = False
app.scope = None
return app

@pytest.fixture
def capture_response(self):
"""Capture ASGI response for assertions."""

class ResponseCapture:
def __init__(self):
self.status = None
self.headers = []
self.body = b""

async def __call__(self, message):
if message["type"] == "http.response.start":
self.status = message["status"]
self.headers = message.get("headers", [])
elif message["type"] == "http.response.body":
self.body += message.get("body", b"")

return ResponseCapture()

@pytest.mark.asyncio
async def test_allows_request_with_tailscale_header(self, mock_app, capture_response):
"""Requests with Tailscale-User-Login header are allowed."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "http",
"path": "/mcp",
"headers": [(b"tailscale-user-login", b"user@example.com")],
"client": ("127.0.0.1", 12345),
}

async def receive():
return {"type": "http.request", "body": b""}

await middleware(scope, receive, capture_response)

assert mock_app.called is True
assert capture_response.status == 200

@pytest.mark.asyncio
async def test_rejects_request_without_tailscale_header(self, mock_app, capture_response):
"""Requests without Tailscale-User-Login header get 401."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "http",
"path": "/mcp",
"headers": [],
"client": ("127.0.0.1", 12345),
}

async def receive():
return {"type": "http.request", "body": b""}

await middleware(scope, receive, capture_response)

assert mock_app.called is False
assert capture_response.status == 401
assert b"Unauthorized" in capture_response.body

@pytest.mark.asyncio
async def test_passes_through_non_http_requests(self, mock_app, capture_response):
"""Non-HTTP requests (websocket, lifespan) pass through without auth."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "lifespan",
}

async def receive():
return {"type": "lifespan.startup"}

await middleware(scope, receive, capture_response)

assert mock_app.called is True