Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ def logs_command(args):
print("Press Ctrl+C to stop the server")
print("-" * 50)

# setup Elasticsearch
from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup

elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()

try:
serve_logs(port=args.port)
serve_logs(port=args.port, elasticsearch_config=elasticsearch_config)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
Expand Down
9 changes: 7 additions & 2 deletions eval_protocol/logging/elasticsearch_direct_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ def emit(self, record: logging.LogRecord) -> None:
print(f"Error preparing log for Elasticsearch: {e}")

def _get_rollout_id(self, record: logging.LogRecord) -> str:
"""Get the rollout ID from environment variables."""
"""Get the rollout ID from record extra data or environment variables."""
# Check if rollout_id is provided in the extra data first
if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore
return str(record.rollout_id) # type: ignore

# Fall back to environment variable
rollout_id = os.getenv("EP_ROLLOUT_ID")
if rollout_id is None:
raise ValueError(
"EP_ROLLOUT_ID environment variable is not set but needed for ElasticsearchDirectHttpHandler"
"EP_ROLLOUT_ID environment variable is not set and no rollout_id provided in extra data for ElasticsearchDirectHttpHandler"
)
return rollout_id

Expand Down
13 changes: 11 additions & 2 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,18 @@ def _get_status() -> Dict[str, Any]:
terminated = bool(status.get("terminated", False))
if terminated:
break
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code == 404:
# 404 means server doesn't implement /status endpoint, stop polling
logger.info(
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
)
break
else:
raise
except Exception:
# transient errors; continue polling
pass
# For all other exceptions, raise them
raise

await asyncio.sleep(poll_interval)
else:
Expand Down
45 changes: 45 additions & 0 deletions eval_protocol/utils/logs_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Pydantic models for the logs server API.

This module contains data models that match the TypeScript schemas in eval-protocol.ts
to ensure consistent data structure between Python backend and TypeScript frontend.
"""

from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict, Field


class LogEntry(BaseModel):
"""
Represents a single log entry from Elasticsearch.

This model matches the LogEntrySchema in eval-protocol.ts to ensure
consistent data structure between Python backend and TypeScript frontend.
"""

timestamp: str = Field(..., alias="@timestamp", description="ISO 8601 timestamp of the log entry")
level: str = Field(..., description="Log level (DEBUG, INFO, WARNING, ERROR)")
message: str = Field(..., description="The log message")
logger_name: str = Field(..., description="Name of the logger that created this entry")
rollout_id: str = Field(..., description="ID of the rollout this log belongs to")
status_code: Optional[int] = Field(None, description="Optional status code")
status_message: Optional[str] = Field(None, description="Optional status message")
status_details: Optional[List[Any]] = Field(None, description="Optional status details")

model_config = ConfigDict(populate_by_name=True)


class LogsResponse(BaseModel):
"""
Response model for the get_logs endpoint.

This model matches the LogsResponseSchema in eval-protocol.ts to ensure
consistent data structure between Python backend and TypeScript frontend.
"""

logs: List[LogEntry] = Field(..., description="Array of log entries")
total: int = Field(..., description="Total number of logs available")
rollout_id: str = Field(..., description="The rollout ID these logs belong to")
filtered_by_level: Optional[str] = Field(None, description="Log level filter applied")

model_config = ConfigDict()
143 changes: 129 additions & 14 deletions eval_protocol/utils/logs_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
import time
from contextlib import asynccontextmanager
from queue import Queue
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import psutil
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware

from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
from eval_protocol.event_bus import event_bus
from eval_protocol.models import Status
from eval_protocol.utils.vite_server import ViteServer
from eval_protocol.logging.elasticsearch_client import ElasticsearchClient
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
from eval_protocol.utils.logs_models import LogEntry, LogsResponse

if TYPE_CHECKING:
from eval_protocol.models import EvaluationRow
Expand Down Expand Up @@ -71,8 +75,11 @@ async def _start_broadcast_loop(self):
while True:
try:
# Wait for a message to be queued
message = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
await self._send_text_to_all_connections(message)
message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)

# Regular string message for all connections
await self._send_text_to_all_connections(str(message_data))

except Exception as e:
logger.error(f"Error in broadcast loop: {e}")
await asyncio.sleep(0.1)
Expand Down Expand Up @@ -238,8 +245,8 @@ class LogsServer(ViteServer):
Enhanced server for serving Vite-built SPA with file watching and WebSocket support.

This server extends ViteServer to add:
- WebSocket connections for real-time updates
- Live log streaming
- WebSocket connections for real-time evaluation row updates
- REST API for log querying
"""

