diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 9f17f8ac..c18ee329 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -34,6 +34,7 @@ from .pytest.parameterize import DefaultParameterIdGenerator from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler from .log_utils.rollout_id_filter import RolloutIdFilter +from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler from .types.remote_rollout_processor import ( InitRequest, @@ -68,6 +69,7 @@ __all__ = [ "ElasticsearchDirectHttpHandler", "RolloutIdFilter", + "setup_rollout_logging_for_elasticsearch_handler", "DataLoaderConfig", "Status", "RemoteRolloutProcessor", diff --git a/eval_protocol/log_utils/util.py b/eval_protocol/log_utils/util.py new file mode 100644 index 00000000..a72907b5 --- /dev/null +++ b/eval_protocol/log_utils/util.py @@ -0,0 +1,22 @@ +import os +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig +from .elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler + + +def setup_rollout_logging_for_elasticsearch_handler( + handler: ElasticsearchDirectHttpHandler, rollout_id: str, elastic_search_config: ElasticsearchConfig +) -> None: + """ + Whenever a new subprocess is created, we need to setup the rollout context + for the subprocess. This is useful when implementing your own remote server + for rollout processing. + + 1. Set the EP_ROLLOUT_ID environment variable + 2. Configure the Elasticsearch handler with the Elasticsearch config + """ + + # this should only affect this subprocess so logs from this subprocess can + # be correlated to the rollout + os.environ["EP_ROLLOUT_ID"] = rollout_id + + handler.configure(elasticsearch_config=elastic_search_config)