11from abc import ABC , abstractmethod
22from collections import defaultdict
3+ from datetime import datetime
34import json
45from .utils .messages import MessagesList
56from ..templates .templates import get_template
6- from ..__init__ import AGENT_DATA_DIR
7+ from .. import AGENT_DATA_DIR
78from .llm_backends import (
89 AsyncVLLMBackend ,
910 AsyncVerlBackend ,
2324import logging
2425from .chain .streaming_observer import ConsoleStreamObserver , StreamingManager
2526from .utils .tokenizer import create_processor , create_tokenizer
27+ from ..utils .monitor import JsonlSink , Monitor , WandbSink
2628try :
2729 from verl .protocol import DataProto
2830except ImportError :
@@ -51,10 +53,12 @@ def __init__(
5153 backend_config : Any = None ,
5254 reward_fn : Callable = None ,
5355 log_file : str = "agent" ,
54- project_name : str = None ,
55- run_name : str = None ,
5656 streaming : str = "console" ,
5757 debug : bool = False ,
58+ monitors : List [str ] = [],
59+ wandb_project_name : str = None ,
60+ wandb_run_name : str = None ,
61+ local_cache_dir : str = None ,
5862 ** kwargs # To pass other unused arguments
5963 ):
6064 """
@@ -94,7 +98,6 @@ def __init__(
9498
9599 # Create appropriate tokenizer for trajectory processing
96100 self .tokenizer = create_tokenizer (model_name_or_path )
97-
98101 self .processor = create_processor (model_name_or_path )
99102
100103 self ._reward_fn = reward_fn
@@ -104,8 +107,12 @@ def __init__(
104107 else :
105108 self .jinja_template = get_template (self .template ).jinja_template ()
106109
107- self .project_name = project_name
108- self .run_name = run_name
110+ self .wandb_project_name = wandb_project_name
111+ self .wandb_run_name = wandb_run_name
112+ self .local_cache_dir = local_cache_dir
113+ self .local_run_cache_dir = None
114+ self ._initialize_monitor (monitors )
115+
109116 self .streaming_manager = StreamingManager ()
110117 if streaming == "console" :
111118 self .streaming_manager .add_observer (ConsoleStreamObserver ())
@@ -177,6 +184,17 @@ def _preprocess_messages(self, messages: List[Dict]):
177184
178185 return messages_list .to_list ()
179186
187+ def _initialize_monitor (self , monitors : List [str ]) -> None :
188+ for monitor in monitors :
189+ if monitor == "local" :
190+ assert self .local_cache_dir is not None , "local_cache_dir must be set when using local monitor."
191+ self .local_run_cache_dir = f"{ os .path .join (self .local_cache_dir , os .path .basename (self .model_name_or_path ), datetime .now ().strftime ('%Y%m%d_%H%M%S' ))} "
192+ Monitor .add_sink ("jsonl" , JsonlSink (f"{ self .local_run_cache_dir } /" ))
193+ elif monitor == "wandb" :
194+ Monitor .add_sink ("wandb" , WandbSink (project = self .wandb_project_name , run_name = self .wandb_run_name ))
195+ else :
196+ raise ValueError (f"Monitor { monitor } is not supported." )
197+
180198 async def run (self ,
181199 messages : Union [List [dict ], np .ndarray , Dict ],
182200 max_turns : int ,
@@ -392,4 +410,4 @@ def get_verl_data_proto(self):
392410 batch = DataProto .from_single_dict (inputs , meta_info = {"use_agent" : True })
393411
394412 return batch
395-
413+
0 commit comments