Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import logging
import os
import random
import socket
Expand All @@ -25,6 +26,8 @@
from agentlightning.store.base import LightningStore
from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task

logger = logging.getLogger(__name__)

__all__ = [
"AgentModeDaemon",
"get_left_padded_ids_and_attention_mask",
Expand Down Expand Up @@ -387,7 +390,9 @@ def proxy(path: str): # type: ignore
current_time = time.time()
num_requests += 1
if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0:
print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}")
logger.debug(
"Proxying %s request to %s. Request data: %s", request.method, target_server, request.get_data()
)
last_request_time = current_time
Comment on lines 390 to 396

try:
Expand Down Expand Up @@ -448,7 +453,7 @@ def run_app():

self._proxy_thread = threading.Thread(target=run_app, daemon=True)
self._proxy_thread.start()
print(f"Proxy server running on port {self.proxy_port}")
logger.info("Proxy server running on port %s", self.proxy_port)

async def _update_proxy_server_v1(self):
model_name = self.train_information.get("model")
Expand Down Expand Up @@ -484,12 +489,12 @@ def run_server():
self._server_thread.start()

# Wait for the server's internal startup event to be set.
print("Waiting for AgentLightningServer to start...")
logger.info("Waiting for AgentLightningServer to start...")
is_ready = self.server.startup_event.wait(timeout=20.0) # Wait up to 20s
if not is_ready:
raise RuntimeError("AgentLightningServer failed to start within the timeout period.")

print(f"AgentLightningServer control plane running on port {self.server_port}")
logger.info("AgentLightningServer control plane running on port %s", self.server_port)

self._start_proxy_server_v0()
else:
Expand Down Expand Up @@ -604,22 +609,24 @@ def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[st
try:
future.result(timeout=300) # Wait for completion with a timeout
except Exception as e:
print(f"Failed to set up data on server: {e}")
logger.error("Failed to set up data on server: %s", e)
raise
Comment on lines 609 to 613

def _validate_data(self, rollout: RolloutLegacy):
if rollout.final_reward is None:
print(
f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}."
logger.warning(
"Reward is None for rollout %s, will be auto-set to %s.",
rollout.rollout_id,
self.reward_fillna_value,
)
if rollout.triplets is None:
print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.")
logger.warning("Triplet is None for rollout %s.", rollout.rollout_id)
elif len(rollout.triplets) == 0:
print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.")
logger.warning("Length of triplets is 0 for rollout %s.", rollout.rollout_id)
elif any(not r.response.get("token_ids", []) for r in rollout.triplets):
print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}")
logger.warning("Rollout %s contains empty response: %s", rollout.rollout_id, rollout.triplets)
elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")
logger.warning("Rollout %s contains empty prompt: %s", rollout.rollout_id, rollout.triplets)

async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy:
"""Convert Rollout to RolloutLegacy and validate.
Expand Down Expand Up @@ -688,19 +695,19 @@ async def _async_run_until_finished(self, verbose: bool = True):
else:
self._validate_data(rollout)
if rollout.rollout_id not in self._task_id_to_original_sample:
print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.")
logger.warning("Received unknown rollout ID %s, skipping.", rollout.rollout_id)
else:
self._completed_rollouts_v0[rollout.rollout_id] = rollout
if verbose:
print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...")
logger.info("Completed %d/%d tasks...", len(self._completed_rollouts_v0), self._total_tasks_queued)
await asyncio.sleep(5)

print("All tasks finished.")
logger.info("All tasks finished.")

def run_until_all_finished(self, verbose: bool = True):
"""Synchronously waits for all queued tasks to be completed and reported."""
if self._total_tasks_queued == 0:
print("Warning: No tasks were queued.")
logger.warning("No tasks were queued.")
return

if self.mode == "v0":
Expand All @@ -716,7 +723,7 @@ def run_until_all_finished(self, verbose: bool = True):
try:
future.result() # Wait indefinitely for all tasks to complete
except Exception as e:
print(f"Error while waiting for tasks to finish: {e}")
logger.error("Error while waiting for tasks to finish: %s", e)
raise

def get_test_metrics(self):
Expand All @@ -733,7 +740,7 @@ def get_test_metrics(self):
final_reward_raw: Optional[float] = rollout.final_reward
final_reward = self._fillna_reward(rollout)
if not rollout.triplets:
print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
logger.warning("No triplets found for test rollout %s.", rollout.rollout_id)
sample_stat_list.append({"reward": final_reward, "has_reward": final_reward_raw is not None})
continue
response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
Expand Down Expand Up @@ -828,7 +835,7 @@ def get_train_data_batch(

if not rollout.triplets:
finished_id_to_final_reward[rollout_id] = final_reward
print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.")
logger.warning("No triplets found for training rollout %s, skipping.", rollout.rollout_id)
continue

# The client should report triplets that contain prompt_ids and response_ids.
Expand Down