@@ -37,6 +37,7 @@ def create_openenv_vllm_rollout_func(
3737 # Environment configuration
3838 env_client_cls : Optional [Type [Any ]] = None ,
3939 tasks : List [str ] | None = None ,
40+ task_var : Optional [str ] = None ,
4041 miniwob_url : str | None = None ,
4142 docker_image : str = "browsergym-env:latest" ,
4243 env_base_url : Optional [str ] = None ,
@@ -66,52 +67,48 @@ def create_openenv_vllm_rollout_func(
6667 The environment side is configured via ``env_client_cls`` and the BrowserGym
6768 parameters (``tasks``, ``miniwob_url``, ``docker_image``, etc.).
6869 """
69- print (f"\n { '=' * 80 } " , flush = True )
70- print (f "[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED" , flush = True )
70+ print (f"\n { '=' * 80 } " , flush = True )
71+ print ("[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED" , flush = True )
7172 print (f" vllm_base_url: { vllm_base_url } " , flush = True )
7273 print (f" vllm_model: { vllm_model } " , flush = True )
7374 print (f" tasks: { tasks } " , flush = True )
7475 print (f" max_steps: { max_steps } " , flush = True )
75- print (f"{ '=' * 80 } " , flush = True )
76+ print (f"{ '=' * 80 } " , flush = True )
7677 sys .stdout .flush ()
77-
78+
7879 # Import VLLMPolicy
7980 from eval_protocol .mcp .execution .vllm_policy import VLLMPolicy
8081
8182 # Global-ish task rotation offset across rollout_func calls.
8283 # This lets us rotate tasks between GRPO steps instead of always
8384 # starting from tasks[0] when a new OpenEnvRolloutProcessor is created.
8485 task_cycle_index : int = 0
85-
86+
8687 def rollout_func (prompts : List [str ], trainer ) -> Dict [str , List ]:
8788 """Execute rollouts via OpenEnv + vLLM and return GRPO-compatible results."""
8889 print ("\n [OpenEnvVLLM] rollout_func called" , flush = True )
89-
90+
9091 # Extract args from trainer
9192 args = trainer .args
9293 processing_class = trainer .processing_class
93-
94+
9495 num_generations = getattr (args , "num_generations" , 8 )
9596 print (
96- f"[OpenEnvVLLM] Received { len (prompts )} prompts, "
97- f"{ num_generations } generations each" ,
97+ f"[OpenEnvVLLM] Received { len (prompts )} prompts, { num_generations } generations each" ,
9898 flush = True ,
9999 )
100-
100+
101101 # 1) Build evaluation rows
102102 evaluation_rows : List [EvaluationRow ] = []
103103 for prompt in prompts :
104104 for gen_idx in range (num_generations ):
105- evaluation_rows .append (
106- EvaluationRow (
107- messages = [Message (role = "user" , content = prompt )],
108- input_metadata = InputMetadata (
109- completion_params = {},
110- extra = {"generation_idx" : gen_idx }
111- ),
112- )
105+ row = EvaluationRow (
106+ messages = [Message (role = "user" , content = prompt )],
107+ input_metadata = InputMetadata (completion_params = {}),
113108 )
114-
109+ row .input_metadata .generation_idx = gen_idx # type: ignore[attr-defined]
110+ evaluation_rows .append (row )
111+
115112 # 2) Build processor config with VLLMPolicy
116113 # We'll pass trainer.vllm_client to VLLMPolicy
117114 base_params : Dict [str , Any ] = {
@@ -121,37 +118,33 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
121118 }
122119 if completion_params :
123120 base_params .update (completion_params )
124-
121+
125122 print (
126- f"[OpenEnvVLLM] Temperature={ base_params ['temperature' ]} , "
127- f"max_tokens={ base_params ['max_tokens' ]} " ,
123+ f"[OpenEnvVLLM] Temperature={ base_params ['temperature' ]} , max_tokens={ base_params ['max_tokens' ]} " ,
128124 flush = True ,
129125 )
130126 print ("[OpenEnvVLLM] Using TRL VLLMClient from trainer" , flush = True )
131-
132- max_concurrency = concurrency if concurrency is not None else getattr (
133- args , "per_device_train_batch_size" , 1
134- )
127+
128+ max_concurrency = concurrency if concurrency is not None else getattr (args , "per_device_train_batch_size" , 1 )
135129 print (
136- f"[OpenEnvVLLM] Max concurrency={ max_concurrency } , "
137- f"max_steps={ max_steps } " ,
130+ f"[OpenEnvVLLM] Max concurrency={ max_concurrency } , max_steps={ max_steps } " ,
138131 flush = True ,
139132 )
140-
133+
141134 config = RolloutProcessorConfig (
142135 completion_params = base_params ,
143136 mcp_config_path = "" ,
144137 semaphore = asyncio .Semaphore (max_concurrency ),
145138 steps = max_steps ,
146139 )
147-
140+
148141 # 3) Execute rollouts with VLLMPolicy
149142 print (
150143 f"[OpenEnvVLLM] Instantiating processor: "
151144 f"{ processor_cls .__name__ if processor_cls else 'OpenEnvRolloutProcessor' } " ,
152145 flush = True ,
153146 )
154-
147+
155148 # Create policy factory that uses trainer's vllm_client
156149 def vllm_policy_factory (model , temperature , max_tokens , base_url = None , ** kwargs ):
157150 """Factory that creates VLLMPolicy using trainer's vllm_client."""
@@ -164,7 +157,7 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
164157 top_k = kwargs .get ("top_k" ),
165158 ** kwargs ,
166159 )
167-
160+
168161 Processor = processor_cls or OpenEnvRolloutProcessor
169162 _kwargs : Dict [str , Any ] = dict (processor_kwargs or {})
170163 _kwargs .setdefault ("env_factory" , env_factory )
@@ -187,6 +180,7 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
187180 flush = True ,
188181 )
189182 _kwargs .setdefault ("tasks" , rotated_tasks )
183+ _kwargs .setdefault ("task_var" , task_var )
190184
191185 _kwargs .setdefault ("miniwob_url" , miniwob_url )
192186 _kwargs .setdefault ("docker_image" , docker_image )
@@ -202,48 +196,49 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
202196 _kwargs .setdefault ("viewport_height" , viewport_height )
203197 _kwargs .setdefault ("timeout_ms" , timeout_ms )
204198 _kwargs .setdefault ("num_generations" , num_generations )
205-
199+
206200 processor = Processor (** _kwargs )
207- print (f "[OpenEnvVLLM] Processor instantiated successfully" , flush = True )
208-
201+ print ("[OpenEnvVLLM] Processor instantiated successfully" , flush = True )
202+
209203 loop = asyncio .new_event_loop ()
210204 asyncio .set_event_loop (loop )
211205 try :
206+
212207 async def _run_all ():
213208 tasks_list = processor (evaluation_rows , config )
214209 return await asyncio .gather (* tasks_list )
215-
210+
216211 completed_rows = loop .run_until_complete (_run_all ())
217212 print (
218213 f"[OpenEnvVLLM] All rollouts completed: { len (completed_rows )} results" ,
219214 flush = True ,
220215 )
221216 finally :
222217 loop .close ()
223-
218+
224219 # 4) Convert to Wordle-style format (no splitting)
225220 # Each completed_row is one rollout with multiple turns
226221 # We .extend() tokens across turns, then .append() per rollout
227222 print (
228223 f"[OpenEnvVLLM] Converting { len (completed_rows )} rollouts to TRL format" ,
229224 flush = True ,
230225 )
231-
226+
232227 tokenizer = getattr (processing_class , "tokenizer" , None ) or processing_class
233228 encode_fn = getattr (tokenizer , "encode" , None )
234-
229+
235230 episode_prompt_ids : List [List [int ]] = []
236231 episode_completion_ids : List [List [int ]] = []
237232 episode_logprobs : List [List [float ]] = []
238233 step_rewards_all : List [List [float ]] = []
239-
234+
240235 for idx , row in enumerate (completed_rows ):
241236 # Accumulate tokens across all turns in this rollout
242237 prompt_ids : List [int ] = [] # .extend() for each turn
243238 completion_ids : List [int ] = [] # .extend() for each turn
244239 logprobs : List [float ] = [] # .extend() for each turn
245240 rewards : List [float ] = []
246-
241+
247242 # Go through all messages and accumulate tokens
248243 for msg in row .messages :
249244 if msg .role == "user" :
@@ -259,50 +254,50 @@ async def _run_all():
259254 content = msg .content or ""
260255 if isinstance (content , str ) and content .startswith ("__ep_step_rewards__:" ):
261256 import json
257+
262258 payload = content .split (":" , 1 )[1 ]
263259 rewards = json .loads (payload ) or []
264260 except Exception :
265261 pass
266-
267- # Fallback for rewards
268- if not rewards and hasattr ( row . execution_metadata , "extra" ) :
262+
263+ # Fallback for rewards (if extra field exists via model_config extra="allow")
264+ if not rewards :
269265 try :
270- rewards = row .execution_metadata .extra .get ("step_rewards" , []) or []
266+ extra = getattr (row .execution_metadata , "extra" , None )
267+ if isinstance (extra , dict ):
268+ rewards = extra .get ("step_rewards" , []) or []
271269 except Exception :
272270 pass
273-
271+
274272 # Append accumulated tokens for this episode
275273 episode_prompt_ids .append (prompt_ids if prompt_ids else [0 ])
276274 episode_completion_ids .append (completion_ids if completion_ids else [0 ])
277275 episode_logprobs .append (logprobs if logprobs else [0.0 ])
278276 step_rewards_all .append (rewards if rewards else [0.0 ])
279-
277+
280278 total_reward = sum (sum (r ) for r in step_rewards_all )
281279 avg_reward = total_reward / len (step_rewards_all ) if step_rewards_all else 0.0
282280 print (
283281 f"[OpenEnvVLLM] Total reward={ total_reward :.2f} , Avg reward={ avg_reward :.2f} " ,
284282 flush = True ,
285283 )
286- print (
287- f"[OpenEnvVLLM] Returning { len (episode_prompt_ids )} episodes" , flush = True
288- )
284+ print (f"[OpenEnvVLLM] Returning { len (episode_prompt_ids )} episodes" , flush = True )
289285 sys .stdout .flush ()
290-
286+
291287 # Return in Wordle format
292288 # Tokens: 2D arrays (accumulate across turns, one list per episode)
293289 # Rewards: 1D arrays (one scalar per episode)
294290 total_rewards = [sum (r ) for r in step_rewards_all ] # Sum step rewards per episode
295-
291+
296292 print (f"[OpenEnvVLLM] Episode rewards: { total_rewards } " , flush = True )
297-
293+
298294 return {
299295 "prompt_ids" : episode_prompt_ids , # List[List[int]] - tokens per episode
300296 "completion_ids" : episode_completion_ids , # List[List[int]] - tokens per episode
301297 "logprobs" : episode_logprobs , # List[List[float]] - logprobs per episode
302298 "step_rewards" : total_rewards , # List[float] - total reward per episode (1D!)
303299 }
304-
300+
305301 print (f"[openenv_trl_vllm] Returning rollout_func (type={ type (rollout_func )} )" , flush = True )
306302 sys .stdout .flush ()
307303 return rollout_func
308-
0 commit comments