diff --git a/src/agent_session_analytics/cli.py b/src/agent_session_analytics/cli.py index 2eb762d..baa6282 100644 --- a/src/agent_session_analytics/cli.py +++ b/src/agent_session_analytics/cli.py @@ -1473,7 +1473,7 @@ def cmd_benchmark(args): # - ingest_logs, ingest_git_history, ingest_git_history_all_projects # - correlate_git_with_sessions, ingest_bus_events # - find_related_sessions (requires valid session_id) - # - upload_entries, get_sync_status (remote sync tools - modify DB or require client context) + # - upload_entries, get_sync_status, finalize_sync (remote sync tools - modify DB or require client context) benchmarks = [] for tool_name, tool_func in tool_functions.items(): @@ -1530,11 +1530,23 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None: req = urllib.request.Request( remote_url, data=json.dumps(mcp_request).encode("utf-8"), - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, method="POST", ) - with urllib.request.urlopen(req, timeout=60) as resp: - result = json.loads(resp.read().decode("utf-8")) + with urllib.request.urlopen(req, timeout=120) as resp: + raw_response = resp.read().decode("utf-8") + # Parse SSE format: "event: message\ndata: {...}" + result = None + for line in raw_response.split("\n"): + if line.startswith("data: "): + result = json.loads(line[6:]) + break + if result is None: + # Try parsing as plain JSON (fallback) + result = json.loads(raw_response) if "result" in result: content = result["result"].get("content", []) if content and content[0].get("type") == "text": @@ -1606,12 +1618,19 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None: server_latest = server_sessions.get(session_id) if server_latest: # Parse server timestamp and filter + from datetime import timezone + server_ts = datetime.fromisoformat(server_latest.replace("Z", "+00:00")) + if server_ts.tzinfo is None: + server_ts = server_ts.replace(tzinfo=timezone.utc) for project_path, entry in entries: entry_ts_str = entry.get("timestamp") if entry_ts_str: try: entry_ts = datetime.fromisoformat(entry_ts_str.replace("Z", "+00:00")) + # Ensure both are timezone-aware for comparison + if entry_ts.tzinfo is None: + entry_ts = entry_ts.replace(tzinfo=timezone.utc) if entry_ts > server_ts: entries_to_send.append((project_path, entry)) except ValueError: @@ -1648,10 +1667,15 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None: total_added = 0 total_skipped = 0 total_errors = 0 + entries_uploaded = 0 + total_entries = len(entries_to_send) + start_time = datetime.now() # Upload in batches per project batch_size = args.batch_size for project_path, entries in by_project.items(): + if not args.json: + print(f" {project_path}: {len(entries)} entries", flush=True) for i in range(0, len(entries), batch_size): batch = entries[i : i + batch_size] result = mcp_call("upload_entries", {"entries": batch, "project_path": project_path}) @@ -1660,9 +1684,26 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None: total_added += result.get("events_added", 0) total_skipped += result.get("events_skipped", 0) total_errors += result.get("parse_errors", 0) - - if not args.json: - print(f" {project_path}: {len(entries)} entries") + entries_uploaded += len(batch) + + # Progress update with time estimate + if not args.json and total_entries > 0: + elapsed = (datetime.now() - start_time).total_seconds() + pct = entries_uploaded / total_entries * 100 + if entries_uploaded > 0 and elapsed > 0: + rate = entries_uploaded / elapsed + remaining = (total_entries - entries_uploaded) / rate if rate > 0 else 0 + print( + f" [{pct:5.1f}%] {entries_uploaded}/{total_entries} " + f"({rate:.0f}/s, ~{remaining:.0f}s remaining)", + flush=True, + ) + + # Finalize sync by updating session statistics once + if entries_to_send and not args.json: + print("Finalizing sync (updating session stats)...", flush=True) + finalize_result = mcp_call("finalize_sync", {}) + sessions_updated = finalize_result.get("sessions_updated", 0) if finalize_result else 0 output = { "status": "ok", @@ -1671,6 +1712,7 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None: "entries_sent": len(entries_to_send), "events_added": total_added, "events_skipped": total_skipped, + "sessions_updated": sessions_updated, "local_parse_errors": local_parse_errors, "remote_parse_errors": total_errors, "remote_url": remote_url, diff --git a/src/agent_session_analytics/guide.md b/src/agent_session_analytics/guide.md index abe6bde..6a2a831 100644 --- a/src/agent_session_analytics/guide.md +++ b/src/agent_session_analytics/guide.md @@ -26,6 +26,7 @@ For setups where the database lives on a central server (e.g., via Tailscale): |------|---------| | `get_sync_status(session_ids?)` | Get latest timestamp per session for incremental sync | | `upload_entries(entries, project_path)` | Upload raw JSONL entries from remote clients | +| `finalize_sync()` | Update session statistics after batch uploads complete | **CLI usage:** ```bash diff --git a/src/agent_session_analytics/server.py b/src/agent_session_analytics/server.py index 004c5bf..5896d51 100644 --- a/src/agent_session_analytics/server.py +++ b/src/agent_session_analytics/server.py @@ -106,7 +106,7 @@ def get_sync_status(session_ids: list[str] | None = None) -> dict: @mcp.tool() -def upload_entries(entries: list[dict], project_path: str) -> dict: +def upload_entries(entries: list[dict], project_path: str, update_stats: bool = False) -> dict: """Upload raw JSONL entries from a remote client. For multi-machine setups where session JSONL files live on client machines. @@ -115,6 +115,7 @@ def upload_entries(entries: list[dict], project_path: str) -> dict: Args: entries: List of raw JSONL entry dicts (as read from session files) project_path: Project path identifier (typically the directory name) + update_stats: Update session stats after insert (default: False, call finalize_sync at end) """ # Parse entries server-side using the same logic as local ingestion all_events = [] @@ -131,8 +132,10 @@ def upload_entries(entries: list[dict], project_path: str) -> dict: # Insert with deduplication (INSERT OR IGNORE on uuid) events_added = storage.add_events_batch(all_events) if all_events else 0 - # Update session statistics - sessions_updated = ingest.update_session_stats(storage) + # Update session statistics only if requested (expensive operation) + sessions_updated = 0 + if update_stats: + sessions_updated = ingest.update_session_stats(storage) return { "status": "ok", @@ -145,6 +148,19 @@ def upload_entries(entries: list[dict], project_path: str) -> dict: } +@mcp.tool() +def finalize_sync() -> dict: + """Finalize a sync operation by updating session statistics. + + Call this once after all upload_entries batches are complete. + """ + sessions_updated = ingest.update_session_stats(storage) + return { + "status": "ok", + "sessions_updated": sessions_updated, + } + + @mcp.tool() def get_tool_frequency(days: int = 7, project: str | None = None, expand: bool = True) -> dict: """Get tool usage frequency counts. @@ -721,10 +737,12 @@ class TailscaleAuthMiddleware: (Tailscale-User-Login) into requests. This middleware rejects requests that don't have these headers. + Localhost connections are trusted and bypass auth. Set AGENT_SESSION_ANALYTICS_AUTH_DISABLED=1 to disable (for testing/local dev). """ TAILSCALE_USER_HEADER = b"tailscale-user-login" + TRUSTED_IPS = ("127.0.0.1", "::1") def __init__(self, app): self.app = app @@ -734,13 +752,18 @@ async def __call__(self, scope, receive, send): await self.app(scope, receive, send) return + # Trust localhost connections + client_ip = scope.get("client", ("", 0))[0] + if client_ip in self.TRUSTED_IPS: + await self.app(scope, receive, send) + return + headers = dict(scope.get("headers", [])) tailscale_user = headers.get(self.TAILSCALE_USER_HEADER) if not tailscale_user: logger.warning( - f"Rejected unauthenticated request to {scope.get('path', '/')} " - f"from {scope.get('client', ('unknown',))[0]}" + f"Rejected unauthenticated request to {scope.get('path', '/')} from {client_ip}" ) await self._send_unauthorized(send) return diff --git a/tests/test_server.py b/tests/test_server.py index 04b60de..cbd11dd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,6 +8,7 @@ analyze_trends, classify_sessions, detect_parallel_sessions, + finalize_sync, find_related_sessions, get_compaction_events, get_error_details, @@ -452,6 +453,14 @@ def test_upload_entries_empty(): assert result["events_parsed"] == 0 +def test_finalize_sync(): + """Test that finalize_sync updates session statistics.""" + result = finalize_sync.fn() + assert result["status"] == "ok" + assert "sessions_updated" in result + assert isinstance(result["sessions_updated"], int) + + # --- Tailscale Auth Middleware Tests --- @@ -524,13 +533,13 @@ async def receive(): @pytest.mark.asyncio async def test_rejects_request_without_tailscale_header(self, mock_app, capture_response): - """Requests without Tailscale-User-Login header get 401.""" + """Requests without Tailscale-User-Login header get 401 (non-localhost).""" middleware = TailscaleAuthMiddleware(mock_app) scope = { "type": "http", "path": "/mcp", "headers": [], - "client": ("127.0.0.1", 12345), + "client": ("192.168.1.100", 12345), # Non-localhost to test auth rejection } async def receive(): @@ -556,3 +565,41 @@ async def receive(): await middleware(scope, receive, capture_response) assert mock_app.called is True + + @pytest.mark.asyncio + async def test_allows_localhost_without_tailscale_header(self, mock_app, capture_response): + """Localhost requests are trusted and bypass auth.""" + middleware = TailscaleAuthMiddleware(mock_app) + scope = { + "type": "http", + "path": "/mcp", + "headers": [], # No Tailscale header + "client": ("127.0.0.1", 12345), + } + + async def receive(): + return {"type": "http.request", "body": b""} + + await middleware(scope, receive, capture_response) + + assert mock_app.called is True # Localhost bypasses auth + assert capture_response.status == 200 + + @pytest.mark.asyncio + async def test_allows_ipv6_localhost_without_tailscale_header(self, mock_app, capture_response): + """IPv6 localhost (::1) requests are trusted and bypass auth.""" + middleware = TailscaleAuthMiddleware(mock_app) + scope = { + "type": "http", + "path": "/mcp", + "headers": [], # No Tailscale header + "client": ("::1", 12345), + } + + async def receive(): + return {"type": "http.request", "body": b""} + + await middleware(scope, receive, capture_response) + + assert mock_app.called is True # IPv6 localhost bypasses auth + assert capture_response.status == 200