Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions src/agent_session_analytics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def cmd_benchmark(args):
# - ingest_logs, ingest_git_history, ingest_git_history_all_projects
# - correlate_git_with_sessions, ingest_bus_events
# - find_related_sessions (requires valid session_id)
# - upload_entries, get_sync_status (remote sync tools - modify DB or require client context)
# - upload_entries, get_sync_status, finalize_sync (remote sync tools - modify DB or require client context)

benchmarks = []
for tool_name, tool_func in tool_functions.items():
Expand Down Expand Up @@ -1530,11 +1530,23 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None:
req = urllib.request.Request(
remote_url,
data=json.dumps(mcp_request).encode("utf-8"),
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
method="POST",
)
with urllib.request.urlopen(req, timeout=60) as resp:
result = json.loads(resp.read().decode("utf-8"))
with urllib.request.urlopen(req, timeout=120) as resp:
raw_response = resp.read().decode("utf-8")
# Parse SSE format: "event: message\ndata: {...}"
result = None
for line in raw_response.split("\n"):
if line.startswith("data: "):
result = json.loads(line[6:])
break
if result is None:
# Try parsing as plain JSON (fallback)
result = json.loads(raw_response)
if "result" in result:
content = result["result"].get("content", [])
if content and content[0].get("type") == "text":
Expand Down Expand Up @@ -1606,12 +1618,19 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None:
server_latest = server_sessions.get(session_id)
if server_latest:
# Parse server timestamp and filter
from datetime import timezone

server_ts = datetime.fromisoformat(server_latest.replace("Z", "+00:00"))
if server_ts.tzinfo is None:
server_ts = server_ts.replace(tzinfo=timezone.utc)
for project_path, entry in entries:
entry_ts_str = entry.get("timestamp")
if entry_ts_str:
try:
entry_ts = datetime.fromisoformat(entry_ts_str.replace("Z", "+00:00"))
# Ensure both are timezone-aware for comparison
if entry_ts.tzinfo is None:
entry_ts = entry_ts.replace(tzinfo=timezone.utc)
if entry_ts > server_ts:
entries_to_send.append((project_path, entry))
except ValueError:
Expand Down Expand Up @@ -1648,10 +1667,15 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None:
total_added = 0
total_skipped = 0
total_errors = 0
entries_uploaded = 0
total_entries = len(entries_to_send)
start_time = datetime.now()

# Upload in batches per project
batch_size = args.batch_size
for project_path, entries in by_project.items():
if not args.json:
print(f" {project_path}: {len(entries)} entries", flush=True)
for i in range(0, len(entries), batch_size):
batch = entries[i : i + batch_size]
result = mcp_call("upload_entries", {"entries": batch, "project_path": project_path})
Expand All @@ -1660,9 +1684,26 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None:
total_added += result.get("events_added", 0)
total_skipped += result.get("events_skipped", 0)
total_errors += result.get("parse_errors", 0)

if not args.json:
print(f" {project_path}: {len(entries)} entries")
entries_uploaded += len(batch)

# Progress update with time estimate
if not args.json and total_entries > 0:
elapsed = (datetime.now() - start_time).total_seconds()
pct = entries_uploaded / total_entries * 100
if entries_uploaded > 0 and elapsed > 0:
rate = entries_uploaded / elapsed
remaining = (total_entries - entries_uploaded) / rate if rate > 0 else 0
print(
f" [{pct:5.1f}%] {entries_uploaded}/{total_entries} "
f"({rate:.0f}/s, ~{remaining:.0f}s remaining)",
flush=True,
)

# Finalize sync by updating session statistics once
if entries_to_send and not args.json:
print("Finalizing sync (updating session stats)...", flush=True)
finalize_result = mcp_call("finalize_sync", {})
sessions_updated = finalize_result.get("sessions_updated", 0) if finalize_result else 0

