From 912d1e7e26e077a2256cc3b51999611e8d4a60c8 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 12:57:12 -0700 Subject: [PATCH 1/9] save --- .../logging/elasticsearch_direct_http_handler.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index cbe5e402..32e269ea 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -1,6 +1,7 @@ import json import logging import asyncio +import os import threading from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple, Any, Dict @@ -11,7 +12,7 @@ from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig -class ElasticsearchDirectHttpHandler(logging.Handler): +class ElasticSearchDirectHttpHandler(logging.Handler): def __init__(self, elasticsearch_config: ElasticSearchConfig) -> None: super().__init__() self.base_url: str = elasticsearch_config.url.rstrip("/") @@ -31,12 +32,14 @@ def emit(self, record: logging.LogRecord) -> None: # Create proper ISO 8601 timestamp timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + rollout_id = self._get_rollout_id(record) + data: Dict[str, Any] = { "@timestamp": timestamp, "level": record.levelname, "message": record.getMessage(), "logger_name": record.name, - # Add other relevant record attributes if needed + "rollout_id": rollout_id, } # Schedule the HTTP request to run asynchronously @@ -45,6 +48,15 @@ def emit(self, record: logging.LogRecord) -> None: self.handleError(record) print(f"Error preparing log for Elasticsearch: {e}") + def _get_rollout_id(self, record: logging.LogRecord) -> str: + """Get the rollout ID from environment variables.""" + rollout_id = os.getenv("EP_ROLLOUT_ID") + if rollout_id is None: + raise ValueError( + "EP_ROLLOUT_ID environment variable is not set but needed for ElasticSearchDirectHttpHandler" + ) + return rollout_id + def _schedule_async_send(self, data: Dict[str, Any], record: logging.LogRecord) -> None: """Schedule an async task to send the log data to Elasticsearch.""" if self._executor is None: From fc21655960ccf3a121ab67f9e36fa8990d8b3824 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 13:17:34 -0700 Subject: [PATCH 2/9] Enhance ElasticsearchIndexManager and tests to include rollout_id field - Updated _has_correct_timestamp_mapping method to check for rollout_id as a keyword field in Elasticsearch mappings. - Added rollout_id fixture to tests for setting up environment variable during test execution. - Modified test functions to include rollout_id in log messages and assertions, ensuring proper indexing and retrieval from Elasticsearch. - Implemented new tests to verify the presence and correctness of rollout_id in logged messages. --- .../logging/elasticsearch_index_manager.py | 22 +- .../test_elasticsearch_direct_http_handler.py | 220 +++++++++++++++++- 2 files changed, 230 insertions(+), 12 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_index_manager.py b/eval_protocol/logging/elasticsearch_index_manager.py index c9808dbc..70465cc8 100644 --- a/eval_protocol/logging/elasticsearch_index_manager.py +++ b/eval_protocol/logging/elasticsearch_index_manager.py @@ -98,22 +98,31 @@ def _index_exists_with_correct_mapping(self) -> bool: return False def _has_correct_timestamp_mapping(self, mapping_data: Dict[str, Any]) -> bool: - """Check if the mapping has @timestamp as a date field. + """Check if the mapping has @timestamp as a date field and rollout_id as a keyword field. Args: mapping_data: Elasticsearch mapping response data Returns: - bool: True if @timestamp is correctly mapped as date field + bool: True if @timestamp is correctly mapped as date field and rollout_id as keyword field """ try: - return ( + if not ( self.index_name in mapping_data and "mappings" in mapping_data[self.index_name] and "properties" in mapping_data[self.index_name]["mappings"] - and "@timestamp" in mapping_data[self.index_name]["mappings"]["properties"] - and mapping_data[self.index_name]["mappings"]["properties"]["@timestamp"].get("type") == "date" - ) + ): + return False + + properties = mapping_data[self.index_name]["mappings"]["properties"] + + # Check @timestamp is mapped as date + timestamp_ok = "@timestamp" in properties and properties["@timestamp"].get("type") == "date" + + # Check rollout_id is mapped as keyword + rollout_id_ok = "rollout_id" in properties and properties["rollout_id"].get("type") == "keyword" + + return timestamp_ok and rollout_id_ok except (KeyError, TypeError): return False @@ -130,6 +139,7 @@ def _get_logging_mapping(self) -> Dict[str, Any]: "level": {"type": "keyword"}, "message": {"type": "text"}, "logger_name": {"type": "keyword"}, + "rollout_id": {"type": "keyword"}, } } } diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 8b7eb892..39b24dfd 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -5,11 +5,29 @@ import pytest from urllib.parse import urlparse -from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler +from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticSearchDirectHttpHandler from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +@pytest.fixture +def rollout_id(): + """Set up EP_ROLLOUT_ID environment variable for tests.""" + import uuid + + # Generate a unique rollout ID for this test session + test_rollout_id = f"test-rollout-{uuid.uuid4().hex[:8]}" + + # Set the environment variable + os.environ["EP_ROLLOUT_ID"] = test_rollout_id + + yield test_rollout_id + + # Clean up after the test + if "EP_ROLLOUT_ID" in os.environ: + del os.environ["EP_ROLLOUT_ID"] + + @pytest.fixture def elasticsearch_config(): """Set up Elasticsearch and return configuration.""" @@ -22,11 +40,11 @@ def elasticsearch_config(): @pytest.fixture -def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig): +def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: str): """Create and configure ElasticsearchDirectHttpHandler.""" # Use a unique test-specific index name with timestamp - handler = ElasticsearchDirectHttpHandler(elasticsearch_config) + handler = ElasticSearchDirectHttpHandler(elasticsearch_config) # Set a specific log level handler.setLevel(logging.INFO) @@ -35,7 +53,7 @@ def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig): @pytest.fixture -def test_logger(elasticsearch_handler, elasticsearch_config): +def test_logger(elasticsearch_handler, elasticsearch_config, rollout_id: str): """Set up a test logger with the Elasticsearch handler.""" # Create the index for this specific handler setup = ElasticsearchSetup() @@ -58,7 +76,7 @@ def test_logger(elasticsearch_handler, elasticsearch_config): @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sends_logs( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str ): """Test that ElasticsearchDirectHttpHandler successfully sends logs to Elasticsearch.""" @@ -133,7 +151,7 @@ def test_elasticsearch_direct_http_handler_sends_logs( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str ): """Test that logs can be sorted chronologically by timestamp.""" @@ -194,3 +212,193 @@ def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( print(f"Successfully verified chronological sorting of {len(hits)} log messages") print(f"Timestamps in order: {found_timestamps}") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_includes_rollout_id( + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str +): + """Test that ElasticsearchDirectHttpHandler includes rollout_id field in indexed logs.""" + + # Generate a unique test message to avoid conflicts + test_message = f"Rollout ID test message at {time.time()}" + + # Send the log message + test_logger.info(test_message) + + # Give Elasticsearch a moment to process the document + time.sleep(3) + + # Query Elasticsearch to verify the document was received with rollout_id + parsed_url = urlparse(elasticsearch_config.url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" + + # Prepare the search query + search_query = { + "query": {"match": {"message": test_message}}, + "sort": [{"@timestamp": {"order": "desc"}}], + "size": 1, + } + + # Execute the search + response = requests.post( + search_url, + headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, + json=search_query, + verify=parsed_url.scheme == "https", + ) + + # Check for errors + if response.status_code != 200: + print(f"Elasticsearch search failed with status {response.status_code}") + print(f"Response: {response.text}") + response.raise_for_status() + + search_results = response.json() + + # Assert that we found our log message + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + # Elasticsearch 7+ format + total_count = total_hits["value"] + else: + # Elasticsearch 6 format + total_count = total_hits + + assert total_count > 0, f"Expected to find at least 1 log message, but found {total_count}" + + # Verify the content of the found document + hits = search_results["hits"]["hits"] + assert len(hits) > 0, "Expected at least one hit" + + found_document = hits[0]["_source"] + + # Verify the rollout_id field is present and correct + assert "rollout_id" in found_document, "Expected document to contain 'rollout_id' field" + assert found_document["rollout_id"] == rollout_id, ( + f"Expected rollout_id '{rollout_id}', got '{found_document['rollout_id']}'" + ) + + # Verify other expected fields are still present + assert found_document["message"] == test_message, ( + f"Expected message '{test_message}', got '{found_document['message']}'" + ) + assert found_document["level"] == "INFO", f"Expected level 'INFO', got '{found_document['level']}'" + assert found_document["logger_name"] == "test_elasticsearch_logger", ( + f"Expected logger name 'test_elasticsearch_logger', got '{found_document['logger_name']}'" + ) + assert "@timestamp" in found_document, "Expected document to contain '@timestamp' field" + + print(f"Successfully verified log message with rollout_id '{rollout_id}' in Elasticsearch: {test_message}") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_search_by_rollout_id( + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str +): + """Test that logs can be searched by rollout_id field in Elasticsearch.""" + + # Generate unique test messages to avoid conflicts + test_messages = [] + for i in range(3): + message = f"Rollout search test message {i} at {time.time()}" + test_messages.append(message) + test_logger.info(message) + time.sleep(0.1) # Small delay to ensure different timestamps + + # Give Elasticsearch time to process all documents + time.sleep(3) + + # Query Elasticsearch to search by rollout_id + parsed_url = urlparse(elasticsearch_config.url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" + + # Search for logs with our specific rollout_id using term query + search_query = { + "query": {"term": {"rollout_id": rollout_id}}, + "sort": [{"@timestamp": {"order": "desc"}}], + "size": 10, + } + + # Execute the search + response = requests.post( + search_url, + headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, + json=search_query, + verify=parsed_url.scheme == "https", + ) + + # Check for errors + if response.status_code != 200: + print(f"Elasticsearch search failed with status {response.status_code}") + print(f"Response: {response.text}") + response.raise_for_status() + + search_results = response.json() + + # Assert that we found our log messages + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + # Elasticsearch 7+ format + total_count = total_hits["value"] + else: + # Elasticsearch 6 format + total_count = total_hits + + assert total_count >= 3, f"Expected to find at least 3 log messages, but found {total_count}" + + # Verify the content of the found documents + hits = search_results["hits"]["hits"] + assert len(hits) >= 3, f"Expected at least 3 hits, found {len(hits)}" + + # Verify all found documents have the correct rollout_id + found_messages = [] + for hit in hits: + document = hit["_source"] + assert document["rollout_id"] == rollout_id, ( + f"Expected rollout_id '{rollout_id}', got '{document['rollout_id']}'" + ) + found_messages.append(document["message"]) + + # Verify all our test messages are present in the search results + for test_message in test_messages: + assert test_message in found_messages, f"Expected message '{test_message}' not found in search results" + + # Test searching for a different rollout_id (should return no results) + different_rollout_id = f"different-rollout-{time.time()}" + search_query_different = { + "query": {"term": {"rollout_id": different_rollout_id}}, + "size": 10, + } + + response_different = requests.post( + search_url, + headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, + json=search_query_different, + verify=parsed_url.scheme == "https", + ) + + if response_different.status_code != 200: + print(f"Elasticsearch search failed with status {response_different.status_code}") + print(f"Response: {response_different.text}") + response_different.raise_for_status() + + different_results = response_different.json() + different_total_hits = different_results["hits"]["total"] + if isinstance(different_total_hits, dict): + different_count = different_total_hits["value"] + else: + different_count = different_total_hits + + assert different_count == 0, f"Expected 0 results for different rollout_id, but found {different_count}" + + print(f"Successfully verified search by rollout_id '{rollout_id}' found {len(hits)} log messages") + print("Verified that search for different rollout_id returns 0 results") From dcaf023cca7cedb177bd50071e1bc6085db85b0d Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 14:40:55 -0700 Subject: [PATCH 3/9] using tracing.fireworks.ai works --- tests/remote_server/remote_server.py | 13 +++++-------- tests/remote_server/test_remote_langfuse.py | 3 ++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index ba2df50a..cbe4d04a 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -1,17 +1,15 @@ import os import threading -from typing import Any, Dict, List +from typing import Any, Dict import uvicorn from fastapi import FastAPI, HTTPException -from langfuse.openai import openai # pyright: ignore[reportPrivateImportUsage] +from openai import OpenAI from eval_protocol.types.remote_rollout_processor import ( InitRequest, StatusResponse, - create_langfuse_config_tags, ) -from eval_protocol.models import Message app = FastAPI() @@ -28,21 +26,20 @@ def init(req: InitRequest): # Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration def _worker(): try: - metadata = {"langfuse_tags": create_langfuse_config_tags(req)} - if not req.messages: raise ValueError("messages is required") completion_kwargs = { "model": req.model, "messages": req.messages, - "metadata": metadata, } if req.tools: completion_kwargs["tools"] = req.tools - completion = openai.chat.completions.create(**completion_kwargs) + client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) + + completion = client.chat.completions.create(**completion_kwargs) except Exception as e: # Best-effort; mark as done even on error to unblock polling diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 78cde359..bf92d19a 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -56,7 +56,7 @@ def rows() -> List[EvaluationRow]: @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") -@pytest.mark.parametrize("completion_params", [{"model": "gpt-4o"}]) +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) @evaluation_test( data_loaders=DynamicDataLoader( generators=[rows], @@ -65,6 +65,7 @@ def rows() -> List[EvaluationRow]: remote_base_url="http://127.0.0.1:3000", timeout_seconds=30, output_data_loader=langfuse_output_data_loader, + model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk", ), ) async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: From f4d80ecda2393f3f88424509ecf1738ca8440c10 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 14:51:48 -0700 Subject: [PATCH 4/9] Added Status model import to eval_protocol and included it in the __all__ exports for better accessibility. --- eval_protocol/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index b713441a..ff77bd71 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -33,7 +33,7 @@ _FIREWORKS_AVAILABLE = False # Import submodules to make them available via eval_protocol.rewards, etc. from . import mcp, rewards -from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata +from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata, Status from .playback_policy import PlaybackPolicyBase from .resources import create_llm_resource from .reward_function import RewardFunction @@ -73,6 +73,7 @@ warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") __all__ = [ + "Status", "RemoteRolloutProcessor", "InputMetadata", "EvaluationRow", From 4c2773442fbe53db118b39e8ca29f06a5bec89b9 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 14:52:13 -0700 Subject: [PATCH 5/9] Enhance Elasticsearch logging to include status information - Added a method to extract status information from log records in ElasticSearchDirectHttpHandler. - Updated the data structure sent to Elasticsearch to include status_code, status_message, and status_details if present. - Modified ElasticsearchIndexManager to validate the mapping of new status fields. - Implemented tests to verify logging of status information and searching by status code in Elasticsearch. --- .../elasticsearch_direct_http_handler.py | 33 ++++ .../logging/elasticsearch_index_manager.py | 14 +- .../test_elasticsearch_direct_http_handler.py | 173 ++++++++++++++++++ 3 files changed, 217 insertions(+), 3 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index 32e269ea..dc77cc98 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -33,6 +33,7 @@ def emit(self, record: logging.LogRecord) -> None: timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ") rollout_id = self._get_rollout_id(record) + status_info = self._get_status_info(record) data: Dict[str, Any] = { "@timestamp": timestamp, @@ -42,6 +43,10 @@ def emit(self, record: logging.LogRecord) -> None: "rollout_id": rollout_id, } + # Add status information if present + if status_info: + data.update(status_info) + # Schedule the HTTP request to run asynchronously self._schedule_async_send(data, record) except Exception as e: @@ -57,6 +62,34 @@ def _get_rollout_id(self, record: logging.LogRecord) -> str: ) return rollout_id + def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]: + """Extract status information from the log record's extra data.""" + # Check if 'status' is in the extra data (passed via extra parameter) + if hasattr(record, "status") and record.status is not None: # type: ignore + status = record.status # type: ignore + + # Handle Status class instances (Pydantic BaseModel) + if hasattr(status, "code") and hasattr(status, "message"): + # Status object - extract code and message + status_code = status.code + # Handle both enum values and direct integer values + if hasattr(status_code, "value"): + status_code = status_code.value + + return { + "status_code": status_code, + "status_message": status.message, + "status_details": getattr(status, "details", []), + } + elif isinstance(status, dict): + # Dictionary representation of status + return { + "status_code": status.get("code"), + "status_message": status.get("message"), + "status_details": status.get("details", []), + } + return None + def _schedule_async_send(self, data: Dict[str, Any], record: logging.LogRecord) -> None: """Schedule an async task to send the log data to Elasticsearch.""" if self._executor is None: diff --git a/eval_protocol/logging/elasticsearch_index_manager.py b/eval_protocol/logging/elasticsearch_index_manager.py index 70465cc8..4ee7fccb 100644 --- a/eval_protocol/logging/elasticsearch_index_manager.py +++ b/eval_protocol/logging/elasticsearch_index_manager.py @@ -98,13 +98,13 @@ def _index_exists_with_correct_mapping(self) -> bool: return False def _has_correct_timestamp_mapping(self, mapping_data: Dict[str, Any]) -> bool: - """Check if the mapping has @timestamp as a date field and rollout_id as a keyword field. + """Check if the mapping has @timestamp as a date field, rollout_id as a keyword field, and status fields. Args: mapping_data: Elasticsearch mapping response data Returns: - bool: True if @timestamp is correctly mapped as date field and rollout_id as keyword field + bool: True if all required fields are correctly mapped """ try: if not ( @@ -122,7 +122,12 @@ def _has_correct_timestamp_mapping(self, mapping_data: Dict[str, Any]) -> bool: # Check rollout_id is mapped as keyword rollout_id_ok = "rollout_id" in properties and properties["rollout_id"].get("type") == "keyword" - return timestamp_ok and rollout_id_ok + # Check status fields are mapped correctly + status_code_ok = "status_code" in properties and properties["status_code"].get("type") == "integer" + status_message_ok = "status_message" in properties and properties["status_message"].get("type") == "text" + status_details_ok = "status_details" in properties and properties["status_details"].get("type") == "object" + + return timestamp_ok and rollout_id_ok and status_code_ok and status_message_ok and status_details_ok except (KeyError, TypeError): return False @@ -140,6 +145,9 @@ def _get_logging_mapping(self) -> Dict[str, Any]: "message": {"type": "text"}, "logger_name": {"type": "keyword"}, "rollout_id": {"type": "keyword"}, + "status_code": {"type": "integer"}, + "status_message": {"type": "text"}, + "status_details": {"type": "object"}, } } } diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 39b24dfd..728d8d2a 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -402,3 +402,176 @@ def test_elasticsearch_direct_http_handler_search_by_rollout_id( print(f"Successfully verified search by rollout_id '{rollout_id}' found {len(hits)} log messages") print("Verified that search for different rollout_id returns 0 results") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_logs_status_info( + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str +): + """Test that ElasticsearchDirectHttpHandler logs Status class instances and can search by status code.""" + from eval_protocol import Status + + # Create a Status instance + test_status = Status.rollout_running() + + # Generate a unique test message + test_message = f"Status logging test message at {time.time()}" + + # Log with Status instance in extra data + test_logger.info(test_message, extra={"status": test_status}) + + # Give Elasticsearch time to process the document + time.sleep(3) + + # Query Elasticsearch to verify the document was received with status info + parsed_url = urlparse(elasticsearch_config.url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" + + # Search for logs with our specific status code + search_query = { + "query": {"term": {"status_code": test_status.code.value}}, + "sort": [{"@timestamp": {"order": "desc"}}], + "size": 1, + } + + # Execute the search + response = requests.post( + search_url, + headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, + json=search_query, + verify=parsed_url.scheme == "https", + ) + + # Check for errors + if response.status_code != 200: + print(f"Elasticsearch search failed with status {response.status_code}") + print(f"Response: {response.text}") + response.raise_for_status() + + search_results = response.json() + + # Assert that we found our log message + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + total_count = total_hits["value"] + else: + total_count = total_hits + + assert total_count > 0, f"Expected to find at least 1 log message, but found {total_count}" + + # Verify the content of the found document + hits = search_results["hits"]["hits"] + assert len(hits) > 0, "Expected at least one hit" + + found_document = hits[0]["_source"] + + # Verify the status fields are present and correct + assert "status_code" in found_document, "Expected document to contain 'status_code' field" + assert found_document["status_code"] == test_status.code.value, ( + f"Expected status_code {test_status.code.value}, got {found_document['status_code']}" + ) + assert "status_message" in found_document, "Expected document to contain 'status_message' field" + assert found_document["status_message"] == test_status.message, ( + f"Expected status_message '{test_status.message}', got '{found_document['status_message']}'" + ) + assert "status_details" in found_document, "Expected document to contain 'status_details' field" + assert found_document["status_details"] == test_status.details, ( + f"Expected status_details {test_status.details}, got {found_document['status_details']}" + ) + + # Verify other expected fields are still present + assert found_document["message"] == test_message, ( + f"Expected message '{test_message}', got '{found_document['message']}'" + ) + assert found_document["rollout_id"] == rollout_id, ( + f"Expected rollout_id '{rollout_id}', got '{found_document['rollout_id']}'" + ) + + print(f"Successfully verified Status logging with code {test_status.code.value} in Elasticsearch: {test_message}") + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +def test_elasticsearch_direct_http_handler_search_by_status_code( + elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str +): + """Test that logs can be searched by status code in Elasticsearch.""" + from eval_protocol.models import Status + + # Create different Status instances for testing + statuses = [ + Status.rollout_running(), + Status.eval_finished(), + Status.error("Test error message"), + ] + + # Generate unique test messages + test_messages = [] + for i, status in enumerate(statuses): + message = f"Status search test message {i} at {time.time()}" + test_messages.append((message, status)) + test_logger.info(message, extra={"status": status}) + time.sleep(0.1) # Small delay to ensure different timestamps + + # Give Elasticsearch time to process all documents + time.sleep(3) + + # Query Elasticsearch to search by specific status code + parsed_url = urlparse(elasticsearch_config.url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" + + # Search for logs with RUNNING status code + running_status = Status.Code.RUNNING + search_query = { + "query": {"term": {"status_code": running_status.value}}, + "sort": [{"@timestamp": {"order": "desc"}}], + "size": 10, + } + + # Execute the search + response = requests.post( + search_url, + headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, + json=search_query, + verify=parsed_url.scheme == "https", + ) + + # Check for errors + if response.status_code != 200: + print(f"Elasticsearch search failed with status {response.status_code}") + print(f"Response: {response.text}") + response.raise_for_status() + + search_results = response.json() + + # Assert that we found our log messages + assert "hits" in search_results, "Search response should contain 'hits'" + assert "total" in search_results["hits"], "Search hits should contain 'total'" + + total_hits = search_results["hits"]["total"] + if isinstance(total_hits, dict): + total_count = total_hits["value"] + else: + total_count = total_hits + + assert total_count >= 1, f"Expected to find at least 1 log message with RUNNING status, but found {total_count}" + + # Verify the content of the found documents + hits = search_results["hits"]["hits"] + assert len(hits) >= 1, f"Expected at least 1 hit, found {len(hits)}" + + # Verify all found documents have the correct status code + for hit in hits: + document = hit["_source"] + assert document["status_code"] == running_status.value, ( + f"Expected status_code {running_status.value}, got {document['status_code']}" + ) + assert document["rollout_id"] == rollout_id, ( + f"Expected rollout_id '{rollout_id}', got '{document['rollout_id']}'" + ) + + print(f"Successfully verified search by status code {running_status.value} found {len(hits)} log messages") From 034fde71c34914850cdd4551d5f3c699f3e8398f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 16:07:05 -0700 Subject: [PATCH 6/9] Refactor Elasticsearch integration to use dedicated client - Replaced direct HTTP requests with an ElasticsearchClient for handling interactions with Elasticsearch in ElasticSearchDirectHttpHandler and ElasticsearchIndexManager. - Updated configuration handling to encapsulate Elasticsearch settings within a single config object. - Simplified index management and document indexing processes, improving code maintainability and readability. - Adjusted tests to utilize the new client for searching and verifying log entries, enhancing test reliability. --- eval_protocol/logging/elasticsearch_client.py | 301 ++++++++++++++++++ .../elasticsearch_direct_http_handler.py | 27 +- .../logging/elasticsearch_index_manager.py | 84 ++--- .../test_elasticsearch_direct_http_handler.py | 221 +++---------- 4 files changed, 375 insertions(+), 258 deletions(-) create mode 100644 eval_protocol/logging/elasticsearch_client.py diff --git a/eval_protocol/logging/elasticsearch_client.py b/eval_protocol/logging/elasticsearch_client.py new file mode 100644 index 00000000..917a06b3 --- /dev/null +++ b/eval_protocol/logging/elasticsearch_client.py @@ -0,0 +1,301 @@ +""" +Centralized Elasticsearch client for all Elasticsearch API operations. + +This module provides a unified interface for all Elasticsearch operations +used throughout the codebase, including index management, document operations, +and search functionality. +""" + +import json +import requests +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse +from dataclasses import dataclass + + +@dataclass +class ElasticsearchConfig: + """Configuration for Elasticsearch client.""" + + url: str + api_key: str + index_name: str + verify_ssl: bool = True + + def __post_init__(self): + """Parse URL to determine SSL verification setting.""" + parsed_url = urlparse(self.url) + self.verify_ssl = parsed_url.scheme == "https" + + +class ElasticsearchClient: + """Centralized client for all Elasticsearch operations.""" + + def __init__(self, config: ElasticsearchConfig): + """Initialize the Elasticsearch client. + + Args: + config: Elasticsearch configuration + """ + self.config = config + self.base_url = config.url.rstrip("/") + self.index_url = f"{self.base_url}/{config.index_name}" + self._headers = {"Content-Type": "application/json", "Authorization": f"ApiKey {config.api_key}"} + + def _make_request( + self, + method: str, + url: str, + json_data: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[Dict[str, Any]] = None, + timeout: int = 30, + ) -> requests.Response: + """Make an HTTP request to Elasticsearch. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, HEAD) + url: Full URL for the request + json_data: JSON data to send in request body + params: Query parameters + timeout: Request timeout in seconds + + Returns: + requests.Response object + + Raises: + requests.RequestException: If the request fails + """ + return requests.request( + method=method, + url=url, + headers=self._headers, + json=json_data, + params=params, + verify=self.config.verify_ssl, + timeout=timeout, + ) + + # Index Management Operations + + def create_index(self, mapping: Dict[str, Any]) -> bool: + """Create an index with the specified mapping. + + Args: + mapping: Index mapping configuration + + Returns: + bool: True if successful, False otherwise + """ + try: + response = self._make_request("PUT", self.index_url, json_data=mapping) + return response.status_code in [200, 201] + except Exception: + return False + + def index_exists(self) -> bool: + """Check if the index exists. + + Returns: + bool: True if index exists, False otherwise + """ + try: + response = self._make_request("HEAD", self.index_url) + return response.status_code == 200 + except Exception: + return False + + def delete_index(self) -> bool: + """Delete the index. + + Returns: + bool: True if successful, False otherwise + """ + try: + response = self._make_request("DELETE", self.index_url) + return response.status_code in [200, 404] # 404 means index doesn't exist + except Exception: + return False + + def get_mapping(self) -> Optional[Dict[str, Any]]: + """Get the index mapping. + + Returns: + Dict containing mapping data, or None if failed + """ + try: + response = self._make_request("GET", f"{self.index_url}/_mapping") + if response.status_code == 200: + return response.json() + return None + except Exception: + return None + + def get_index_stats(self) -> Optional[Dict[str, Any]]: + """Get index statistics. + + Returns: + Dict containing index statistics, or None if failed + """ + try: + response = self._make_request("GET", f"{self.index_url}/_stats") + if response.status_code == 200: + return response.json() + return None + except Exception: + return None + + # Document Operations + + def index_document(self, document: Dict[str, Any], doc_id: Optional[str] = None) -> bool: + """Index a document. + + Args: + document: Document to index + doc_id: Optional document ID + + Returns: + bool: True if successful, False otherwise + """ + try: + if doc_id: + url = f"{self.index_url}/_doc/{doc_id}" + else: + url = f"{self.index_url}/_doc" + + response = self._make_request("POST", url, json_data=document) + return response.status_code in [200, 201] + except Exception: + return False + + def bulk_index_documents(self, documents: List[Dict[str, Any]]) -> bool: + """Bulk index multiple documents. + + Args: + documents: List of documents to index + + Returns: + bool: True if successful, False otherwise + """ + try: + # Prepare bulk request body + bulk_body = [] + for doc in documents: + bulk_body.append({"index": {}}) + bulk_body.append(doc) + + response = self._make_request("POST", f"{self.index_url}/_bulk", json_data=bulk_body) + return response.status_code == 200 + except Exception: + return False + + # Search Operations + + def search( + self, query: Dict[str, Any], size: int = 10, from_: int = 0, sort: Optional[List[Dict[str, Any]]] = None + ) -> Optional[Dict[str, Any]]: + """Search documents in the index. + + Args: + query: Elasticsearch query + size: Number of results to return + from_: Starting offset + sort: Sort specification + + Returns: + Dict containing search results, or None if failed + """ + try: + search_body = {"query": query, "size": size, "from": from_} + + if sort: + search_body["sort"] = sort + + response = self._make_request("POST", f"{self.index_url}/_search", json_data=search_body) + + if response.status_code == 200: + return response.json() + return None + except Exception: + return None + + def search_by_term(self, field: str, value: Any, size: int = 10) -> Optional[Dict[str, Any]]: + """Search documents by exact term match. + + Args: + field: Field name to search + value: Value to match + size: Number of results to return + + Returns: + Dict containing search results, or None if failed + """ + query = {"term": {field: value}} + return self.search(query, size=size) + + def search_by_match(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]: + """Search documents by text match. + + Args: + field: Field name to search + value: Text to match + size: Number of results to return + + Returns: + Dict containing search results, or None if failed + """ + query = {"match": {field: value}} + return self.search(query, size=size) + + def search_by_match_phrase_prefix(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]: + """Search documents by phrase prefix match. + + Args: + field: Field name to search + value: Phrase prefix to match + size: Number of results to return + + Returns: + Dict containing search results, or None if failed + """ + query = {"match_phrase_prefix": {field: value}} + return self.search(query, size=size) + + def search_all(self, size: int = 10) -> Optional[Dict[str, Any]]: + """Search all documents in the index. + + Args: + size: Number of results to return + + Returns: + Dict containing search results, or None if failed + """ + query = {"match_all": {}} + return self.search(query, size=size) + + # Health and Status Operations + + def health_check(self) -> bool: + """Check if Elasticsearch is healthy. + + Returns: + bool: True if healthy, False otherwise + """ + try: + response = self._make_request("GET", f"{self.base_url}/_cluster/health") + return response.status_code == 200 + except Exception: + return False + + def get_cluster_info(self) -> Optional[Dict[str, Any]]: + """Get cluster information. + + Returns: + Dict containing cluster info, or None if failed + """ + try: + response = self._make_request("GET", f"{self.base_url}/_cluster/health") + if response.status_code == 200: + return response.json() + return None + except Exception: + return None diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index dc77cc98..66ad8751 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -6,26 +6,23 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple, Any, Dict from datetime import datetime -from urllib.parse import urlparse -import requests from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from .elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig class ElasticSearchDirectHttpHandler(logging.Handler): def __init__(self, elasticsearch_config: ElasticSearchConfig) -> None: super().__init__() - self.base_url: str = elasticsearch_config.url.rstrip("/") - self.index_name: str = elasticsearch_config.index_name - self.api_key: str = elasticsearch_config.api_key - self.url: str = f"{self.base_url}/{self.index_name}/_doc" + self.config = ESConfig( + url=elasticsearch_config.url, + api_key=elasticsearch_config.api_key, + index_name=elasticsearch_config.index_name, + ) + self.client = ElasticsearchClient(self.config) self.formatter: logging.Formatter = logging.Formatter() self._executor = None - # Parse URL to determine if we should verify SSL - parsed_url = urlparse(elasticsearch_config.url) - self.verify_ssl = parsed_url.scheme == "https" - def emit(self, record: logging.LogRecord) -> None: """Emit a log record by scheduling it for async transmission.""" try: @@ -104,13 +101,9 @@ def _schedule_async_send(self, data: Dict[str, Any], record: logging.LogRecord) def _send_to_elasticsearch(self, data: Dict[str, Any], record: logging.LogRecord) -> None: """Send data to Elasticsearch (runs in thread pool).""" try: - response: requests.Response = requests.post( - self.url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {self.api_key}"}, - data=json.dumps(data), - verify=self.verify_ssl, # If using HTTPS, verify SSL certificate - ) - response.raise_for_status() # Raise an exception for HTTP errors + success = self.client.index_document(data) + if not success: + raise Exception("Failed to index document to Elasticsearch") except Exception as e: # Re-raise to be handled by the callback raise e diff --git a/eval_protocol/logging/elasticsearch_index_manager.py b/eval_protocol/logging/elasticsearch_index_manager.py index 4ee7fccb..08e18813 100644 --- a/eval_protocol/logging/elasticsearch_index_manager.py +++ b/eval_protocol/logging/elasticsearch_index_manager.py @@ -1,6 +1,5 @@ -import requests from typing import Dict, Any, Optional -from urllib.parse import urlparse +from .elasticsearch_client import ElasticsearchClient, ElasticsearchConfig class ElasticsearchIndexManager: @@ -14,16 +13,10 @@ def __init__(self, base_url: str, index_name: str, api_key: str) -> None: index_name: Name of the index to manage api_key: API key for authentication """ - self.base_url: str = base_url.rstrip("/") - self.index_name: str = index_name - self.api_key: str = api_key - self.index_url: str = f"{self.base_url}/{self.index_name}" + self.config = ElasticsearchConfig(url=base_url, api_key=api_key, index_name=index_name) + self.client = ElasticsearchClient(self.config) self._mapping_created: bool = False - # Parse URL to determine if we should verify SSL - parsed_url = urlparse(base_url) - self.verify_ssl = parsed_url.scheme == "https" - def create_logging_index_mapping(self) -> bool: """Create index with proper mapping for logging data. @@ -41,25 +34,22 @@ def create_logging_index_mapping(self) -> bool: # If index exists but has wrong mapping, delete and recreate it if self.index_exists(): - print(f"Warning: Index {self.index_name} exists with incorrect mapping. Deleting and recreating...") + print( + f"Warning: Index {self.config.index_name} exists with incorrect mapping. Deleting and recreating..." + ) if not self.delete_index(): - print(f"Warning: Failed to delete existing index {self.index_name}") + print(f"Warning: Failed to delete existing index {self.config.index_name}") return False # Create index with proper mapping mapping = self._get_logging_mapping() - response = requests.put( - self.index_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {self.api_key}"}, - json=mapping, - verify=self.verify_ssl, - ) - - if response.status_code in [200, 201]: + success = self.client.create_index(mapping) + + if success: self._mapping_created = True return True else: - print(f"Warning: Failed to create index mapping: {response.status_code} - {response.text}") + print("Warning: Failed to create index mapping") return False except Exception as e: @@ -74,24 +64,14 @@ def _index_exists_with_correct_mapping(self) -> bool: """ try: # Check if index exists - response = requests.head( - self.index_url, headers={"Authorization": f"ApiKey {self.api_key}"}, verify=self.verify_ssl - ) - - if response.status_code != 200: + if not self.client.index_exists(): return False # Check if mapping is correct - mapping_response = requests.get( - f"{self.index_url}/_mapping", - headers={"Authorization": f"ApiKey {self.api_key}"}, - verify=self.verify_ssl, - ) - - if mapping_response.status_code != 200: + mapping_data = self.client.get_mapping() + if mapping_data is None: return False - mapping_data = mapping_response.json() return self._has_correct_timestamp_mapping(mapping_data) except Exception: @@ -108,13 +88,13 @@ def _has_correct_timestamp_mapping(self, mapping_data: Dict[str, Any]) -> bool: """ try: if not ( - self.index_name in mapping_data - and "mappings" in mapping_data[self.index_name] - and "properties" in mapping_data[self.index_name]["mappings"] + self.config.index_name in mapping_data + and "mappings" in mapping_data[self.config.index_name] + and "properties" in mapping_data[self.config.index_name]["mappings"] ): return False - properties = mapping_data[self.index_name]["mappings"]["properties"] + properties = mapping_data[self.config.index_name]["mappings"]["properties"] # Check @timestamp is mapped as date timestamp_ok = "@timestamp" in properties and properties["@timestamp"].get("type") == "date" @@ -159,14 +139,12 @@ def delete_index(self) -> bool: bool: True if index was deleted successfully, False otherwise. """ try: - response = requests.delete( - self.index_url, headers={"Authorization": f"ApiKey {self.api_key}"}, verify=self.verify_ssl - ) - if response.status_code in [200, 404]: # 404 means index doesn't exist, which is fine + success = self.client.delete_index() + if success: self._mapping_created = False return True else: - print(f"Warning: Failed to delete index: {response.status_code} - {response.text}") + print("Warning: Failed to delete index") return False except Exception as e: print(f"Warning: Failed to delete index: {e}") @@ -178,13 +156,7 @@ def index_exists(self) -> bool: Returns: bool: True if index exists, False otherwise. """ - try: - response = requests.head( - self.index_url, headers={"Authorization": f"ApiKey {self.api_key}"}, verify=self.verify_ssl - ) - return response.status_code == 200 - except Exception: - return False + return self.client.index_exists() def get_index_stats(self) -> Optional[Dict[str, Any]]: """Get statistics about the index. @@ -192,14 +164,4 @@ def get_index_stats(self) -> Optional[Dict[str, Any]]: Returns: Dict containing index statistics, or None if failed """ - try: - response = requests.get( - f"{self.index_url}/_stats", - headers={"Authorization": f"ApiKey {self.api_key}"}, - verify=self.verify_ssl, - ) - if response.status_code == 200: - return response.json() - return None - except Exception: - return None + return self.client.get_index_stats() diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 728d8d2a..be70d3ea 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -1,11 +1,10 @@ import os import logging import time -import requests import pytest -from urllib.parse import urlparse from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticSearchDirectHttpHandler +from eval_protocol.logging.elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig @@ -52,12 +51,21 @@ def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: return handler +@pytest.fixture +def elasticsearch_client(elasticsearch_config: ElasticSearchConfig): + """Create an Elasticsearch client for testing.""" + config = ESConfig( + url=elasticsearch_config.url, api_key=elasticsearch_config.api_key, index_name=elasticsearch_config.index_name + ) + return ElasticsearchClient(config) + + @pytest.fixture def test_logger(elasticsearch_handler, elasticsearch_config, rollout_id: str): """Set up a test logger with the Elasticsearch handler.""" # Create the index for this specific handler setup = ElasticsearchSetup() - setup.create_logging_index(elasticsearch_handler.index_name) + setup.create_logging_index(elasticsearch_handler.config.index_name) logger = logging.getLogger("test_elasticsearch_logger") logger.setLevel(logging.INFO) @@ -76,7 +84,7 @@ def test_logger(elasticsearch_handler, elasticsearch_config, rollout_id: str): @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sends_logs( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that ElasticsearchDirectHttpHandler successfully sends logs to Elasticsearch.""" @@ -89,36 +97,11 @@ def test_elasticsearch_direct_http_handler_sends_logs( # Give Elasticsearch a moment to process the document time.sleep(3) - # Query Elasticsearch to verify the document was received - # Parse the URL to construct the search endpoint - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - - # Prepare the search query with sorting by @timestamp - search_query = { - "query": {"match": {"message": test_message}}, - "sort": [{"@timestamp": {"order": "desc"}}], - "size": 1, - } - - # Execute the search - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", - ) - - # Check for errors and provide better debugging - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() - - search_results = response.json() + # Search for the document using the client + search_results = elasticsearch_client.search_by_match("message", test_message, size=1) # Assert that we found our log message + assert search_results is not None, "Search should return results" assert "hits" in search_results, "Search response should contain 'hits'" assert "total" in search_results["hits"], "Search hits should contain 'total'" @@ -151,7 +134,7 @@ def test_elasticsearch_direct_http_handler_sends_logs( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that logs can be sorted chronologically by timestamp.""" @@ -166,31 +149,20 @@ def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( # Give Elasticsearch time to process all documents time.sleep(2) - # Query Elasticsearch to get all our test messages sorted by timestamp - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - # Search for all messages containing our test prefix - search_query = { - "query": {"match_phrase_prefix": {"message": "Chronological test message"}}, - "sort": [{"@timestamp": {"order": "asc"}}], # Ascending order (oldest first) - "size": 10, - } - - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", + search_results = elasticsearch_client.search_by_match_phrase_prefix( + "message", "Chronological test message", size=10 ) - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() + # Add sorting to the search + if search_results is None: + search_results = elasticsearch_client.search( + {"match_phrase_prefix": {"message": "Chronological test message"}}, + size=10, + sort=[{"@timestamp": {"order": "asc"}}], + ) - search_results = response.json() + assert search_results is not None, "Search should return results" # Verify we found our messages hits = search_results["hits"]["hits"] @@ -216,7 +188,7 @@ def test_elasticsearch_direct_http_handler_sorts_logs_chronologically( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_includes_rollout_id( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that ElasticsearchDirectHttpHandler includes rollout_id field in indexed logs.""" @@ -229,35 +201,11 @@ def test_elasticsearch_direct_http_handler_includes_rollout_id( # Give Elasticsearch a moment to process the document time.sleep(3) - # Query Elasticsearch to verify the document was received with rollout_id - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - - # Prepare the search query - search_query = { - "query": {"match": {"message": test_message}}, - "sort": [{"@timestamp": {"order": "desc"}}], - "size": 1, - } - - # Execute the search - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", - ) - - # Check for errors - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() - - search_results = response.json() + # Search for the document using the client + search_results = elasticsearch_client.search_by_match("message", test_message, size=1) # Assert that we found our log message + assert search_results is not None, "Search should return results" assert "hits" in search_results, "Search response should contain 'hits'" assert "total" in search_results["hits"], "Search hits should contain 'total'" @@ -298,7 +246,7 @@ def test_elasticsearch_direct_http_handler_includes_rollout_id( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_search_by_rollout_id( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that logs can be searched by rollout_id field in Elasticsearch.""" @@ -313,35 +261,11 @@ def test_elasticsearch_direct_http_handler_search_by_rollout_id( # Give Elasticsearch time to process all documents time.sleep(3) - # Query Elasticsearch to search by rollout_id - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - # Search for logs with our specific rollout_id using term query - search_query = { - "query": {"term": {"rollout_id": rollout_id}}, - "sort": [{"@timestamp": {"order": "desc"}}], - "size": 10, - } - - # Execute the search - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", - ) - - # Check for errors - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() - - search_results = response.json() + search_results = elasticsearch_client.search_by_term("rollout_id", rollout_id, size=10) # Assert that we found our log messages + assert search_results is not None, "Search should return results" assert "hits" in search_results, "Search response should contain 'hits'" assert "total" in search_results["hits"], "Search hits should contain 'total'" @@ -374,24 +298,9 @@ def test_elasticsearch_direct_http_handler_search_by_rollout_id( # Test searching for a different rollout_id (should return no results) different_rollout_id = f"different-rollout-{time.time()}" - search_query_different = { - "query": {"term": {"rollout_id": different_rollout_id}}, - "size": 10, - } - - response_different = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query_different, - verify=parsed_url.scheme == "https", - ) + different_results = elasticsearch_client.search_by_term("rollout_id", different_rollout_id, size=10) - if response_different.status_code != 200: - print(f"Elasticsearch search failed with status {response_different.status_code}") - print(f"Response: {response_different.text}") - response_different.raise_for_status() - - different_results = response_different.json() + assert different_results is not None, "Search should return results" different_total_hits = different_results["hits"]["total"] if isinstance(different_total_hits, dict): different_count = different_total_hits["value"] @@ -406,7 +315,7 @@ def test_elasticsearch_direct_http_handler_search_by_rollout_id( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_logs_status_info( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that ElasticsearchDirectHttpHandler logs Status class instances and can search by status code.""" from eval_protocol import Status @@ -423,35 +332,11 @@ def test_elasticsearch_direct_http_handler_logs_status_info( # Give Elasticsearch time to process the document time.sleep(3) - # Query Elasticsearch to verify the document was received with status info - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - # Search for logs with our specific status code - search_query = { - "query": {"term": {"status_code": test_status.code.value}}, - "sort": [{"@timestamp": {"order": "desc"}}], - "size": 1, - } - - # Execute the search - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", - ) - - # Check for errors - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() - - search_results = response.json() + search_results = elasticsearch_client.search_by_term("status_code", test_status.code.value, size=1) # Assert that we found our log message + assert search_results is not None, "Search should return results" assert "hits" in search_results, "Search response should contain 'hits'" assert "total" in search_results["hits"], "Search hits should contain 'total'" @@ -496,7 +381,7 @@ def test_elasticsearch_direct_http_handler_logs_status_info( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") def test_elasticsearch_direct_http_handler_search_by_status_code( - elasticsearch_config: ElasticSearchConfig, test_logger: logging.Logger, rollout_id: str + elasticsearch_client: ElasticsearchClient, test_logger: logging.Logger, rollout_id: str ): """Test that logs can be searched by status code in Elasticsearch.""" from eval_protocol.models import Status @@ -519,36 +404,12 @@ def test_elasticsearch_direct_http_handler_search_by_status_code( # Give Elasticsearch time to process all documents time.sleep(3) - # Query Elasticsearch to search by specific status code - parsed_url = urlparse(elasticsearch_config.url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - search_url = f"{base_url}/{elasticsearch_config.index_name}/_search" - # Search for logs with RUNNING status code running_status = Status.Code.RUNNING - search_query = { - "query": {"term": {"status_code": running_status.value}}, - "sort": [{"@timestamp": {"order": "desc"}}], - "size": 10, - } - - # Execute the search - response = requests.post( - search_url, - headers={"Content-Type": "application/json", "Authorization": f"ApiKey {elasticsearch_config.api_key}"}, - json=search_query, - verify=parsed_url.scheme == "https", - ) - - # Check for errors - if response.status_code != 200: - print(f"Elasticsearch search failed with status {response.status_code}") - print(f"Response: {response.text}") - response.raise_for_status() - - search_results = response.json() + search_results = elasticsearch_client.search_by_term("status_code", running_status.value, size=10) # Assert that we found our log messages + assert search_results is not None, "Search should return results" assert "hits" in search_results, "Search response should contain 'hits'" assert "total" in search_results["hits"], "Search hits should contain 'total'" From ccd22c1e028e3af8f75dfb5bbcc4a5718dec20f9 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 16:14:18 -0700 Subject: [PATCH 7/9] rename to match official capitalization --- eval_protocol/logging/elasticsearch_direct_http_handler.py | 4 ++-- tests/logging/test_elasticsearch_direct_http_handler.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index 66ad8751..e6439ed3 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -11,7 +11,7 @@ from .elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig -class ElasticSearchDirectHttpHandler(logging.Handler): +class ElasticsearchDirectHttpHandler(logging.Handler): def __init__(self, elasticsearch_config: ElasticSearchConfig) -> None: super().__init__() self.config = ESConfig( @@ -55,7 +55,7 @@ def _get_rollout_id(self, record: logging.LogRecord) -> str: rollout_id = os.getenv("EP_ROLLOUT_ID") if rollout_id is None: raise ValueError( - "EP_ROLLOUT_ID environment variable is not set but needed for ElasticSearchDirectHttpHandler" + "EP_ROLLOUT_ID environment variable is not set but needed for ElasticsearchDirectHttpHandler" ) return rollout_id diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index be70d3ea..f63b56ba 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -3,7 +3,7 @@ import time import pytest -from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticSearchDirectHttpHandler +from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler from eval_protocol.logging.elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig @@ -43,7 +43,7 @@ def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: """Create and configure ElasticsearchDirectHttpHandler.""" # Use a unique test-specific index name with timestamp - handler = ElasticSearchDirectHttpHandler(elasticsearch_config) + handler = ElasticsearchDirectHttpHandler(elasticsearch_config) # Set a specific log level handler.setLevel(logging.INFO) From 5f3f4f93f89148bd937b7db2b0c310b031bcb174 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 16:27:55 -0700 Subject: [PATCH 8/9] Refactor Elasticsearch configuration handling - Updated Elasticsearch integration to use ElasticSearchConfig from the types module, improving consistency across the codebase. - Removed the local ElasticsearchConfig dataclass in favor of the new model, streamlining configuration management. - Adjusted related classes and tests to utilize the updated configuration structure, enhancing maintainability and readability. --- eval_protocol/logging/elasticsearch_client.py | 19 ++----------------- .../elasticsearch_direct_http_handler.py | 4 ++-- .../logging/elasticsearch_index_manager.py | 5 +++-- .../pytest/remote_rollout_processor.py | 5 +++++ .../types/remote_rollout_processor.py | 7 +++++++ .../test_elasticsearch_direct_http_handler.py | 8 +++----- 6 files changed, 22 insertions(+), 26 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_client.py b/eval_protocol/logging/elasticsearch_client.py index 917a06b3..63835b40 100644 --- a/eval_protocol/logging/elasticsearch_client.py +++ b/eval_protocol/logging/elasticsearch_client.py @@ -10,28 +10,13 @@ import requests from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse -from dataclasses import dataclass - - -@dataclass -class ElasticsearchConfig: - """Configuration for Elasticsearch client.""" - - url: str - api_key: str - index_name: str - verify_ssl: bool = True - - def __post_init__(self): - """Parse URL to determine SSL verification setting.""" - parsed_url = urlparse(self.url) - self.verify_ssl = parsed_url.scheme == "https" +from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig class ElasticsearchClient: """Centralized client for all Elasticsearch operations.""" - def __init__(self, config: ElasticsearchConfig): + def __init__(self, config: ElasticSearchConfig): """Initialize the Elasticsearch client. Args: diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index e6439ed3..3b03cb50 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -8,13 +8,13 @@ from datetime import datetime from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig -from .elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig +from .elasticsearch_client import ElasticsearchClient class ElasticsearchDirectHttpHandler(logging.Handler): def __init__(self, elasticsearch_config: ElasticSearchConfig) -> None: super().__init__() - self.config = ESConfig( + self.config = ElasticSearchConfig( url=elasticsearch_config.url, api_key=elasticsearch_config.api_key, index_name=elasticsearch_config.index_name, diff --git a/eval_protocol/logging/elasticsearch_index_manager.py b/eval_protocol/logging/elasticsearch_index_manager.py index 08e18813..34f6a75a 100644 --- a/eval_protocol/logging/elasticsearch_index_manager.py +++ b/eval_protocol/logging/elasticsearch_index_manager.py @@ -1,5 +1,6 @@ from typing import Dict, Any, Optional -from .elasticsearch_client import ElasticsearchClient, ElasticsearchConfig +from .elasticsearch_client import ElasticsearchClient +from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig class ElasticsearchIndexManager: @@ -13,7 +14,7 @@ def __init__(self, base_url: str, index_name: str, api_key: str) -> None: index_name: Name of the index to manage api_key: API key for authentication """ - self.config = ElasticsearchConfig(url=base_url, api_key=api_key, index_name=index_name) + self.config = ElasticSearchConfig(url=base_url, api_key=api_key, index_name=index_name) self.client = ElasticsearchClient(self.config) self._mapping_created: bool = False diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 847561cd..ce857313 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -4,6 +4,7 @@ import requests +from eval_protocol.logging.elasticsearch_client import ElasticsearchClient from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig, InitRequest, RolloutMetadata @@ -182,6 +183,10 @@ def _get_status() -> Dict[str, Any]: r.raise_for_status() return r.json() + elasticsearch_client = ( + ElasticsearchClient(self._elastic_search_config) if self._elastic_search_config else None + ) + while time.time() < deadline: try: status = await asyncio.to_thread(_get_status) diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index e8ccdf75..281423e7 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field +from urllib.parse import urlparse from eval_protocol.models import Message, Status @@ -16,6 +17,12 @@ class ElasticSearchConfig(BaseModel): api_key: str index_name: str + @property + def verify_ssl(self) -> bool: + """Infer verify_ssl from URL scheme.""" + parsed_url = urlparse(self.url) + return parsed_url.scheme == "https" + class RolloutMetadata(BaseModel): """Metadata for rollout execution.""" diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index f63b56ba..0e364959 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -4,7 +4,7 @@ import pytest from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler -from eval_protocol.logging.elasticsearch_client import ElasticsearchClient, ElasticsearchConfig as ESConfig +from eval_protocol.logging.elasticsearch_client import ElasticsearchClient from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig @@ -54,10 +54,8 @@ def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: @pytest.fixture def elasticsearch_client(elasticsearch_config: ElasticSearchConfig): """Create an Elasticsearch client for testing.""" - config = ESConfig( - url=elasticsearch_config.url, api_key=elasticsearch_config.api_key, index_name=elasticsearch_config.index_name - ) - return ElasticsearchClient(config) + # Create a new config instance for the client + return ElasticsearchClient(elasticsearch_config) @pytest.fixture From f8a2e4c2752735445a8cd4508a1d7781afbc3dfb Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 1 Oct 2025 16:28:35 -0700 Subject: [PATCH 9/9] Refactor Elasticsearch configuration naming for consistency - Updated all instances of ElasticSearchConfig to ElasticsearchConfig across the codebase to align with official naming conventions. - Ensured that related classes and methods reflect this change, enhancing code clarity and maintainability. --- eval_protocol/logging/elasticsearch_client.py | 4 ++-- .../logging/elasticsearch_direct_http_handler.py | 6 +++--- .../logging/elasticsearch_index_manager.py | 4 ++-- eval_protocol/pytest/elasticsearch_setup.py | 16 ++++++++-------- eval_protocol/pytest/remote_rollout_processor.py | 6 +++--- eval_protocol/types/remote_rollout_processor.py | 4 ++-- .../test_elasticsearch_direct_http_handler.py | 6 +++--- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/eval_protocol/logging/elasticsearch_client.py b/eval_protocol/logging/elasticsearch_client.py index 63835b40..b7ca7f0e 100644 --- a/eval_protocol/logging/elasticsearch_client.py +++ b/eval_protocol/logging/elasticsearch_client.py @@ -10,13 +10,13 @@ import requests from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig class ElasticsearchClient: """Centralized client for all Elasticsearch operations.""" - def __init__(self, config: ElasticSearchConfig): + def __init__(self, config: ElasticsearchConfig): """Initialize the Elasticsearch client. Args: diff --git a/eval_protocol/logging/elasticsearch_direct_http_handler.py b/eval_protocol/logging/elasticsearch_direct_http_handler.py index 3b03cb50..015445ce 100644 --- a/eval_protocol/logging/elasticsearch_direct_http_handler.py +++ b/eval_protocol/logging/elasticsearch_direct_http_handler.py @@ -7,14 +7,14 @@ from typing import Optional, Tuple, Any, Dict from datetime import datetime -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig from .elasticsearch_client import ElasticsearchClient class ElasticsearchDirectHttpHandler(logging.Handler): - def __init__(self, elasticsearch_config: ElasticSearchConfig) -> None: + def __init__(self, elasticsearch_config: ElasticsearchConfig) -> None: super().__init__() - self.config = ElasticSearchConfig( + self.config = ElasticsearchConfig( url=elasticsearch_config.url, api_key=elasticsearch_config.api_key, index_name=elasticsearch_config.index_name, diff --git a/eval_protocol/logging/elasticsearch_index_manager.py b/eval_protocol/logging/elasticsearch_index_manager.py index 34f6a75a..2687e802 100644 --- a/eval_protocol/logging/elasticsearch_index_manager.py +++ b/eval_protocol/logging/elasticsearch_index_manager.py @@ -1,6 +1,6 @@ from typing import Dict, Any, Optional from .elasticsearch_client import ElasticsearchClient -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig class ElasticsearchIndexManager: @@ -14,7 +14,7 @@ def __init__(self, base_url: str, index_name: str, api_key: str) -> None: index_name: Name of the index to manage api_key: API key for authentication """ - self.config = ElasticSearchConfig(url=base_url, api_key=api_key, index_name=index_name) + self.config = ElasticsearchConfig(url=base_url, api_key=api_key, index_name=index_name) self.client = ElasticsearchClient(self.config) self._mapping_created: bool = False diff --git a/eval_protocol/pytest/elasticsearch_setup.py b/eval_protocol/pytest/elasticsearch_setup.py index 1f3af3fc..56ac1ded 100644 --- a/eval_protocol/pytest/elasticsearch_setup.py +++ b/eval_protocol/pytest/elasticsearch_setup.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from eval_protocol.directory_utils import find_eval_protocol_dir -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig from eval_protocol.logging.elasticsearch_index_manager import ElasticsearchIndexManager logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ class ElasticsearchSetup: def __init__(self): self.eval_protocol_dir = find_eval_protocol_dir() - def setup_elasticsearch(self, index_name: str = "default-logs") -> ElasticSearchConfig: + def setup_elasticsearch(self, index_name: str = "default-logs") -> ElasticsearchConfig: """ Set up Elasticsearch, handling both local and remote scenarios. @@ -32,7 +32,7 @@ def setup_elasticsearch(self, index_name: str = "default-logs") -> ElasticSearch index_name: Name of the Elasticsearch index to use for logging Returns: - ElasticSearchConfig for the running instance with the specified index name. + ElasticsearchConfig for the running instance with the specified index name. """ elastic_start_local_dir = os.path.join(self.eval_protocol_dir, "elastic-start-local") env_file_path = os.path.join(elastic_start_local_dir, ".env") @@ -48,11 +48,11 @@ def setup_elasticsearch(self, index_name: str = "default-logs") -> ElasticSearch self.create_logging_index(index_name) # Return config with the specified index name - return ElasticSearchConfig(url=config.url, api_key=config.api_key, index_name=index_name) + return ElasticsearchConfig(url=config.url, api_key=config.api_key, index_name=index_name) def _setup_existing_docker_elasticsearch( self, elastic_start_local_dir: str, env_file_path: str - ) -> ElasticSearchConfig: + ) -> ElasticsearchConfig: """Set up Elasticsearch using existing Docker start.sh script.""" from eval_protocol.utils.subprocess_utils import run_script_and_wait @@ -63,7 +63,7 @@ def _setup_existing_docker_elasticsearch( ) return self._parse_elastic_env_file(env_file_path) - def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> ElasticSearchConfig: + def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> ElasticsearchConfig: """Set up Elasticsearch by initializing Docker setup from scratch with retry logic.""" max_retries = 2 for attempt in range(max_retries): @@ -126,7 +126,7 @@ def _handle_existing_elasticsearch_container(self, output: str) -> bool: return False return False - def _parse_elastic_env_file(self, env_file_path: str) -> ElasticSearchConfig: + def _parse_elastic_env_file(self, env_file_path: str) -> ElasticsearchConfig: """Parse ES_LOCAL_API_KEY and ES_LOCAL_URL from .env file.""" loaded = load_dotenv(env_file_path) if not loaded: @@ -138,7 +138,7 @@ def _parse_elastic_env_file(self, env_file_path: str) -> ElasticSearchConfig: if not url or not api_key: raise ElasticsearchSetupError("Failed to parse ES_LOCAL_API_KEY and ES_LOCAL_URL from .env file") - return ElasticSearchConfig(url=url, api_key=api_key, index_name="default-logs") + return ElasticsearchConfig(url=url, api_key=api_key, index_name="default-logs") def create_logging_index(self, index_name: str) -> bool: """Create an Elasticsearch index with proper mapping for logging data. diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index ce857313..c76bac61 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -7,7 +7,7 @@ from eval_protocol.logging.elasticsearch_client import ElasticsearchClient from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig, InitRequest, RolloutMetadata +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .elasticsearch_setup import ElasticsearchSetup @@ -34,7 +34,7 @@ def __init__( timeout_seconds: float = 120.0, output_data_loader: Callable[[str], DynamicDataLoader], disable_elastic_search: bool = False, - elastic_search_config: Optional[ElasticSearchConfig] = None, + elastic_search_config: Optional[ElasticsearchConfig] = None, ): # Prefer constructor-provided configuration. These can be overridden via # config.kwargs at call time for backward compatibility. @@ -56,7 +56,7 @@ def setup(self) -> None: self._elastic_search_config = self._setup_elastic_search() logger.info("Elasticsearch setup complete") - def _setup_elastic_search(self) -> ElasticSearchConfig: + def _setup_elastic_search(self) -> ElasticsearchConfig: """Set up Elasticsearch using the dedicated setup module.""" setup = ElasticsearchSetup() return setup.setup_elasticsearch() diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 281423e7..67c3158a 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -8,7 +8,7 @@ from eval_protocol.models import Message, Status -class ElasticSearchConfig(BaseModel): +class ElasticsearchConfig(BaseModel): """ Configuration for Elasticsearch. """ @@ -38,7 +38,7 @@ class InitRequest(BaseModel): """Request model for POST /init endpoint.""" model: str - elastic_search_config: Optional[ElasticSearchConfig] = None + elastic_search_config: Optional[ElasticsearchConfig] = None messages: Optional[List[Message]] = None tools: Optional[List[Dict[str, Any]]] = None diff --git a/tests/logging/test_elasticsearch_direct_http_handler.py b/tests/logging/test_elasticsearch_direct_http_handler.py index 0e364959..6a50f088 100644 --- a/tests/logging/test_elasticsearch_direct_http_handler.py +++ b/tests/logging/test_elasticsearch_direct_http_handler.py @@ -6,7 +6,7 @@ from eval_protocol.logging.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler from eval_protocol.logging.elasticsearch_client import ElasticsearchClient from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup -from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig @pytest.fixture @@ -39,7 +39,7 @@ def elasticsearch_config(): @pytest.fixture -def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: str): +def elasticsearch_handler(elasticsearch_config: ElasticsearchConfig, rollout_id: str): """Create and configure ElasticsearchDirectHttpHandler.""" # Use a unique test-specific index name with timestamp @@ -52,7 +52,7 @@ def elasticsearch_handler(elasticsearch_config: ElasticSearchConfig, rollout_id: @pytest.fixture -def elasticsearch_client(elasticsearch_config: ElasticSearchConfig): +def elasticsearch_client(elasticsearch_config: ElasticsearchConfig): """Create an Elasticsearch client for testing.""" # Create a new config instance for the client return ElasticsearchClient(elasticsearch_config)