diff --git a/eval_protocol/cli_commands/logs.py b/eval_protocol/cli_commands/logs.py index 92b1be58..414826bc 100644 --- a/eval_protocol/cli_commands/logs.py +++ b/eval_protocol/cli_commands/logs.py @@ -19,8 +19,13 @@ def logs_command(args): print("Press Ctrl+C to stop the server") print("-" * 50) + # setup Elasticsearch + from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup + + elasticsearch_config = ElasticsearchSetup().setup_elasticsearch() + try: - serve_logs(port=args.port) + serve_logs(port=args.port, elasticsearch_config=elasticsearch_config) return 0 except KeyboardInterrupt: print("\nš Server stopped by user") diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index 015445ce..8b729dfa 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -51,11 +51,16 @@ def emit(self, record: logging.LogRecord) -> None: print(f"Error preparing log for Elasticsearch: {e}") def _get_rollout_id(self, record: logging.LogRecord) -> str: - """Get the rollout ID from environment variables.""" + """Get the rollout ID from record extra data or environment variables.""" + # Check if rollout_id is provided in the extra data first + if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore + return str(record.rollout_id) # type: ignore + + # Fall back to environment variable rollout_id = os.getenv("EP_ROLLOUT_ID") if rollout_id is None: raise ValueError( - "EP_ROLLOUT_ID environment variable is not set but needed for ElasticsearchDirectHttpHandler" + "EP_ROLLOUT_ID environment variable is not set and no rollout_id provided in extra data for ElasticsearchDirectHttpHandler" ) return rollout_id diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index f073597a..09110a90 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -194,9 +194,18 @@ def _get_status() -> Dict[str, Any]: terminated = bool(status.get("terminated", False)) if terminated: break + except requests.exceptions.HTTPError as e: + if e.response is not None and e.response.status_code == 404: + # 404 means server doesn't implement /status endpoint, stop polling + logger.info( + f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}" + ) + break + else: + raise except Exception: - # transient errors; continue polling - pass + # For all other exceptions, raise them + raise await asyncio.sleep(poll_interval) else: diff --git a/eval_protocol/utils/logs_models.py b/eval_protocol/utils/logs_models.py new file mode 100644 index 00000000..c5e91b12 --- /dev/null +++ b/eval_protocol/utils/logs_models.py @@ -0,0 +1,45 @@ +""" +Pydantic models for the logs server API. + +This module contains data models that match the TypeScript schemas in eval-protocol.ts +to ensure consistent data structure between Python backend and TypeScript frontend. +""" + +from typing import Any, List, Optional +from pydantic import BaseModel, ConfigDict, Field + + +class LogEntry(BaseModel): + """ + Represents a single log entry from Elasticsearch. + + This model matches the LogEntrySchema in eval-protocol.ts to ensure + consistent data structure between Python backend and TypeScript frontend. + """ + + timestamp: str = Field(..., alias="@timestamp", description="ISO 8601 timestamp of the log entry") + level: str = Field(..., description="Log level (DEBUG, INFO, WARNING, ERROR)") + message: str = Field(..., description="The log message") + logger_name: str = Field(..., description="Name of the logger that created this entry") + rollout_id: str = Field(..., description="ID of the rollout this log belongs to") + status_code: Optional[int] = Field(None, description="Optional status code") + status_message: Optional[str] = Field(None, description="Optional status message") + status_details: Optional[List[Any]] = Field(None, description="Optional status details") + + model_config = ConfigDict(populate_by_name=True) + + +class LogsResponse(BaseModel): + """ + Response model for the get_logs endpoint. + + This model matches the LogsResponseSchema in eval-protocol.ts to ensure + consistent data structure between Python backend and TypeScript frontend. + """ + + logs: List[LogEntry] = Field(..., description="Array of log entries") + total: int = Field(..., description="Total number of logs available") + rollout_id: str = Field(..., description="The rollout ID these logs belong to") + filtered_by_level: Optional[str] = Field(None, description="Log level filter applied") + + model_config = ConfigDict() diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index 73781b21..f5358fbe 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -6,17 +6,21 @@ import time from contextlib import asynccontextmanager from queue import Queue -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import psutil import uvicorn -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query +from fastapi.middleware.cors import CORSMiddleware from eval_protocol.dataset_logger import default_logger from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE from eval_protocol.event_bus import event_bus from eval_protocol.models import Status from eval_protocol.utils.vite_server import ViteServer +from eval_protocol.logging.elasticsearch_client import ElasticsearchClient +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig +from eval_protocol.utils.logs_models import LogEntry, LogsResponse if TYPE_CHECKING: from eval_protocol.models import EvaluationRow @@ -71,8 +75,11 @@ async def _start_broadcast_loop(self): while True: try: # Wait for a message to be queued - message = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) - await self._send_text_to_all_connections(message) + message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) + + # Regular string message for all connections + await self._send_text_to_all_connections(str(message_data)) + except Exception as e: logger.error(f"Error in broadcast loop: {e}") await asyncio.sleep(0.1) @@ -238,8 +245,8 @@ class LogsServer(ViteServer): Enhanced server for serving Vite-built SPA with file watching and WebSocket support. This server extends ViteServer to add: - - WebSocket connections for real-time updates - - Live log streaming + - WebSocket connections for real-time evaluation row updates + - REST API for log querying """ def __init__( @@ -250,17 +257,49 @@ def __init__( host: str = "localhost", port: Optional[int] = 8000, index_file: str = "index.html", + elasticsearch_config: Optional[ElasticsearchConfig] = None, ): # Initialize WebSocket manager self.websocket_manager = WebSocketManager() - super().__init__(build_dir, host, port if port is not None else 8000, index_file) + # Initialize Elasticsearch client if config is provided + self.elasticsearch_client: Optional[ElasticsearchClient] = None + if elasticsearch_config: + self.elasticsearch_client = ElasticsearchClient(elasticsearch_config) + + self.app = FastAPI(title="Logs Server") + + # Add WebSocket endpoint and API routes + self._setup_websocket_routes() + self._setup_api_routes() + + super().__init__(build_dir, host, port if port is not None else 8000, index_file, self.app) + + # Add CORS middleware to allow frontend access + allowed_origins = [ + "http://localhost:5173", # Vite dev server + "http://127.0.0.1:5173", # Vite dev server (alternative) + f"http://{host}:{port}", # Server's own origin + f"http://localhost:{port}", # Server on localhost + ] + + self.app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) # Initialize evaluation watcher self.evaluation_watcher = EvaluationWatcher(self.websocket_manager) - # Add WebSocket endpoint - self._setup_websocket_routes() + # Log all registered routes for debugging + logger.info("Registered routes:") + for route in self.app.routes: + path = getattr(route, "path", "UNKNOWN") + methods = getattr(route, "methods", {"UNKNOWN"}) + logger.info(f" {methods} {path}") # Subscribe to events and start listening for cross-process events event_bus.subscribe(self._handle_event) @@ -275,7 +314,7 @@ async def websocket_endpoint(websocket: WebSocket): await self.websocket_manager.connect(websocket) try: while True: - # Keep connection alive + # Keep connection alive (for evaluation row updates) await websocket.receive_text() except WebSocketDisconnect: self.websocket_manager.disconnect(websocket) @@ -283,6 +322,9 @@ async def websocket_endpoint(websocket: WebSocket): logger.error(f"WebSocket error: {e}") self.websocket_manager.disconnect(websocket) + def _setup_api_routes(self): + """Set up API routes.""" + @self.app.get("/api/status") async def status(): """Get server status including active connections.""" @@ -295,8 +337,75 @@ async def status(): # LogsServer inherits from ViteServer which doesn't expose watch_paths # Expose an empty list to satisfy consumers and type checker "watch_paths": [], + "elasticsearch_enabled": self.elasticsearch_client is not None, } + @self.app.get("/api/logs/{rollout_id}", response_model=LogsResponse, response_model_exclude_none=True) + async def get_logs( + rollout_id: str, + level: Optional[str] = Query(None, description="Filter by log level (DEBUG, INFO, WARNING, ERROR)"), + limit: int = Query(100, description="Maximum number of log entries to return"), + ) -> LogsResponse: + """Get logs for a specific rollout ID from Elasticsearch.""" + if not self.elasticsearch_client: + raise HTTPException(status_code=503, detail="Elasticsearch is not configured for this logs server") + + try: + # Search for logs by rollout_id + search_results = self.elasticsearch_client.search_by_match("rollout_id", rollout_id, size=limit) + + if not search_results or "hits" not in search_results: + # Return empty response using Pydantic model + return LogsResponse( + logs=[], + total=0, + rollout_id=rollout_id, + filtered_by_level=level, + ) + + log_entries = [] + for hit in search_results["hits"]["hits"]: + log_data = hit["_source"] + + # Filter by level if specified + if level and log_data.get("level") != level: + continue + + # Create LogEntry using Pydantic model for validation + try: + log_entry = LogEntry( + **log_data # Use ** to unpack the dict, Pydantic will handle field mapping + ) + log_entries.append(log_entry) + except Exception as e: + # Log the error but continue processing other entries + logger.warning(f"Failed to parse log entry: {e}, data: {log_data}") + continue + + # Sort by timestamp (most recent first) + log_entries.sort(key=lambda x: x.timestamp, reverse=True) + + # Get total count + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + # Elasticsearch 7+ format + total_count = total_hits["value"] + else: + # Elasticsearch 6 format + total_count = total_hits + + # Return response using Pydantic model + return LogsResponse( + logs=log_entries, + total=total_count, + rollout_id=rollout_id, + filtered_by_level=level, + ) + + except Exception as e: + logger.error(f"Error retrieving logs for rollout {rollout_id}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to retrieve logs: {str(e)}") + def _handle_event(self, event_type: str, data: Any) -> None: """Handle events from the event bus.""" if event_type in [LOG_EVENT_TYPE]: @@ -353,7 +462,12 @@ def run(self): asyncio.run(self.run_async()) -def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[str] = None) -> FastAPI: +def create_app( + host: str = "localhost", + port: int = 8000, + build_dir: Optional[str] = None, + elasticsearch_config: Optional[ElasticsearchConfig] = None, +) -> FastAPI: """ Factory function to create a FastAPI app instance and start the server with async loops. @@ -364,6 +478,7 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st host: Host to bind to port: Port to bind to build_dir: Optional custom build directory path + elasticsearch_config: Optional Elasticsearch configuration for log querying Returns: FastAPI app instance with server running in background @@ -373,17 +488,17 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "vite-app", "dist") ) - server = LogsServer(host=host, port=port, build_dir=build_dir) + server = LogsServer(host=host, port=port, build_dir=build_dir, elasticsearch_config=elasticsearch_config) server.start_loops() return server.app # For backward compatibility and direct usage -def serve_logs(port: Optional[int] = None): +def serve_logs(port: Optional[int] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None): """ Convenience function to create and run a LogsServer. """ - server = LogsServer(port=port) + server = LogsServer(port=port, elasticsearch_config=elasticsearch_config) server.run() diff --git a/eval_protocol/utils/vite_server.py b/eval_protocol/utils/vite_server.py index 8c91cadb..d1ad1b08 100644 --- a/eval_protocol/utils/vite_server.py +++ b/eval_protocol/utils/vite_server.py @@ -32,13 +32,13 @@ def __init__( host: str = "localhost", port: int = 8000, index_file: str = "index.html", - lifespan: Optional[Callable[[FastAPI], Any]] = None, + app: Optional[FastAPI] = None, ): self.build_dir = Path(build_dir) self.host = host self.port = port self.index_file = index_file - self.app = FastAPI(title="Vite SPA Server", lifespan=lifespan) + self.app = app if app is not None else FastAPI(title="Vite SPA Server") # Validate build directory exists if not self.build_dir.exists(): diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 6a50f088..906d5486 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -434,3 +434,83 @@ def test_elasticsearch_direct_http_handler_search_by_status_code( ) print(f"Successfully verified search by status code {running_status.value} found {len(hits)} log messages") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_rollout_id_from_extra_overrides_env( + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str +): + """Test that rollout_id in extra parameter overrides environment variable.""" + + # Create a different rollout_id to pass in extra + extra_rollout_id = f"extra-rollout-{time.time()}" + + # Generate a unique test message + test_message = f"Rollout ID override test message at {time.time()}" + + # Log with rollout_id in extra data (should override environment variable) + test_logger.info(test_message, extra={"rollout_id": extra_rollout_id}) + + # Give Elasticsearch time to process the document + time.sleep(3) + + # Search for logs with the extra rollout_id (not the environment one) + search_results = elasticsearch_client.search_by_term("rollout_id", extra_rollout_id, size=1) + + # Assert that we found our log message with the extra rollout_id + assert search_results is not None, "Search should return results" + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + total_count = total_hits["value"] + else: + total_count = total_hits + + assert total_count > 0, f"Expected to find at least 1 log message with extra rollout_id, but found {total_count}" + + # Verify the content of the found document + hits = search_results["hits"]["hits"] + assert len(hits) > 0, "Expected at least one hit" + + found_document = hits[0]["_source"] + + # Verify the rollout_id field matches the extra parameter (not environment variable) + assert "rollout_id" in found_document, "Expected document to contain 'rollout_id' field" + assert found_document["rollout_id"] == extra_rollout_id, ( + f"Expected rollout_id '{extra_rollout_id}', got '{found_document['rollout_id']}'" + ) + + # Verify it's NOT the environment variable rollout_id + assert found_document["rollout_id"] != rollout_id, ( + f"Expected rollout_id to be overridden, but got environment rollout_id '{rollout_id}'" + ) + + # Verify other expected fields are still present + assert found_document["message"] == test_message, ( + f"Expected message '{test_message}', got '{found_document['message']}'" + ) + assert found_document["level"] == "INFO", f"Expected level 'INFO', got '{found_document['level']}'" + assert found_document["logger_name"] == "test_elasticsearch_logger", ( + f"Expected logger name 'test_elasticsearch_logger', got '{found_document['logger_name']}'" + ) + assert "@timestamp" in found_document, "Expected document to contain '@timestamp' field" + + # Verify that searching for the original environment rollout_id doesn't find this message + env_search_results = elasticsearch_client.search( + {"bool": {"must": [{"term": {"rollout_id": rollout_id}}, {"match": {"message": test_message}}]}}, size=1 + ) + + assert env_search_results is not None, "Environment rollout_id search should return results" + env_total_hits = env_search_results["hits"]["total"] + if isinstance(env_total_hits, dict): + env_count = env_total_hits["value"] + else: + env_count = env_total_hits + + assert env_count == 0, ( + f"Expected 0 results when searching for message with environment rollout_id, but found {env_count}" + ) + + print(f"Successfully verified rollout_id override: extra '{extra_rollout_id}' overrode environment '{rollout_id}'") diff --git a/tests/test_logs_server.py b/tests/test_logs_server.py index 92f8d8d7..0dcc1cf3 100644 --- a/tests/test_logs_server.py +++ b/tests/test_logs_server.py @@ -373,7 +373,7 @@ def test_serve_logs_port_parameter(self, temp_build_dir): serve_logs(port=test_port) # Verify that LogsServer was created with the correct port - mock_logs_server_class.assert_called_once_with(port=test_port) + mock_logs_server_class.assert_called_once_with(port=test_port, elasticsearch_config=None) # Verify that the run method was called on the instance mock_server_instance.run.assert_called_once() @@ -387,7 +387,7 @@ def test_serve_logs_default_port(self, temp_build_dir): serve_logs() # Verify that LogsServer was created with None port (which will use LogsServer's default of 8000) - mock_logs_server_class.assert_called_once_with(port=None) + mock_logs_server_class.assert_called_once_with(port=None, elasticsearch_config=None) # Verify that the run method was called on the instance mock_server_instance.run.assert_called_once() diff --git a/vite-app/src/components/EvaluationRow.tsx b/vite-app/src/components/EvaluationRow.tsx index 9fc2f9b6..65a3bbb8 100644 --- a/vite-app/src/components/EvaluationRow.tsx +++ b/vite-app/src/components/EvaluationRow.tsx @@ -5,6 +5,7 @@ import type { } from "../types/eval-protocol"; import { ChatInterface } from "./ChatInterface"; import { MetadataSection } from "./MetadataSection"; +import { LogsSection } from "./LogsSection"; import StatusIndicator from "./StatusIndicator"; import { state } from "../App"; import { TableCell, TableRowInteractive } from "./TableContainer"; @@ -372,6 +373,7 @@ const ExpandedContent = observer( {/* Right Column - Metadata */}