output = {
"status": "ok",
Expand All @@ -1671,6 +1712,7 @@ def mcp_call(method_name: str, arguments: dict) -> dict | None:
"entries_sent": len(entries_to_send),
"events_added": total_added,
"events_skipped": total_skipped,
"sessions_updated": sessions_updated,
"local_parse_errors": local_parse_errors,
"remote_parse_errors": total_errors,
"remote_url": remote_url,
Expand Down
1 change: 1 addition & 0 deletions src/agent_session_analytics/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ For setups where the database lives on a central server (e.g., via Tailscale):
|------|---------|
| `get_sync_status(session_ids?)` | Get latest timestamp per session for incremental sync |
| `upload_entries(entries, project_path)` | Upload raw JSONL entries from remote clients |
| `finalize_sync()` | Update session statistics after batch uploads complete |

**CLI usage:**
```bash
Expand Down
33 changes: 28 additions & 5 deletions src/agent_session_analytics/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_sync_status(session_ids: list[str] | None = None) -> dict:


@mcp.tool()
def upload_entries(entries: list[dict], project_path: str) -> dict:
def upload_entries(entries: list[dict], project_path: str, update_stats: bool = False) -> dict:
"""Upload raw JSONL entries from a remote client.

For multi-machine setups where session JSONL files live on client machines.
Expand All @@ -115,6 +115,7 @@ def upload_entries(entries: list[dict], project_path: str) -> dict:
Args:
entries: List of raw JSONL entry dicts (as read from session files)
project_path: Project path identifier (typically the directory name)
update_stats: Update session stats after insert (default: False, call finalize_sync at end)
"""
# Parse entries server-side using the same logic as local ingestion
all_events = []
Expand All @@ -131,8 +132,10 @@ def upload_entries(entries: list[dict], project_path: str) -> dict:
# Insert with deduplication (INSERT OR IGNORE on uuid)
events_added = storage.add_events_batch(all_events) if all_events else 0

# Update session statistics
sessions_updated = ingest.update_session_stats(storage)
# Update session statistics only if requested (expensive operation)
sessions_updated = 0
if update_stats:
sessions_updated = ingest.update_session_stats(storage)

return {
"status": "ok",
Expand All @@ -145,6 +148,19 @@ def upload_entries(entries: list[dict], project_path: str) -> dict:
}


@mcp.tool()
def finalize_sync() -> dict:
"""Finalize a sync operation by updating session statistics.

Call this once after all upload_entries batches are complete.
"""
sessions_updated = ingest.update_session_stats(storage)
return {
"status": "ok",
"sessions_updated": sessions_updated,
}


@mcp.tool()
def get_tool_frequency(days: int = 7, project: str | None = None, expand: bool = True) -> dict:
"""Get tool usage frequency counts.
Expand Down Expand Up @@ -721,10 +737,12 @@ class TailscaleAuthMiddleware:
(Tailscale-User-Login) into requests. This middleware rejects requests
that don't have these headers.

Localhost connections are trusted and bypass auth.
Set AGENT_SESSION_ANALYTICS_AUTH_DISABLED=1 to disable (for testing/local dev).
"""

TAILSCALE_USER_HEADER = b"tailscale-user-login"
TRUSTED_IPS = ("127.0.0.1", "::1")

def __init__(self, app):
self.app = app
Expand All @@ -734,13 +752,18 @@ async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
return

# Trust localhost connections
client_ip = scope.get("client", ("", 0))[0]
if client_ip in self.TRUSTED_IPS:
await self.app(scope, receive, send)
return

headers = dict(scope.get("headers", []))
tailscale_user = headers.get(self.TAILSCALE_USER_HEADER)

if not tailscale_user:
logger.warning(
f"Rejected unauthenticated request to {scope.get('path', '/')} "
f"from {scope.get('client', ('unknown',))[0]}"
f"Rejected unauthenticated request to {scope.get('path', '/')} from {client_ip}"
)
await self._send_unauthorized(send)
return
Expand Down
51 changes: 49 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
analyze_trends,
classify_sessions,
detect_parallel_sessions,
finalize_sync,
find_related_sessions,
get_compaction_events,
get_error_details,
Expand Down Expand Up @@ -452,6 +453,14 @@ def test_upload_entries_empty():
assert result["events_parsed"] == 0


def test_finalize_sync():
"""Test that finalize_sync updates session statistics."""
result = finalize_sync.fn()
assert result["status"] == "ok"
assert "sessions_updated" in result
assert isinstance(result["sessions_updated"], int)


# --- Tailscale Auth Middleware Tests ---


Expand Down Expand Up @@ -524,13 +533,13 @@ async def receive():

@pytest.mark.asyncio
async def test_rejects_request_without_tailscale_header(self, mock_app, capture_response):
"""Requests without Tailscale-User-Login header get 401."""
"""Requests without Tailscale-User-Login header get 401 (non-localhost)."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "http",
"path": "/mcp",
"headers": [],
"client": ("127.0.0.1", 12345),
"client": ("192.168.1.100", 12345), # Non-localhost to test auth rejection
}

async def receive():
Expand All @@ -556,3 +565,41 @@ async def receive():
await middleware(scope, receive, capture_response)

assert mock_app.called is True

@pytest.mark.asyncio
async def test_allows_localhost_without_tailscale_header(self, mock_app, capture_response):
"""Localhost requests are trusted and bypass auth."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "http",
"path": "/mcp",
"headers": [], # No Tailscale header
"client": ("127.0.0.1", 12345),
}

async def receive():
return {"type": "http.request", "body": b""}

await middleware(scope, receive, capture_response)

assert mock_app.called is True # Localhost bypasses auth
assert capture_response.status == 200

@pytest.mark.asyncio
async def test_allows_ipv6_localhost_without_tailscale_header(self, mock_app, capture_response):
"""IPv6 localhost (::1) requests are trusted and bypass auth."""
middleware = TailscaleAuthMiddleware(mock_app)
scope = {
"type": "http",
"path": "/mcp",
"headers": [], # No Tailscale header
"client": ("::1", 12345),
}

async def receive():
return {"type": "http.request", "body": b""}

await middleware(scope, receive, capture_response)

assert mock_app.called is True # IPv6 localhost bypasses auth
assert capture_response.status == 200