44
55import requests
66
7- from eval_protocol .log_utils .elasticsearch_client import ElasticsearchClient
87from eval_protocol .models import EvaluationRow , Status
98from eval_protocol .data_loader .dynamic_data_loader import DynamicDataLoader
109from eval_protocol .types .remote_rollout_processor import (
1110 DataLoaderConfig ,
12- ElasticsearchConfig ,
1311)
12+ from eval_protocol .adapters .fireworks_tracing import FireworksTracingAdapter
13+
1414from .rollout_processor import RolloutProcessor
1515from .types import RolloutProcessorConfig
16- from .elasticsearch_setup import ElasticsearchSetup
1716from .tracing_utils import default_fireworks_output_data_loader , build_init_request , update_row_with_remote_trace
1817import logging
1918
2221logger = logging .getLogger (__name__ )
2322
2423
25- def create_elasticsearch_config_from_env () -> ElasticsearchConfig :
26- """Setup Elasticsearch config from environment variables."""
27- url = os .getenv ("ELASTICSEARCH_URL" )
28- api_key = os .getenv ("ELASTICSEARCH_API_KEY" )
29- index_name = os .getenv ("ELASTICSEARCH_INDEX_NAME" )
30-
31- if url is None :
32- raise ValueError ("ELASTICSEARCH_URL must be set" )
33- if api_key is None :
34- raise ValueError ("ELASTICSEARCH_API_KEY must be set" )
35- if index_name is None :
36- raise ValueError ("ELASTICSEARCH_INDEX_NAME must be set" )
37- return ElasticsearchConfig (
38- url = url ,
39- api_key = api_key ,
40- index_name = index_name ,
41- )
42-
43-
4424class RemoteRolloutProcessor (RolloutProcessor ):
4525 """
4626 Rollout processor that triggers a remote HTTP server to perform the rollout.
@@ -59,8 +39,6 @@ def __init__(
5939 poll_interval : float = 1.0 ,
6040 timeout_seconds : float = 120.0 ,
6141 output_data_loader : Optional [Callable [[DataLoaderConfig ], DynamicDataLoader ]] = None ,
62- disable_elastic_search_setup : bool = False ,
63- elastic_search_config : Optional [ElasticsearchConfig ] = None ,
6442 ):
6543 # Prefer constructor-provided configuration. These can be overridden via
6644 # config.kwargs at call time for backward compatibility.
@@ -74,21 +52,7 @@ def __init__(
7452 self ._poll_interval = poll_interval
7553 self ._timeout_seconds = timeout_seconds
7654 self ._output_data_loader = output_data_loader or default_fireworks_output_data_loader
77- self ._disable_elastic_search_setup = disable_elastic_search_setup
78- self ._elastic_search_config = elastic_search_config
79-
80- def setup (self ) -> None :
81- if self ._disable_elastic_search_setup :
82- logger .info ("Elasticsearch is disabled, skipping setup" )
83- return
84- logger .info ("Setting up Elasticsearch" )
85- self ._elastic_search_config = self ._setup_elastic_search ()
86- logger .info ("Elasticsearch setup complete" )
87-
88- def _setup_elastic_search (self ) -> ElasticsearchConfig :
89- """Set up Elasticsearch using the dedicated setup module."""
90- setup = ElasticsearchSetup ()
91- return setup .setup_elasticsearch ()
55+ self ._tracing_adapter = FireworksTracingAdapter (base_url = self ._model_base_url )
9256
9357 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
9458 tasks : List [asyncio .Task [EvaluationRow ]] = []
@@ -123,7 +87,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
12387 if row .input_metadata .row_id is None :
12488 raise ValueError ("Row ID is required in RemoteRolloutProcessor" )
12589
126- init_payload = build_init_request (row , config , model_base_url , self . _elastic_search_config )
90+ init_payload = build_init_request (row , config , model_base_url )
12791
12892 # Fire-and-poll
12993 def _post_init () -> None :
@@ -153,10 +117,6 @@ def _get_status() -> Dict[str, Any]:
153117 r .raise_for_status ()
154118 return r .json ()
155119
156- elasticsearch_client = (
157- ElasticsearchClient (self ._elastic_search_config ) if self ._elastic_search_config else None
158- )
159-
160120 continue_polling_status = True
161121 while time .time () < deadline :
162122 try :
@@ -178,29 +138,41 @@ def _get_status() -> Dict[str, Any]:
178138 # For all other exceptions, raise them
179139 raise
180140
181- if not elasticsearch_client :
182- continue
183-
184- search_results = elasticsearch_client .search_by_status_code_not_in (
185- row .execution_metadata .rollout_id , [Status .Code .RUNNING ]
141+ # Search Fireworks tracing logs for completion
142+ completed_logs = self ._tracing_adapter .search_logs (
143+ tags = [f"rollout_id:{ row .execution_metadata .rollout_id } " ]
186144 )
187- hits = search_results ["hits" ]["hits" ] if search_results else []
145+ # Filter for logs that actually have status information
146+ status_logs = []
147+ for log in completed_logs :
148+ status_dict = log .get ("status" )
149+ if status_dict and isinstance (status_dict , dict ) and "code" in status_dict :
150+ status_logs .append (log )
151+
152+ if status_logs :
153+ # Use the first log with status information
154+ status_log = status_logs [0 ]
155+ status_dict = status_log .get ("status" )
156+
157+ logger .info (
158+ f"Found status log for rollout { row .execution_metadata .rollout_id } : { status_log .get ('message' , '' )} "
159+ )
188160
189- if hits :
190- # log all statuses found and update rollout status from the last hit
191- for hit in hits :
192- document = hit [ "_source" ]
193- logger .info (
194- f"Found log for rollout { row .execution_metadata .rollout_id } with status code { document [ ' status_code' ] } "
195- )
196- # Update rollout status from the document
197- if "status_code" in document :
198- row . rollout_status = Status (
199- code = Status . Code ( document [ "status_code" ]) ,
200- message = document . get ( "status_message" , "" ) ,
201- details = document . get ( "status_details" , []),
202- )
203- logger .info ("Stopping status polling for rollout %s" , row .execution_metadata .rollout_id )
161+ status_code = status_dict . get ( "code" )
162+ status_message = status_dict . get ( "message" , "" )
163+ status_details = status_dict . get ( "details" , [])
164+
165+ logger .info (
166+ f"Found Fireworks log for rollout { row .execution_metadata .rollout_id } with status code { status_code } "
167+ )
168+
169+ row . rollout_status = Status (
170+ code = Status . Code ( status_code ),
171+ message = status_message ,
172+ details = status_details ,
173+ )
174+
175+ logger .info ("Stopping polling for rollout %s" , row .execution_metadata .rollout_id )
204176 break
205177
206178 await asyncio .sleep (poll_interval )
0 commit comments