def __init__(
Expand All @@ -250,17 +257,49 @@ def __init__(
host: str = "localhost",
port: Optional[int] = 8000,
index_file: str = "index.html",
elasticsearch_config: Optional[ElasticsearchConfig] = None,
):
# Initialize WebSocket manager
self.websocket_manager = WebSocketManager()

super().__init__(build_dir, host, port if port is not None else 8000, index_file)
# Initialize Elasticsearch client if config is provided
self.elasticsearch_client: Optional[ElasticsearchClient] = None
if elasticsearch_config:
self.elasticsearch_client = ElasticsearchClient(elasticsearch_config)

self.app = FastAPI(title="Logs Server")

# Add WebSocket endpoint and API routes
self._setup_websocket_routes()
self._setup_api_routes()

super().__init__(build_dir, host, port if port is not None else 8000, index_file, self.app)

# Add CORS middleware to allow frontend access
allowed_origins = [
"http://localhost:5173", # Vite dev server
"http://127.0.0.1:5173", # Vite dev server (alternative)
f"http://{host}:{port}", # Server's own origin
f"http://localhost:{port}", # Server on localhost
]

self.app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Initialize evaluation watcher
self.evaluation_watcher = EvaluationWatcher(self.websocket_manager)

# Add WebSocket endpoint
self._setup_websocket_routes()
# Log all registered routes for debugging
logger.info("Registered routes:")
for route in self.app.routes:
path = getattr(route, "path", "UNKNOWN")
methods = getattr(route, "methods", {"UNKNOWN"})
logger.info(f" {methods} {path}")

# Subscribe to events and start listening for cross-process events
event_bus.subscribe(self._handle_event)
Expand All @@ -275,14 +314,17 @@ async def websocket_endpoint(websocket: WebSocket):
await self.websocket_manager.connect(websocket)
try:
while True:
# Keep connection alive
# Keep connection alive (for evaluation row updates)
await websocket.receive_text()
except WebSocketDisconnect:
self.websocket_manager.disconnect(websocket)
except Exception as e:
logger.error(f"WebSocket error: {e}")
self.websocket_manager.disconnect(websocket)

def _setup_api_routes(self):
"""Set up API routes."""

@self.app.get("/api/status")
async def status():
"""Get server status including active connections."""
Expand All @@ -295,8 +337,75 @@ async def status():
# LogsServer inherits from ViteServer which doesn't expose watch_paths
# Expose an empty list to satisfy consumers and type checker
"watch_paths": [],
"elasticsearch_enabled": self.elasticsearch_client is not None,
}

@self.app.get("/api/logs/{rollout_id}", response_model=LogsResponse, response_model_exclude_none=True)
async def get_logs(
rollout_id: str,
level: Optional[str] = Query(None, description="Filter by log level (DEBUG, INFO, WARNING, ERROR)"),
limit: int = Query(100, description="Maximum number of log entries to return"),
) -> LogsResponse:
"""Get logs for a specific rollout ID from Elasticsearch."""
if not self.elasticsearch_client:
raise HTTPException(status_code=503, detail="Elasticsearch is not configured for this logs server")

try:
# Search for logs by rollout_id
search_results = self.elasticsearch_client.search_by_match("rollout_id", rollout_id, size=limit)

if not search_results or "hits" not in search_results:
# Return empty response using Pydantic model
return LogsResponse(
logs=[],
total=0,
rollout_id=rollout_id,
filtered_by_level=level,
)

log_entries = []
for hit in search_results["hits"]["hits"]:
log_data = hit["_source"]

# Filter by level if specified
if level and log_data.get("level") != level:
continue

# Create LogEntry using Pydantic model for validation
try:
log_entry = LogEntry(
**log_data # Use ** to unpack the dict, Pydantic will handle field mapping
)
log_entries.append(log_entry)
except Exception as e:
# Log the error but continue processing other entries
logger.warning(f"Failed to parse log entry: {e}, data: {log_data}")
continue

# Sort by timestamp (most recent first)
log_entries.sort(key=lambda x: x.timestamp, reverse=True)

# Get total count
total_hits = search_results["hits"]["total"]
if isinstance(total_hits, dict):
# Elasticsearch 7+ format
total_count = total_hits["value"]
else:
# Elasticsearch 6 format
total_count = total_hits

# Return response using Pydantic model
return LogsResponse(
logs=log_entries,
total=total_count,
rollout_id=rollout_id,
filtered_by_level=level,
)

except Exception as e:
logger.error(f"Error retrieving logs for rollout {rollout_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to retrieve logs: {str(e)}")

def _handle_event(self, event_type: str, data: Any) -> None:
"""Handle events from the event bus."""
if event_type in [LOG_EVENT_TYPE]:
Expand Down Expand Up @@ -353,7 +462,12 @@ def run(self):
asyncio.run(self.run_async())


def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[str] = None) -> FastAPI:
def create_app(
host: str = "localhost",
port: int = 8000,
build_dir: Optional[str] = None,
elasticsearch_config: Optional[ElasticsearchConfig] = None,
) -> FastAPI:
"""
Factory function to create a FastAPI app instance and start the server with async loops.

Expand All @@ -364,6 +478,7 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
host: Host to bind to
port: Port to bind to
build_dir: Optional custom build directory path
elasticsearch_config: Optional Elasticsearch configuration for log querying

Returns:
FastAPI app instance with server running in background
Expand All @@ -373,17 +488,17 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "vite-app", "dist")
)

server = LogsServer(host=host, port=port, build_dir=build_dir)
server = LogsServer(host=host, port=port, build_dir=build_dir, elasticsearch_config=elasticsearch_config)
server.start_loops()
return server.app


# For backward compatibility and direct usage
def serve_logs(port: Optional[int] = None):
def serve_logs(port: Optional[int] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None):
"""
Convenience function to create and run a LogsServer.
"""
server = LogsServer(port=port)
server = LogsServer(port=port, elasticsearch_config=elasticsearch_config)
server.run()


Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/utils/vite_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(
host: str = "localhost",
port: int = 8000,
index_file: str = "index.html",
lifespan: Optional[Callable[[FastAPI], Any]] = None,
app: Optional[FastAPI] = None,
):
self.build_dir = Path(build_dir)
self.host = host
self.port = port
self.index_file = index_file
self.app = FastAPI(title="Vite SPA Server", lifespan=lifespan)
self.app = app if app is not None else FastAPI(title="Vite SPA Server")

# Validate build directory exists
if not self.build_dir.exists():
Expand Down
Loading
Loading