From 5b2b7b4ceb957a729cf786c9a1e25809c6a447b0 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:09:33 -0700 Subject: [PATCH 01/11] match height of logs --- vite-app/src/components/ChatInterface.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vite-app/src/components/ChatInterface.tsx b/vite-app/src/components/ChatInterface.tsx index 9d6c1862..0b220fad 100644 --- a/vite-app/src/components/ChatInterface.tsx +++ b/vite-app/src/components/ChatInterface.tsx @@ -87,7 +87,7 @@ export const ChatInterface = ({ messages }: ChatInterfaceProps) => { e.preventDefault(); const deltaY = e.clientY - initialMouseY; const newHeight = initialHeight + deltaY; - setChatHeight(Math.max(200, Math.min(800, newHeight))); // Min 200px, max 800px + setChatHeight(Math.max(200, Math.min(844, newHeight))); // Min 200px, max 844px } }; From 01e4ec763e5281e1c38d7675d3423f82e43d141b Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:10:33 -0700 Subject: [PATCH 02/11] properly create timestmap in log emission --- .../log_utils/elasticsearch_direct_http_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval_protocol/log_utils/elasticsearch_direct_http_handler.py b/eval_protocol/log_utils/elasticsearch_direct_http_handler.py index cd1a3765..735869b0 100644 --- a/eval_protocol/log_utils/elasticsearch_direct_http_handler.py +++ b/eval_protocol/log_utils/elasticsearch_direct_http_handler.py @@ -2,7 +2,7 @@ import os from concurrent.futures import ThreadPoolExecutor from typing import Optional, Any, Dict -from datetime import datetime +from datetime import datetime, timezone from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig from .elasticsearch_client import ElasticsearchClient @@ -36,8 +36,8 @@ def configure(self, elasticsearch_config: ElasticsearchConfig) -> None: def emit(self, record: logging.LogRecord) -> None: """Emit a log record by scheduling it for async transmission.""" try: - # Create proper ISO 8601 timestamp - timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + # Create proper ISO 8601 timestamp in UTC + timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ") rollout_id = self._get_rollout_id(record) logger.debug(f"Emitting log record: {record.getMessage()} with rollout_id: {rollout_id}") From cb115279a768f8ef93bb23c44c21641035f59082 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:12:50 -0700 Subject: [PATCH 03/11] added working test --- .../test_elasticsearch_direct_http_handler.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 428f2dda..f7e72680 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -2,6 +2,7 @@ import logging import time import pytest +from datetime import datetime, timezone from eval_protocol.log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient @@ -514,3 +515,91 @@ def test_elasticsearch_direct_http_handler_rollout_id_from_extra_overrides_env( ) print(f"Successfully verified rollout_id override: extra '{extra_rollout_id}' overrode environment '{rollout_id}'") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_timestamp_format( + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str +): + """Test that ElasticsearchDirectHttpHandler formats timestamps correctly with UTC timezone.""" + + # Generate a unique test message + test_message = f"Timestamp format test message at {time.time()}" + + # Record the time before logging to compare with the timestamp + before_log_time = datetime.now(timezone.utc) + + # Send the log message + test_logger.info(test_message) + + # Record the time after logging + after_log_time = datetime.now(timezone.utc) + + # Give Elasticsearch a moment to process the document + time.sleep(3) + + # Search for the document using the client + search_results = elasticsearch_client.search_by_match("message", test_message, size=1) + + # Assert that we found our log message + assert search_results is not None, "Search should return results" + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + # Elasticsearch 7+ format + total_count = total_hits["value"] + else: + # Elasticsearch 6 format + total_count = total_hits + + assert total_count > 0, f"Expected to find at least 1 log message, but found {total_count}" + + # Verify the content of the found document + hits = search_results["hits"]["hits"] + assert len(hits) > 0, "Expected at least one hit" + + found_document = hits[0]["_source"] + assert "@timestamp" in found_document, "Expected document to contain '@timestamp' field" + + # Get the timestamp from the document + timestamp_str = found_document["@timestamp"] + + # Verify the timestamp format matches ISO 8601 with UTC timezone (Z suffix) + assert timestamp_str.endswith("Z"), f"Expected timestamp to end with 'Z' (UTC), got: {timestamp_str}" + + # Parse the timestamp to verify it's valid + try: + parsed_timestamp = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) + except ValueError as e: + pytest.fail(f"Failed to parse timestamp '{timestamp_str}': {e}") + + # Verify the timestamp is timezone-aware (UTC) + assert parsed_timestamp.tzinfo is not None, "Expected timestamp to be timezone-aware" + utc_offset = parsed_timestamp.tzinfo.utcoffset(None) + assert utc_offset is not None and utc_offset.total_seconds() == 0, "Expected timestamp to be in UTC timezone" + + # Verify the timestamp is within reasonable bounds (between before and after log time) + # Allow for some margin due to processing time + from datetime import timedelta + + time_margin = timedelta(seconds=5) # 5 seconds margin + assert before_log_time - time_margin <= parsed_timestamp <= after_log_time + time_margin, ( + f"Expected timestamp {parsed_timestamp} to be between {before_log_time} and {after_log_time} " + f"(with {time_margin} margin)" + ) + + # Verify the timestamp format includes microseconds + assert "." in timestamp_str, "Expected timestamp to include microseconds" + assert timestamp_str.count(".") == 1, "Expected timestamp to have exactly one decimal point" + + # Verify the format matches the expected pattern: YYYY-MM-DDTHH:MM:SS.ffffffZ + import re + + iso_pattern = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{6}Z$" + assert re.match(iso_pattern, timestamp_str), f"Expected timestamp to match ISO 8601 pattern, got: {timestamp_str}" + + print(f"Successfully verified timestamp format: {timestamp_str}") + print(f"Parsed timestamp: {parsed_timestamp} (UTC)") + print(f"Timestamp is within expected time range: {before_log_time} <= {parsed_timestamp} <= {after_log_time}") From ce00763788c12a34733f11a2291546c4e5b5843a Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:22:37 -0700 Subject: [PATCH 04/11] add clear index / set/cleanup logic for all ecs tests --- .../log_utils/elasticsearch_client.py | 19 ++++++++++ .../test_elasticsearch_direct_http_handler.py | 35 +++++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/eval_protocol/log_utils/elasticsearch_client.py b/eval_protocol/log_utils/elasticsearch_client.py index a59269ea..860b90fe 100644 --- a/eval_protocol/log_utils/elasticsearch_client.py +++ b/eval_protocol/log_utils/elasticsearch_client.py @@ -100,6 +100,25 @@ def delete_index(self) -> bool: except Exception: return False + def clear_index(self) -> bool: + """Clear all documents from the index. + + Returns: + bool: True if successful, False otherwise + """ + try: + # Delete all documents by query + response = self._make_request( + "POST", f"{self.index_url}/_delete_by_query", json_data={"query": {"match_all": {}}} + ) + if response.status_code == 200: + # Refresh the index to ensure changes are visible + refresh_response = self._make_request("POST", f"{self.index_url}/_refresh") + return refresh_response.status_code == 200 + return False + except Exception: + return False + def get_mapping(self) -> Optional[Dict[str, Any]]: """Get the index mapping. diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index f7e72680..8156b4d2 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -32,11 +32,23 @@ def rollout_id(): def elasticsearch_config(): """Set up Elasticsearch and return configuration.""" import time + import uuid - index_name = f"test-logs-{int(time.time())}" + # Use a more unique index name to avoid conflicts between tests + index_name = f"test-logs-{int(time.time())}-{uuid.uuid4().hex[:8]}" setup = ElasticsearchSetup() config = setup.setup_elasticsearch(index_name) - return config + yield config + + # Clean up the index after the test + try: + # Create a client to clean up the index + from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient + + client = ElasticsearchClient(config) + client.delete_index() + except Exception as e: + print(f"Warning: Failed to cleanup Elasticsearch index {index_name}: {e}") @pytest.fixture @@ -78,7 +90,24 @@ def test_logger(elasticsearch_handler, elasticsearch_config, rollout_id: str): # Prevent propagation to avoid duplicate logs logger.propagate = False - return logger + yield logger + + # Clean up the logger handlers after the test + logger.handlers.clear() + + +@pytest.fixture(autouse=True) +def clear_elasticsearch_before_test( + elasticsearch_client: ElasticsearchClient, elasticsearch_config: ElasticsearchConfig +): + """Clear Elasticsearch index before each test to ensure clean state.""" + try: + # Clear all documents from the index before each test + success = elasticsearch_client.clear_index() + if not success: + print(f"Warning: Failed to clear Elasticsearch index {elasticsearch_config.index_name}") + except Exception as e: + print(f"Warning: Failed to clear Elasticsearch index before test: {e}") @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") From a32af57113a26e560f7d23a09a5e02ebd854fdc3 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:22:54 -0700 Subject: [PATCH 05/11] remve skip markers --- tests/logging/test_elasticsearch_direct_http_handler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 8156b4d2..df59f9de 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -110,7 +110,6 @@ def clear_elasticsearch_before_test( print(f"Warning: Failed to clear Elasticsearch index before test: {e}") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sends_logs( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -160,7 +159,6 @@ def test_elasticsearch_direct_http_handler_sends_logs( print(f"Successfully verified log message in Elasticsearch: {test_message}") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -214,7 +212,6 @@ def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( print(f"Timestamps in order: {found_timestamps}") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_includes_rollout_id( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -272,7 +269,6 @@ def test_elasticsearch_direct_http_handler_includes_rollout_id( print(f"Successfully verified log message with rollout_id '{rollout_id}' in Elasticsearch: {test_message}") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_search_by_rollout_id( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -341,7 +337,6 @@ def test_elasticsearch_direct_http_handler_search_by_rollout_id( print("Verified that search for different rollout_id returns 0 results") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_logs_status_info( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -407,7 +402,6 @@ def test_elasticsearch_direct_http_handler_logs_status_info( print(f"Successfully verified Status logging with code {test_status.code.value} in Elasticsearch: {test_message}") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_search_by_status_code( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -466,7 +460,6 @@ def test_elasticsearch_direct_http_handler_search_by_status_code( print(f"Successfully verified search by status code {running_status.value} found {len(hits)} log messages") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_rollout_id_from_extra_overrides_env( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): @@ -546,7 +539,6 @@ def test_elasticsearch_direct_http_handler_rollout_id_from_extra_overrides_env( print(f"Successfully verified rollout_id override: extra '{extra_rollout_id}' overrode environment '{rollout_id}'") -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_timestamp_format( elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): From 1648047f85ac2484c9abf6f37aaf1dca368e7dc5 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 09:24:53 -0700 Subject: [PATCH 06/11] add ecs tests --- .github/workflows/ci.yml | 1 + ...agate-status-test.yml => elasticsearch-tests.yml} | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) rename .github/workflows/{remote-rollout-processor-propagate-status-test.yml => elasticsearch-tests.yml} (73%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 903c7734..9b90b08b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -110,6 +110,7 @@ jobs: --ignore=tests/pytest/test_svgbench.py \ --ignore=tests/pytest/test_livesvgbench.py \ --ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \ + --ignore=tests/logging/test_elasticsearch_direct_http_handler.py \ --ignore=eval_protocol/benchmarks/ \ --cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10 diff --git a/.github/workflows/remote-rollout-processor-propagate-status-test.yml b/.github/workflows/elasticsearch-tests.yml similarity index 73% rename from .github/workflows/remote-rollout-processor-propagate-status-test.yml rename to .github/workflows/elasticsearch-tests.yml index d8080777..9eba33a8 100644 --- a/.github/workflows/remote-rollout-processor-propagate-status-test.yml +++ b/.github/workflows/elasticsearch-tests.yml @@ -1,4 +1,4 @@ -name: RemoteRolloutProcessor Propagate Status Test +name: Elasticsearch Tests on: push: @@ -13,8 +13,8 @@ on: workflow_dispatch: # Allow manual triggering jobs: - remote-rollout-processor-propagate-status-smoke-test: - name: Fireworks Propagate Status Smoke Test + elasticsearch-tests: + name: Elasticsearch Integration Tests runs-on: ubuntu-latest steps: @@ -36,10 +36,14 @@ jobs: - name: Install the project run: uv sync --locked --all-extras --dev - - name: Run RemoteRolloutProcessor Propagate Status Smoke Test + - name: Run Elasticsearch Tests env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" run: | + # Run Elasticsearch direct HTTP handler tests + uv run pytest tests/logging/test_elasticsearch_direct_http_handler.py -v --tb=short + + # Run RemoteRolloutProcessor Propagate Status Smoke Test (also uses Elasticsearch) uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ -v --tb=short From fd70d4fb5df9e2d259e28857bcbee8124c819a06 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 10:11:07 -0700 Subject: [PATCH 07/11] Enhance logging and debugging capabilities - Added a `--debug` flag to the logs command for enabling debug mode. - Updated logs command to display debug mode status. - Enhanced logging in the SqliteDatasetLoggerAdapter and SqliteEventBus for better traceability of events and errors. - Implemented debug logging in the WebSocketManager and LogsServer to provide detailed insights during WebSocket connections and broadcasts. - Enabled debug mode for all relevant loggers in the logs server system. --- eval_protocol/cli.py | 1 + eval_protocol/cli_commands/logs.py | 3 +- .../sqlite_dataset_logger_adapter.py | 9 +- eval_protocol/event_bus/sqlite_event_bus.py | 41 ++++- eval_protocol/utils/logs_server.py | 154 +++++++++++++++--- 5 files changed, 179 insertions(+), 29 deletions(-) diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 01cfb8c5..1d54cd65 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -300,6 +300,7 @@ def parse_args(args=None): # Logs command logs_parser = subparsers.add_parser("logs", help="Serve logs with file watching and real-time updates") logs_parser.add_argument("--port", type=int, default=8000, help="Port to bind to (default: 8000)") + logs_parser.add_argument("--debug", action="store_true", help="Enable debug mode") # Upload command upload_parser = subparsers.add_parser( diff --git a/eval_protocol/cli_commands/logs.py b/eval_protocol/cli_commands/logs.py index 414826bc..9f26f373 100644 --- a/eval_protocol/cli_commands/logs.py +++ b/eval_protocol/cli_commands/logs.py @@ -16,6 +16,7 @@ def logs_command(args): print(f"🌐 URL: http://localhost:{port}") print(f"šŸ”Œ WebSocket: ws://localhost:{port}/ws") print(f"šŸ‘€ Watching paths: {['current directory']}") + print(f"šŸ” Debug mode: {args.debug}") print("Press Ctrl+C to stop the server") print("-" * 50) @@ -25,7 +26,7 @@ def logs_command(args): elasticsearch_config = ElasticsearchSetup().setup_elasticsearch() try: - serve_logs(port=args.port, elasticsearch_config=elasticsearch_config) + serve_logs(port=args.port, elasticsearch_config=elasticsearch_config, debug=args.debug) return 0 except KeyboardInterrupt: print("\nšŸ›‘ Server stopped by user") diff --git a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py index 37704266..5f360bfc 100644 --- a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py +++ b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py @@ -23,12 +23,19 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati def log(self, row: "EvaluationRow") -> None: data = row.model_dump(exclude_none=True, mode="json") + rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown") + logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}") + self._store.upsert_row(data=data) + logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}") + try: + logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}") event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data)) + logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}") except Exception as e: # Avoid breaking storage due to event emission issues - logger.error(f"Failed to emit row_upserted event: {e}") + logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") pass def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: diff --git a/eval_protocol/event_bus/sqlite_event_bus.py b/eval_protocol/event_bus/sqlite_event_bus.py index 2925abe4..07442a53 100644 --- a/eval_protocol/event_bus/sqlite_event_bus.py +++ b/eval_protocol/event_bus/sqlite_event_bus.py @@ -35,58 +35,84 @@ def emit(self, event_type: str, data: Any) -> None: event_type: Type of event (e.g., "log") data: Event data """ + logger.debug(f"[CROSS_PROCESS_EMIT] Emitting event type: {event_type}") + # Call local listeners immediately + logger.debug(f"[CROSS_PROCESS_EMIT] Calling {len(self._listeners)} local listeners") super().emit(event_type, data) + logger.debug("[CROSS_PROCESS_EMIT] Completed local listener calls") # Publish to cross-process subscribers + logger.debug("[CROSS_PROCESS_EMIT] Publishing to cross-process subscribers") self._publish_cross_process(event_type, data) + logger.debug("[CROSS_PROCESS_EMIT] Completed cross-process publish") def _publish_cross_process(self, event_type: str, data: Any) -> None: """Publish event to cross-process subscribers via database.""" - self._db.publish_event(event_type, data, self._process_id) + logger.debug(f"[CROSS_PROCESS_PUBLISH] Publishing event {event_type} to database") + try: + self._db.publish_event(event_type, data, self._process_id) + logger.debug(f"[CROSS_PROCESS_PUBLISH] Successfully published event {event_type} to database") + except Exception as e: + logger.error(f"[CROSS_PROCESS_PUBLISH] Failed to publish event {event_type} to database: {e}") def start_listening(self) -> None: """Start listening for cross-process events.""" if self._running: + logger.debug("[CROSS_PROCESS_LISTEN] Already listening, skipping start") return + logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening") self._running = True self._start_database_listener() + logger.debug("[CROSS_PROCESS_LISTEN] Started database listener thread") def stop_listening(self) -> None: """Stop listening for cross-process events.""" + logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening") self._running = False if self._listener_thread and self._listener_thread.is_alive(): + logger.debug("[CROSS_PROCESS_LISTEN] Waiting for listener thread to stop") self._listener_thread.join(timeout=1) + logger.debug("[CROSS_PROCESS_LISTEN] Listener thread stopped") def _start_database_listener(self) -> None: """Start database-based event listener.""" def database_listener(): + logger.debug("[CROSS_PROCESS_LISTENER] Starting database listener loop") last_cleanup = time.time() while self._running: try: # Get unprocessed events from other processes events = self._db.get_unprocessed_events(self._process_id) + if events: + logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events") for event in events: if not self._running: break try: + logger.debug( + f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}" + ) # Handle the event self._handle_cross_process_event(event["event_type"], event["data"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}") # Mark as processed self._db.mark_event_processed(event["event_id"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed") except Exception as e: - logger.debug(f"Failed to process event {event['event_id']}: {e}") + logger.debug(f"[CROSS_PROCESS_LISTENER] Failed to process event {event['event_id']}: {e}") # Clean up old events every hour current_time = time.time() if current_time - last_cleanup >= 3600: + logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events") self._db.cleanup_old_events() last_cleanup = current_time @@ -94,7 +120,7 @@ def database_listener(): time.sleep(0.1) except Exception as e: - logger.debug(f"Database listener error: {e}") + logger.debug(f"[CROSS_PROCESS_LISTENER] Database listener error: {e}") time.sleep(1) self._listener_thread = threading.Thread(target=database_listener, daemon=True) @@ -102,8 +128,13 @@ def database_listener(): def _handle_cross_process_event(self, event_type: str, data: Any) -> None: """Handle events received from other processes.""" - for listener in self._listeners: + logger.debug(f"[CROSS_PROCESS_HANDLE] Handling cross-process event type: {event_type}") + logger.debug(f"[CROSS_PROCESS_HANDLE] Calling {len(self._listeners)} listeners") + + for i, listener in enumerate(self._listeners): try: + logger.debug(f"[CROSS_PROCESS_HANDLE] Calling listener {i}") listener(event_type, data) + logger.debug(f"[CROSS_PROCESS_HANDLE] Successfully called listener {i}") except Exception as e: - logger.debug(f"Cross-process event listener failed for {event_type}: {e}") + logger.debug(f"[CROSS_PROCESS_HANDLE] Cross-process event listener {i} failed for {event_type}: {e}") diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index 078a26ad..b094b81f 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -30,6 +30,19 @@ logger = logging.getLogger(__name__) +def enable_debug_mode(): + """Enable debug mode for all relevant loggers in the logs server system.""" + # Set debug level for all relevant loggers + logger.setLevel(logging.DEBUG) + + # Set debug level for event bus logger + from eval_protocol.event_bus.logger import logger as event_bus_logger + + event_bus_logger.setLevel(logging.DEBUG) + + print("Debug mode enabled for all relevant loggers") + + class WebSocketManager: """Manages WebSocket connections and broadcasts messages.""" @@ -40,100 +53,158 @@ def __init__(self): self._lock = threading.Lock() async def connect(self, websocket: WebSocket): + logger.debug("[WEBSOCKET_CONNECT] New websocket connection attempt") await websocket.accept() with self._lock: self.active_connections.append(websocket) connection_count = len(self.active_connections) - logger.info(f"WebSocket connected. Total connections: {connection_count}") + logger.info(f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}") + + logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization") logs = default_logger.read() + logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send") + data = { "type": "initialize_logs", "logs": [log.model_dump(exclude_none=True, mode="json") for log in logs], } + logger.debug("[WEBSOCKET_CONNECT] Sending initialization data") await websocket.send_text(json.dumps(data)) + logger.debug("[WEBSOCKET_CONNECT] Successfully sent initialization data") def disconnect(self, websocket: WebSocket): + logger.debug("[WEBSOCKET_DISCONNECT] WebSocket disconnection") with self._lock: if websocket in self.active_connections: self.active_connections.remove(websocket) + logger.debug("[WEBSOCKET_DISCONNECT] Removed websocket from active connections") + else: + logger.debug("[WEBSOCKET_DISCONNECT] Websocket was not in active connections") connection_count = len(self.active_connections) - logger.info(f"WebSocket disconnected. Total connections: {connection_count}") + logger.info(f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: {connection_count}") def broadcast_row_upserted(self, row: "EvaluationRow"): """Broadcast a row-upsert event to all connected clients. Safe no-op if server loop is not running or there are no connections. """ + rollout_id = row.execution_metadata.rollout_id if row.execution_metadata else "unknown" + logger.debug(f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}") + + with self._lock: + active_connections_count = len(self.active_connections) + logger.debug(f"[WEBSOCKET_BROADCAST] Active connections: {active_connections_count}") + + if active_connections_count == 0: + logger.debug( + f"[WEBSOCKET_BROADCAST] No active connections, skipping broadcast for rollout_id: {rollout_id}" + ) + return + try: # Serialize pydantic model + logger.debug(f"[WEBSOCKET_BROADCAST] Serializing row for rollout_id: {rollout_id}") json_message = json.dumps({"type": "log", "row": row.model_dump(exclude_none=True, mode="json")}) + logger.debug( + f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}" + ) + # Queue the message for broadcasting in the main event loop + logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}") self._broadcast_queue.put(json_message) + logger.debug(f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: {rollout_id}") except Exception as e: - logger.error(f"Failed to serialize row for broadcast: {e}") + logger.error( + f"[WEBSOCKET_BROADCAST] Failed to serialize row for broadcast for rollout_id {rollout_id}: {e}" + ) async def _start_broadcast_loop(self): """Start the broadcast loop that processes queued messages.""" + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Starting broadcast loop") while True: try: # Wait for a message to be queued + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue") message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) + logger.debug( + f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(str(message_data))})" + ) # Regular string message for all connections + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections") await self._send_text_to_all_connections(str(message_data)) + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections") except Exception as e: - logger.error(f"Error in broadcast loop: {e}") + logger.error(f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: {e}") await asyncio.sleep(0.1) except asyncio.CancelledError: - logger.info("Broadcast loop cancelled") + logger.info("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled") break async def _send_text_to_all_connections(self, text: str): with self._lock: connections = list(self.active_connections) + logger.debug(f"[WEBSOCKET_SEND] Attempting to send to {len(connections)} connections") + if not connections: + logger.debug("[WEBSOCKET_SEND] No connections available, skipping send") return tasks = [] failed_connections = [] - for connection in connections: + for i, connection in enumerate(connections): try: + logger.debug(f"[WEBSOCKET_SEND] Preparing to send to connection {i}") tasks.append(connection.send_text(text)) except Exception as e: - logger.error(f"Failed to send text to WebSocket: {e}") + logger.error(f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket {i}: {e}") failed_connections.append(connection) # Execute all sends in parallel if tasks: + logger.debug(f"[WEBSOCKET_SEND] Executing {len(tasks)} parallel sends") results = await asyncio.gather(*tasks, return_exceptions=True) + logger.debug("[WEBSOCKET_SEND] Completed parallel sends") # Check for any exceptions that occurred during execution for i, result in enumerate(results): if isinstance(result, Exception): - logger.error(f"Failed to send text to WebSocket: {result}") + logger.error(f"[WEBSOCKET_SEND] Failed to send text to WebSocket {i}: {result}") failed_connections.append(connections[i]) + else: + logger.debug(f"[WEBSOCKET_SEND] Successfully sent to connection {i}") # Remove all failed connections - with self._lock: - for connection in failed_connections: - try: - self.active_connections.remove(connection) - except ValueError: - pass + if failed_connections: + logger.debug(f"[WEBSOCKET_SEND] Removing {len(failed_connections)} failed connections") + with self._lock: + for connection in failed_connections: + try: + self.active_connections.remove(connection) + except ValueError: + pass def start_broadcast_loop(self): """Start the broadcast loop in the current event loop.""" if self._broadcast_task is None or self._broadcast_task.done(): + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Creating new broadcast task") self._broadcast_task = asyncio.create_task(self._start_broadcast_loop()) + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Broadcast task created") + else: + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Broadcast task already running") def stop_broadcast_loop(self): """Stop the broadcast loop.""" if self._broadcast_task and not self._broadcast_task.done(): + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Cancelling broadcast task") self._broadcast_task.cancel() self._broadcast_task = None + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Broadcast task cancelled") + else: + logger.debug("[WEBSOCKET_BROADCAST_LOOP] No active broadcast task to stop") class EvaluationWatcher: @@ -260,7 +331,12 @@ def __init__( port: Optional[int] = 8000, index_file: str = "index.html", elasticsearch_config: Optional[ElasticsearchConfig] = None, + debug: bool = False, ): + # Enable debug mode if requested + if debug: + enable_debug_mode() + # Initialize WebSocket manager self.websocket_manager = WebSocketManager() @@ -304,9 +380,11 @@ def __init__( logger.info(f" {methods} {path}") # Subscribe to events and start listening for cross-process events + logger.debug("[LOGS_SERVER_INIT] Subscribing to event bus") event_bus.subscribe(self._handle_event) + logger.debug("[LOGS_SERVER_INIT] Successfully subscribed to event bus") - logger.info(f"LogsServer initialized on {host}:{port}") + logger.info(f"[LOGS_SERVER_INIT] LogsServer initialized on {host}:{port}") def _setup_websocket_routes(self): """Set up WebSocket routes for real-time communication.""" @@ -418,17 +496,34 @@ async def get_logs( def _handle_event(self, event_type: str, data: Any) -> None: """Handle events from the event bus.""" + logger.debug(f"[EVENT_BUS_RECEIVE] Received event type: {event_type}") + if event_type in [LOG_EVENT_TYPE]: from eval_protocol.models import EvaluationRow - data = EvaluationRow(**data) - self.websocket_manager.broadcast_row_upserted(data) + try: + logger.debug("[EVENT_BUS_RECEIVE] Processing LOG_EVENT_TYPE event") + data = EvaluationRow(**data) + rollout_id = data.execution_metadata.rollout_id if data.execution_metadata else "unknown" + logger.debug(f"[EVENT_BUS_RECEIVE] Successfully parsed EvaluationRow for rollout_id: {rollout_id}") + + logger.debug("[EVENT_BUS_RECEIVE] Broadcasting row_upserted to websocket manager") + self.websocket_manager.broadcast_row_upserted(data) + logger.debug(f"[EVENT_BUS_RECEIVE] Successfully queued broadcast for rollout_id: {rollout_id}") + except Exception as e: + logger.error(f"[EVENT_BUS_RECEIVE] Failed to process LOG_EVENT_TYPE event: {e}") + else: + logger.debug(f"[EVENT_BUS_RECEIVE] Ignoring event type: {event_type} (not LOG_EVENT_TYPE)") def start_loops(self): """Start the broadcast loop and evaluation watcher.""" + logger.debug("[LOGS_SERVER_LOOPS] Starting all loops") self.websocket_manager.start_broadcast_loop() + logger.debug("[LOGS_SERVER_LOOPS] Started websocket broadcast loop") self.evaluation_watcher.start() + logger.debug("[LOGS_SERVER_LOOPS] Started evaluation watcher") event_bus.start_listening() + logger.debug("[LOGS_SERVER_LOOPS] Started event bus listening") async def run_async(self): """ @@ -477,6 +572,7 @@ def create_app( port: int = 8000, build_dir: Optional[str] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None, + debug: bool = False, ) -> FastAPI: """ Factory function to create a FastAPI app instance and start the server with async loops. @@ -498,17 +594,21 @@ def create_app( os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "vite-app", "dist") ) - server = LogsServer(host=host, port=port, build_dir=build_dir, elasticsearch_config=elasticsearch_config) + server = LogsServer( + host=host, port=port, build_dir=build_dir, elasticsearch_config=elasticsearch_config, debug=debug + ) server.start_loops() return server.app # For backward compatibility and direct usage -def serve_logs(port: Optional[int] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None): +def serve_logs( + port: Optional[int] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None, debug: bool = False +): """ Convenience function to create and run a LogsServer. """ - server = LogsServer(port=port, elasticsearch_config=elasticsearch_config) + server = LogsServer(port=port, elasticsearch_config=elasticsearch_config, debug=debug) server.run() @@ -519,17 +619,27 @@ def serve_logs(port: Optional[int] = None, elasticsearch_config: Optional[Elasti parser.add_argument("--host", default="localhost", help="Host to bind to (default: localhost)") parser.add_argument("--port", type=int, default=8000, help="Port to bind to (default: 8000)") parser.add_argument("--build-dir", help="Path to Vite build directory") + parser.add_argument("--debug", help="Set logger level to DEBUG") args = parser.parse_args() + if args.debug: + enable_debug_mode() + elasticsearch_config = ElasticsearchSetup().setup_elasticsearch() # Create server with command line arguments if args.build_dir: server = LogsServer( - host=args.host, port=args.port, build_dir=args.build_dir, elasticsearch_config=elasticsearch_config + host=args.host, + port=args.port, + build_dir=args.build_dir, + elasticsearch_config=elasticsearch_config, + debug=bool(args.debug), ) else: - server = LogsServer(host=args.host, port=args.port, elasticsearch_config=elasticsearch_config) + server = LogsServer( + host=args.host, port=args.port, elasticsearch_config=elasticsearch_config, debug=bool(args.debug) + ) server.run() From 174b26122dd324ea934a8f4406ab9251272385fe Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 10:57:14 -0700 Subject: [PATCH 08/11] Improve WebSocket connection handling - Updated the WebSocket connection logic to prevent multiple connections by checking for both OPEN and CONNECTING states. - This change addresses potential issues in React strict mode where multiple connection attempts could occur. --- vite-app/src/App.tsx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vite-app/src/App.tsx b/vite-app/src/App.tsx index f322333d..fe55b4d3 100644 --- a/vite-app/src/App.tsx +++ b/vite-app/src/App.tsx @@ -23,8 +23,11 @@ const App = observer(() => { const reconnectAttemptsRef = useRef(0); const connectWebSocket = () => { - if (wsRef.current?.readyState === WebSocket.OPEN) { - return; // Already connected + if ( + wsRef.current?.readyState === WebSocket.OPEN || + wsRef.current?.readyState === WebSocket.CONNECTING + ) { + return; // Already connected or connecting. This will happen in React strict mode. } const ws = new WebSocket(getWebSocketUrl()); From 8e54f84c343a7c9dee1c2ecc0fff850f1025f628 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 11:27:45 -0700 Subject: [PATCH 09/11] Refactor SqliteEventBus for async event handling - Updated SqliteEventBus to use asyncio for cross-process event listening, replacing the previous threading implementation. - Changed the processed field in the database from CharField to BooleanField for better data integrity. - Adjusted event processing logic to accommodate the new async structure, ensuring events are handled correctly across processes. - Enhanced test cases to support async operations and validate cross-process event communication. --- eval_protocol/event_bus/sqlite_event_bus.py | 108 ++++++++--------- .../event_bus/sqlite_event_bus_database.py | 14 +-- test_event_bus_helper.py | 74 ++++++++++++ tests/test_event_bus.py | 112 ++++++++++++------ tests/test_event_bus_helper.py | 74 ++++++++++++ 5 files changed, 275 insertions(+), 107 deletions(-) create mode 100644 test_event_bus_helper.py create mode 100644 tests/test_event_bus_helper.py diff --git a/eval_protocol/event_bus/sqlite_event_bus.py b/eval_protocol/event_bus/sqlite_event_bus.py index 07442a53..88125a5b 100644 --- a/eval_protocol/event_bus/sqlite_event_bus.py +++ b/eval_protocol/event_bus/sqlite_event_bus.py @@ -1,3 +1,5 @@ +import asyncio +import os import threading import time from typing import Any, Optional @@ -16,17 +18,14 @@ def __init__(self, db_path: Optional[str] = None): # Use the same database as the evaluation row store if db_path is None: - import os - from eval_protocol.directory_utils import find_eval_protocol_dir eval_protocol_dir = find_eval_protocol_dir() db_path = os.path.join(eval_protocol_dir, "logs.db") - self._db = SqliteEventBusDatabase(db_path) + self._db: SqliteEventBusDatabase = SqliteEventBusDatabase(db_path) self._running = False - self._listener_thread: Optional[threading.Thread] = None - self._process_id = str(uuid4()) + self._process_id = str(os.getpid()) def emit(self, event_type: str, data: Any) -> None: """Emit an event to all subscribers. @@ -64,67 +63,54 @@ def start_listening(self) -> None: logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening") self._running = True - self._start_database_listener() - logger.debug("[CROSS_PROCESS_LISTEN] Started database listener thread") + loop = asyncio.get_running_loop() + loop.create_task(self._database_listener_task()) + logger.debug("[CROSS_PROCESS_LISTEN] Started async database listener task") def stop_listening(self) -> None: """Stop listening for cross-process events.""" logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening") self._running = False - if self._listener_thread and self._listener_thread.is_alive(): - logger.debug("[CROSS_PROCESS_LISTEN] Waiting for listener thread to stop") - self._listener_thread.join(timeout=1) - logger.debug("[CROSS_PROCESS_LISTEN] Listener thread stopped") - - def _start_database_listener(self) -> None: - """Start database-based event listener.""" - - def database_listener(): - logger.debug("[CROSS_PROCESS_LISTENER] Starting database listener loop") - last_cleanup = time.time() - - while self._running: - try: - # Get unprocessed events from other processes - events = self._db.get_unprocessed_events(self._process_id) - if events: - logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events") - - for event in events: - if not self._running: - break - - try: - logger.debug( - f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}" - ) - # Handle the event - self._handle_cross_process_event(event["event_type"], event["data"]) - logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}") - - # Mark as processed - self._db.mark_event_processed(event["event_id"]) - logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed") - - except Exception as e: - logger.debug(f"[CROSS_PROCESS_LISTENER] Failed to process event {event['event_id']}: {e}") - - # Clean up old events every hour - current_time = time.time() - if current_time - last_cleanup >= 3600: - logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events") - self._db.cleanup_old_events() - last_cleanup = current_time - - # Small sleep to prevent busy waiting - time.sleep(0.1) - - except Exception as e: - logger.debug(f"[CROSS_PROCESS_LISTENER] Database listener error: {e}") - time.sleep(1) - - self._listener_thread = threading.Thread(target=database_listener, daemon=True) - self._listener_thread.start() + + async def _database_listener_task(self) -> None: + """Single database listener task that processes events and recreates itself.""" + if not self._running: + # this should end the task loop + logger.debug("[CROSS_PROCESS_LISTENER] Stopping database listener task") + return + + # Get unprocessed events from other processes + events = self._db.get_unprocessed_events(str(self._process_id)) + if events: + logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events") + else: + logger.debug(f"[CROSS_PROCESS_LISTENER] No unprocessed events found for process {self._process_id}") + + for event in events: + logger.debug( + f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}" + ) + # Handle the event + self._handle_cross_process_event(event["event_type"], event["data"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}") + + # Mark as processed + self._db.mark_event_processed(event["event_id"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed") + + # Clean up old events every hour + current_time = time.time() + if not hasattr(self, "_last_cleanup"): + self._last_cleanup = current_time + elif current_time - self._last_cleanup >= 3600: + logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events") + self._db.cleanup_old_events() + self._last_cleanup = current_time + + # Schedule the next task if still running + await asyncio.sleep(1.0) + loop = asyncio.get_running_loop() + loop.create_task(self._database_listener_task()) def _handle_cross_process_event(self, event_type: str, data: Any) -> None: """Handle events received from other processes.""" diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index f7a96f84..5d1f522a 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -2,7 +2,7 @@ from typing import Any, List from uuid import uuid4 -from peewee import CharField, DateTimeField, Model, SqliteDatabase +from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase from playhouse.sqlite_ext import JSONField from eval_protocol.event_bus.logger import logger @@ -25,7 +25,7 @@ class Event(BaseModel): # type: ignore data = JSONField() timestamp = DateTimeField() process_id = CharField() - processed = CharField(default="false") # Track if event has been processed + processed = BooleanField(default=False) # Track if event has been processed self._Event = Event self._db.connect() @@ -46,7 +46,7 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None: data=serialized_data, timestamp=time.time(), process_id=process_id, - processed="false", + processed=False, ) except Exception as e: logger.warning(f"Failed to publish event to database: {e}") @@ -56,7 +56,7 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]: try: query = ( self._Event.select() - .where((self._Event.process_id != process_id) & (self._Event.processed == "false")) + .where((self._Event.process_id != process_id) & (~self._Event.processed)) .order_by(self._Event.timestamp) ) @@ -80,7 +80,7 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]: def mark_event_processed(self, event_id: str) -> None: """Mark an event as processed.""" try: - self._Event.update(processed="true").where(self._Event.event_id == event_id).execute() + self._Event.update(processed=True).where(self._Event.event_id == event_id).execute() except Exception as e: logger.debug(f"Failed to mark event as processed: {e}") @@ -88,8 +88,6 @@ def cleanup_old_events(self, max_age_hours: int = 24) -> None: """Clean up old processed events.""" try: cutoff_time = time.time() - (max_age_hours * 3600) - self._Event.delete().where( - (self._Event.processed == "true") & (self._Event.timestamp < cutoff_time) - ).execute() + self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute() except Exception as e: logger.debug(f"Failed to cleanup old events: {e}") diff --git a/test_event_bus_helper.py b/test_event_bus_helper.py new file mode 100644 index 00000000..4bd97231 --- /dev/null +++ b/test_event_bus_helper.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +"""Helper script for testing event bus cross-process communication.""" + +import asyncio +import sys +import json +from eval_protocol.event_bus import SqliteEventBus +from eval_protocol.models import EvaluationRow, InputMetadata + + +async def listener_process(db_path: str): + """Run an event bus listener in a separate process.""" + try: + event_bus = SqliteEventBus(db_path=db_path) + + received_events = [] + + def test_listener(event_type: str, data): + received_events.append((event_type, data)) + + event_bus.subscribe(test_listener) + event_bus.start_listening() + + # Wait for events for up to 5 seconds + start_time = asyncio.get_event_loop().time() + while asyncio.get_event_loop().time() - start_time < 5.0: + await asyncio.sleep(0.1) + if received_events: + break + + # Output results to stdout + print(json.dumps(received_events)) + event_bus.stop_listening() + + except Exception as e: + print(f"Error in listener process: {e}", file=sys.stderr) + sys.exit(1) + + +async def emitter_process(db_path: str, event_type: str, data_json: str): + """Emit an event from a separate process.""" + try: + event_bus = SqliteEventBus(db_path=db_path) + + # Parse the data + if data_json: + data = json.loads(data_json) + else: + data = None + + event_bus.emit(event_type, data) + + except Exception as e: + print(f"Error in emitter process: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python test_event_bus_helper.py [event_type] [data_json]", file=sys.stderr) + sys.exit(1) + + mode = sys.argv[1] + db_path = sys.argv[2] + + if mode == "listener": + asyncio.run(listener_process(db_path)) + elif mode == "emitter": + event_type = sys.argv[3] + data_json = sys.argv[4] if len(sys.argv) > 4 else "" + asyncio.run(emitter_process(db_path, event_type, data_json)) + else: + print(f"Unknown mode: {mode}", file=sys.stderr) + sys.exit(1) diff --git a/tests/test_event_bus.py b/tests/test_event_bus.py index 49d74d12..6306a02f 100644 --- a/tests/test_event_bus.py +++ b/tests/test_event_bus.py @@ -1,9 +1,9 @@ +import asyncio import tempfile -import time from eval_protocol.event_bus import SqliteEventBus from eval_protocol.event_bus.event_bus import EventBus -from eval_protocol.models import EvaluationRow, InputMetadata +from eval_protocol.models import EvaluationRow, InputMetadata, Message class TestSqliteEventBus: @@ -37,70 +37,108 @@ def test_listener(event_type: str, data): os.unlink(db_path) - def test_cross_process_events(self): + async def test_cross_process_events(self): """Test cross-process event communication.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create two event buses (simulating different processes) - event_bus1 = SqliteEventBus(db_path=db_path) - event_bus2 = SqliteEventBus(db_path=db_path) + # Start listener process + listener_process = await asyncio.create_subprocess_exec( + "python", + "tests/test_event_bus_helper.py", + "listener", + db_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) - # Set up listener on event_bus2 - received_events = [] + # Give the listener time to start + await asyncio.sleep(0.5) - def test_listener(event_type: str, data): - received_events.append((event_type, data)) + # Emit event from a separate process + test_data = {"test": "cross_process"} + import json + + data_json = json.dumps(test_data) + + emitter_process = await asyncio.create_subprocess_exec( + "python", + "tests/test_event_bus_helper.py", + "emitter", + db_path, + "cross_process_event", + data_json, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) - event_bus2.subscribe(test_listener) - event_bus2.start_listening() + # Wait for emitter to complete + await emitter_process.wait() - # Emit event from event_bus1 - test_data = {"test": "cross_process"} - event_bus1.emit("cross_process_event", test_data) + # Wait for listener to complete and get results + stdout, stderr = await listener_process.communicate() - # Wait a bit for the event to be processed - time.sleep(0.2) + # Parse results + received_events = json.loads(stdout.decode()) - # Check that event_bus2 received the event + # Check that the event was received assert len(received_events) == 1 assert received_events[0][0] == "cross_process_event" assert received_events[0][1] == test_data - event_bus2.stop_listening() - finally: import os os.unlink(db_path) - def test_evaluation_row_events(self): + async def test_evaluation_row_events(self): """Test that EvaluationRow objects can be emitted and received.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - event_bus1 = SqliteEventBus(db_path=db_path) - event_bus2 = SqliteEventBus(db_path=db_path) + # Start listener process + listener_process = await asyncio.create_subprocess_exec( + "python", + "test_event_bus_helper.py", + "listener", + db_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) - received_events = [] + # Give the listener time to start + await asyncio.sleep(0.5) - def test_listener(event_type: str, data): - received_events.append((event_type, data)) + # Create and emit an EvaluationRow from a separate process + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], input_metadata=InputMetadata(row_id="test-123") + ) - event_bus2.subscribe(test_listener) - event_bus2.start_listening() + import json - # Create and emit an EvaluationRow - test_row = EvaluationRow( - messages=[{"role": "user", "content": "test"}], input_metadata=InputMetadata(row_id="test-123") + data_json = json.dumps(test_row.model_dump(mode="json")) + + emitter_process = await asyncio.create_subprocess_exec( + "python", + "test_event_bus_helper.py", + "emitter", + db_path, + "row_upserted", + data_json, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) - event_bus1.emit("row_upserted", test_row) + # Wait for emitter to complete + await emitter_process.wait() - # Wait for processing - time.sleep(0.2) + # Wait for listener to complete and get results + stdout, stderr = await listener_process.communicate() + + # Parse results + received_events = json.loads(stdout.decode()) # Check that the event was received assert len(received_events) == 1 @@ -111,14 +149,12 @@ def test_listener(event_type: str, data): event = EvaluationRow(**received_events[0][1]) assert event.input_metadata.row_id == "test-123" - event_bus2.stop_listening() - finally: import os os.unlink(db_path) - def test_process_isolation(self): + async def test_process_isolation(self): """Test that processes receive their own events locally but not via cross-process mechanism.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name @@ -138,7 +174,7 @@ def test_listener(event_type: str, data): event_bus.emit("self_event", {"test": "data"}) # Wait for processing - time.sleep(0.2) + await asyncio.sleep(1.0) # Should receive the event from its own process via local delivery assert len(received_events) == 1 diff --git a/tests/test_event_bus_helper.py b/tests/test_event_bus_helper.py new file mode 100644 index 00000000..4bd97231 --- /dev/null +++ b/tests/test_event_bus_helper.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +"""Helper script for testing event bus cross-process communication.""" + +import asyncio +import sys +import json +from eval_protocol.event_bus import SqliteEventBus +from eval_protocol.models import EvaluationRow, InputMetadata + + +async def listener_process(db_path: str): + """Run an event bus listener in a separate process.""" + try: + event_bus = SqliteEventBus(db_path=db_path) + + received_events = [] + + def test_listener(event_type: str, data): + received_events.append((event_type, data)) + + event_bus.subscribe(test_listener) + event_bus.start_listening() + + # Wait for events for up to 5 seconds + start_time = asyncio.get_event_loop().time() + while asyncio.get_event_loop().time() - start_time < 5.0: + await asyncio.sleep(0.1) + if received_events: + break + + # Output results to stdout + print(json.dumps(received_events)) + event_bus.stop_listening() + + except Exception as e: + print(f"Error in listener process: {e}", file=sys.stderr) + sys.exit(1) + + +async def emitter_process(db_path: str, event_type: str, data_json: str): + """Emit an event from a separate process.""" + try: + event_bus = SqliteEventBus(db_path=db_path) + + # Parse the data + if data_json: + data = json.loads(data_json) + else: + data = None + + event_bus.emit(event_type, data) + + except Exception as e: + print(f"Error in emitter process: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python test_event_bus_helper.py [event_type] [data_json]", file=sys.stderr) + sys.exit(1) + + mode = sys.argv[1] + db_path = sys.argv[2] + + if mode == "listener": + asyncio.run(listener_process(db_path)) + elif mode == "emitter": + event_type = sys.argv[3] + data_json = sys.argv[4] if len(sys.argv) > 4 else "" + asyncio.run(emitter_process(db_path, event_type, data_json)) + else: + print(f"Unknown mode: {mode}", file=sys.stderr) + sys.exit(1) From beae947e812ea4caa4a66ccd89db1a818b24a3e0 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 11:31:30 -0700 Subject: [PATCH 10/11] fix tests --- eval_protocol/utils/logs_server.py | 6 ------ tests/test_logs_server.py | 4 ++-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index b094b81f..d8a18638 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -95,12 +95,6 @@ def broadcast_row_upserted(self, row: "EvaluationRow"): active_connections_count = len(self.active_connections) logger.debug(f"[WEBSOCKET_BROADCAST] Active connections: {active_connections_count}") - if active_connections_count == 0: - logger.debug( - f"[WEBSOCKET_BROADCAST] No active connections, skipping broadcast for rollout_id: {rollout_id}" - ) - return - try: # Serialize pydantic model logger.debug(f"[WEBSOCKET_BROADCAST] Serializing row for rollout_id: {rollout_id}") diff --git a/tests/test_logs_server.py b/tests/test_logs_server.py index 0dcc1cf3..f17b3cf9 100644 --- a/tests/test_logs_server.py +++ b/tests/test_logs_server.py @@ -373,7 +373,7 @@ def test_serve_logs_port_parameter(self, temp_build_dir): serve_logs(port=test_port) # Verify that LogsServer was created with the correct port - mock_logs_server_class.assert_called_once_with(port=test_port, elasticsearch_config=None) + mock_logs_server_class.assert_called_once_with(port=test_port, elasticsearch_config=None, debug=False) # Verify that the run method was called on the instance mock_server_instance.run.assert_called_once() @@ -387,7 +387,7 @@ def test_serve_logs_default_port(self, temp_build_dir): serve_logs() # Verify that LogsServer was created with None port (which will use LogsServer's default of 8000) - mock_logs_server_class.assert_called_once_with(port=None, elasticsearch_config=None) + mock_logs_server_class.assert_called_once_with(port=None, elasticsearch_config=None, debug=False) # Verify that the run method was called on the instance mock_server_instance.run.assert_called_once() From fbc3bddf5ed8036d46c6ccdabad2ffa3ef537440 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 8 Oct 2025 11:47:12 -0700 Subject: [PATCH 11/11] increase timeout --- tests/remote_server/test_remote_fireworks_propagate_status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 27ac977b..d924832d 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -87,7 +87,7 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", - timeout_seconds=30, + timeout_seconds=120, ), ) async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: