diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..f61d26778 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -2,6 +2,7 @@ import asyncio import json +import logging import os import random import socket @@ -25,6 +26,8 @@ from agentlightning.store.base import LightningStore from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task +logger = logging.getLogger(__name__) + __all__ = [ "AgentModeDaemon", "get_left_padded_ids_and_attention_mask", @@ -387,7 +390,9 @@ def proxy(path: str): # type: ignore current_time = time.time() num_requests += 1 if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0: - print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}") + logger.debug( + "Proxying %s request to %s. Request data: %s", request.method, target_server, request.get_data() + ) last_request_time = current_time try: @@ -448,7 +453,7 @@ def run_app(): self._proxy_thread = threading.Thread(target=run_app, daemon=True) self._proxy_thread.start() - print(f"Proxy server running on port {self.proxy_port}") + logger.info("Proxy server running on port %s", self.proxy_port) async def _update_proxy_server_v1(self): model_name = self.train_information.get("model") @@ -484,12 +489,12 @@ def run_server(): self._server_thread.start() # Wait for the server's internal startup event to be set. - print("Waiting for AgentLightningServer to start...") + logger.info("Waiting for AgentLightningServer to start...") is_ready = self.server.startup_event.wait(timeout=20.0) # Wait up to 20s if not is_ready: raise RuntimeError("AgentLightningServer failed to start within the timeout period.") - print(f"AgentLightningServer control plane running on port {self.server_port}") + logger.info("AgentLightningServer control plane running on port %s", self.server_port) self._start_proxy_server_v0() else: @@ -604,22 +609,24 @@ def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[st try: future.result(timeout=300) # Wait for completion with a timeout except Exception as e: - print(f"Failed to set up data on server: {e}") + logger.error("Failed to set up data on server: %s", e) raise def _validate_data(self, rollout: RolloutLegacy): if rollout.final_reward is None: - print( - f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}." + logger.warning( + "Reward is None for rollout %s, will be auto-set to %s.", + rollout.rollout_id, + self.reward_fillna_value, ) if rollout.triplets is None: - print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.") + logger.warning("Triplet is None for rollout %s.", rollout.rollout_id) elif len(rollout.triplets) == 0: - print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.") + logger.warning("Length of triplets is 0 for rollout %s.", rollout.rollout_id) elif any(not r.response.get("token_ids", []) for r in rollout.triplets): - print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}") + logger.warning("Rollout %s contains empty response: %s", rollout.rollout_id, rollout.triplets) elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets): - print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}") + logger.warning("Rollout %s contains empty prompt: %s", rollout.rollout_id, rollout.triplets) async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy: """Convert Rollout to RolloutLegacy and validate. @@ -688,19 +695,19 @@ async def _async_run_until_finished(self, verbose: bool = True): else: self._validate_data(rollout) if rollout.rollout_id not in self._task_id_to_original_sample: - print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.") + logger.warning("Received unknown rollout ID %s, skipping.", rollout.rollout_id) else: self._completed_rollouts_v0[rollout.rollout_id] = rollout if verbose: - print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...") + logger.info("Completed %d/%d tasks...", len(self._completed_rollouts_v0), self._total_tasks_queued) await asyncio.sleep(5) - print("All tasks finished.") + logger.info("All tasks finished.") def run_until_all_finished(self, verbose: bool = True): """Synchronously waits for all queued tasks to be completed and reported.""" if self._total_tasks_queued == 0: - print("Warning: No tasks were queued.") + logger.warning("No tasks were queued.") return if self.mode == "v0": @@ -716,7 +723,7 @@ def run_until_all_finished(self, verbose: bool = True): try: future.result() # Wait indefinitely for all tasks to complete except Exception as e: - print(f"Error while waiting for tasks to finish: {e}") + logger.error("Error while waiting for tasks to finish: %s", e) raise def get_test_metrics(self): @@ -733,7 +740,7 @@ def get_test_metrics(self): final_reward_raw: Optional[float] = rollout.final_reward final_reward = self._fillna_reward(rollout) if not rollout.triplets: - print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.") + logger.warning("No triplets found for test rollout %s.", rollout.rollout_id) sample_stat_list.append({"reward": final_reward, "has_reward": final_reward_raw is not None}) continue response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets] @@ -828,7 +835,7 @@ def get_train_data_batch( if not rollout.triplets: finished_id_to_final_reward[rollout_id] = final_reward - print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.") + logger.warning("No triplets found for training rollout %s, skipping.", rollout.rollout_id) continue # The client should report triplets that contain prompt_ids and response_ids.