Skip to content

Commit 375edfa

Browse files
authored
Merge pull request #16 from Agent-One-Lab/agents
Several bug fixes. Update the monitor. Now support setting monitor types.
2 parents 075a0ad + 97566ab commit 375edfa

34 files changed

Lines changed: 2706 additions & 145 deletions

.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
121121
*.lock
122122

123123
# data
124+
data/
124125
*.parquet
125126
agentfly/agents/data/*
127+
test_cache/
126128

127129
# local logs
128130
logs
@@ -133,6 +135,10 @@ data/
133135
test_cache/
134136
/*.jpg
135137
/*.png
138+
slurm/
139+
*.err
140+
*.out
141+
*.log
136142

137143
# Notebooks
138144
agentfly/tests/*.ipynb
@@ -146,3 +152,11 @@ test_outputs/
146152
agentfly/data/
147153
*.ipynb
148154

155+
# training scripts
156+
training_scripts/
157+
verl/training_scripts/
158+
159+
# training scripts
160+
training_scripts/
161+
verl/training_scripts/
162+

agentfly/agents/agent_base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from abc import ABC, abstractmethod
22
from collections import defaultdict
3+
from datetime import datetime
34
import json
45
from .utils.messages import MessagesList
56
from ..templates.templates import get_template
6-
from ..__init__ import AGENT_DATA_DIR
7+
from .. import AGENT_DATA_DIR
78
from .llm_backends import (
89
AsyncVLLMBackend,
910
AsyncVerlBackend,
@@ -23,6 +24,7 @@
2324
import logging
2425
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
2526
from .utils.tokenizer import create_processor, create_tokenizer
27+
from ..utils.monitor import JsonlSink, Monitor, WandbSink
2628
try:
2729
from verl.protocol import DataProto
2830
except ImportError:
@@ -51,10 +53,12 @@ def __init__(
5153
backend_config: Any = None,
5254
reward_fn: Callable = None,
5355
log_file: str = "agent",
54-
project_name: str = None,
55-
run_name: str = None,
5656
streaming: str = "console",
5757
debug: bool = False,
58+
monitors: List[str] = [],
59+
wandb_project_name: str = None,
60+
wandb_run_name: str = None,
61+
local_cache_dir: str = None,
5862
**kwargs # To pass other unused arguments
5963
):
6064
"""
@@ -94,7 +98,6 @@ def __init__(
9498

9599
# Create appropriate tokenizer for trajectory processing
96100
self.tokenizer = create_tokenizer(model_name_or_path)
97-
98101
self.processor = create_processor(model_name_or_path)
99102

100103
self._reward_fn = reward_fn
@@ -104,8 +107,12 @@ def __init__(
104107
else:
105108
self.jinja_template = get_template(self.template).jinja_template()
106109

107-
self.project_name = project_name
108-
self.run_name = run_name
110+
self.wandb_project_name = wandb_project_name
111+
self.wandb_run_name = wandb_run_name
112+
self.local_cache_dir = local_cache_dir
113+
self.local_run_cache_dir = None
114+
self._initialize_monitor(monitors)
115+
109116
self.streaming_manager = StreamingManager()
110117
if streaming == "console":
111118
self.streaming_manager.add_observer(ConsoleStreamObserver())
@@ -177,6 +184,17 @@ def _preprocess_messages(self, messages: List[Dict]):
177184

178185
return messages_list.to_list()
179186

187+
def _initialize_monitor(self, monitors: List[str]) -> None:
188+
for monitor in monitors:
189+
if monitor == "local":
190+
assert self.local_cache_dir is not None, "local_cache_dir must be set when using local monitor."
191+
self.local_run_cache_dir = f"{os.path.join(self.local_cache_dir, os.path.basename(self.model_name_or_path), datetime.now().strftime('%Y%m%d_%H%M%S'))}"
192+
Monitor.add_sink("jsonl", JsonlSink(f"{self.local_run_cache_dir}/"))
193+
elif monitor == "wandb":
194+
Monitor.add_sink("wandb", WandbSink(project=self.wandb_project_name, run_name=self.wandb_run_name))
195+
else:
196+
raise ValueError(f"Monitor {monitor} is not supported.")
197+
180198
async def run(self,
181199
messages: Union[List[dict], np.ndarray, Dict],
182200
max_turns: int,
@@ -392,4 +410,4 @@ def get_verl_data_proto(self):
392410
batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True})
393411

394412
return batch
395-
413+

agentfly/agents/chain/chain_base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def __init__(self):
137137
self.terminal_status = ["terminal", "finish"]
138138
self.global_step = 0
139139
self.finished_chains_count = 0
140-
self.initialize_monitor()
141140
self.monitor_info = defaultdict(list)
142141

143142
def reset(self) -> None:
@@ -333,7 +332,7 @@ async def _run_single_chain(self,
333332
await done_queue.put((chain_id, chain, current_node))
334333

335334
self.finished_chains_count += 1
336-
self.monitor_chain()
335+
self.monitor_chain(trajectory=current_node.messages.messages)
337336

338337
async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming):
339338
"""Generate response with optional streaming support."""
@@ -485,7 +484,6 @@ async def _finalize_chain(self, chain_id, chain, current_node, depth):
485484

486485
await self.release_resources(chain_id)
487486

488-
489487
async def release_resources(self, id: str) -> None:
490488
for tool in self.tools:
491489
if isinstance(tool, Tool):
@@ -498,10 +496,6 @@ async def set_tools(self, id: str, env_args: Dict[str, Any]) -> None:
498496
if isinstance(tool, Tool):
499497
await tool.set_env(id, env_args)
500498

501-
def initialize_monitor(self) -> None:
502-
Monitor.add_sink("jsonl", JsonlSink(f"{AGENT_DATA_DIR}/demo_metrics.jsonl"))
503-
Monitor.add_sink("wandb", WandbSink(project=self.project_name, run_name=self.run_name))
504-
505499
def monitor_step(self) -> None:
506500
messages = self.get_messages()
507501
avg_turns = 0
@@ -589,9 +583,19 @@ def monitor_step(self) -> None:
589583
emit(evt)
590584

591585

592-
def monitor_chain(self) -> None:
586+
def monitor_chain(self, trajectory) -> None:
593587
self.monitor_info['Agent/chains'].append(self.finished_chains_count)
594588
for tool in self.tools:
595589
if tool.is_stateful and tool.pool_size > 0:
596590
self.monitor_info[f"Agent/Tool/{tool.name}/used_env_size"].append(tool.used_env_size)
597591

592+
# We only log the trajectory to local jsonl file, for wandb much bandwidth is needed
593+
evt = MetricEvent(
594+
sinks=["jsonl"],
595+
kind="text",
596+
name="Agent/rollout/trajectory",
597+
value=json.dumps(serialize_for_json(trajectory), indent=2),
598+
x=self.global_step,
599+
x_name="Agent/rollout/step"
600+
)
601+
emit(evt)

0 commit comments

Comments
 (0)